#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

asset0 = tb.Asset.find(
    "EXAMPLE - Gene Training Data 0", owned_by=tb.config.example_user2["team_id"]
)
asset1 = tb.Asset.find(
    "EXAMPLE - Gene Training Data 1", owned_by=tb.config.example_user3["team_id"]
)
if not asset0 or not asset1:
    raise SystemError("Datasets not found.")

preproc = (
    tb.TabularPreprocessor.builder()
    .add_column("target", target=True)
    .all_columns(True)
    .dtype("float32")
)

# Available Regression Algorithms with usable params:
# -- LinearRegression - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
# -- LogisticRegression - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
# -- Lasso - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html
# -- Ridge - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html
# Specify regression_type parameter and add optional parameters to
# model_params as shown in example below.

result = tb.regression_asset.RegressionModel.train(
    datasets=[asset0, asset1],
    regression_type=tb.RegressionType.LOGISTIC,
    target="target",
    model_params={"max_iter": 100},
    test_size=0.1,
    preprocessor=preproc,
    job_name="Gene Logistic Regression Training",
)

if result:
    print(f"Trained Asset ID: {result.uuid}")

    # Download a local copy of the trained model
    filename = result.retrieve("gene_model.zip", overwrite=True)

    # Also retain the Router Asset ID for later usage
    tb.util.save_to("model_asset_id_distributed.out", result.uuid)
    print(f"Sklearn model has been downloaded as:")
    print(f"    {filename}")

    pack = tb.Package.load("gene_model.zip")
    model = pack.model()

    print("\nCoefficients:")
    print(model.coef_)

    # Create an agreement which allows one of the other team to use this
    # trained model in subsequent steps.
    agreement = result.add_agreement(
        with_team=tb.config.example_user2["team_id"],
        operation=tb.Operation.EXECUTE,
        algorithm_security="smpc",
    )
    if agreement:
        print("Created Agreement for use of trained Asset.")
