#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

#############################################################################
# Find training datasets on the Router's index
prefix = "EXAMPLE - "
asset0 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical train 0-40",
    owned_by=tb.config.example_user1["team_id"],
)

asset1 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical train 41-100",
    owned_by=tb.config.example_user2["team_id"],
)
asset2 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical train 101-120",
    owned_by=tb.config.example_user3["team_id"],
)

if not asset0 or not asset1 or not asset2:
    raise SystemError("Datasets not found.")


preproc = tb.TabularPreprocessor.builder().all_columns().add_column("y", target=True)

# Define job to train a PSI Vertical Regression model on the two remote datasets
result_asset = tb.regression_asset.PSIVerticalRegressionModel.train(
    datasets=[asset0, asset1, asset2],
    match_column="ID",
    regression_type=tb.RegressionType.LINEAR,
    target="y",
    preprocessor=preproc,
    job_name="PSI Vertical Regression Training",
)

filename = result_asset.retrieve(
    save_as="psi_vert_linear_reg_model.zip",
    overwrite=True,
    show_progress=True,
)
print(f"Model saved as {filename}")

# Save trained asset ID for later use in inference scripts
tb.util.save_to("psi_vert_linear_reg_model_asset_id.out", result_asset.uuid)
