#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

asset_a = tb.TableAsset.find(
    "EXAMPLE - Patient Data (imperial)", owned_by=tb.config.example_user1["team_id"]
)
asset_b = tb.TableAsset.find(
    "EXAMPLE - Patient Data (metric)", owned_by=tb.config.example_user2["team_id"]
)


# A preprocessor can modify the shape of an input file just before performing
# an operation.  Here we will create a new calculated field using the height
# and weight to create a BMI.
#
# Preprocessing can also apply filters to the input data.  Here we will extract
# the cohort of study participants who are over 50 years old.

# NOTE: When defining the sql_transform portion of the preprocessor the input
# comes from table named "data".
preprocess_a = (
    tb.TabularPreprocessor.builder()
    .add_column("bmi", target=True)
    .all_columns(True)
    .sql_transform(
        """SELECT Patient_Id as pid, Height_IN as height, Weight_LBS as weight,
            (cast(Weight_LBS as real) / (cast(Height_IN as real) * cast(Height_IN as real))) * 703.0 as bmi
            FROM data
            WHERE Age > 50 and Weight_LBS IS NOT NULL"""
    )
    .dtype("float32")
)

# Different preprocessors can be defined for each input dataset.  This
# preprocessor performs a data transformation from metric to imperial units
# as well as calculating a new BMI field.  No filtering is necessary since the
# dataset has only patients over 50.
#
# See "python_transform_script.py" more details on how this style of data
# transformer works.
preprocess_b = (
    tb.TabularPreprocessor.builder()
    .add_column("bmi", target=True)
    .all_columns(True)
    .python_transform("python_transform_script.py")
    .dtype("float32")
)

# Run a regression against the datasets.  This should show a clear relationship
# between the BMI and height/weight.
job = tb.create_job(
    job_name="Calculated BMI example",
    operation=tb.Operation.REGRESSION,
    dataset=[asset_a, asset_b],
    preprocessor=[preprocess_a, preprocess_b],
    params={
        "regression_algorithm": "Linear",  # Linear, Logistic, Lasso, and Ridge are allowed
        "test_size": 0.1,  # r2 if regression, accuracy report otherwise
    },
)

if job.submit():
    job.wait_for_completion()

    if job.success:
        print(f"Trained Asset ID: {job.result.asset.uuid}")

        # Download a local copy of the trained model
        model_file = job.result.asset.retrieve("bmi_model.zip", overwrite=True)
        print(f"Sklearn model has been downloaded as:")
        print(f"    {model_file}")

        # Also retain the Router Asset ID for later usage
        tb.util.save_to("model_asset_id_distributed.out", job.result.asset.uuid)

        pack = tb.Package.load("bmi_model.zip")
        model = pack.model()

        print("\nCoefficients:")
        print(model.coef_)
    else:
        raise SystemExit("Training failed")
