#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)


#############################################################################
# Find training datasets on the Router's index

asset0 = tb.TableAsset.find(
    "EXAMPLE - San Diego Housing Census 1990",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    "EXAMPLE - Los Angeles Housing Census 1990",
    owned_by=tb.config.example_user2["team_id"],
)


csv_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("MedInc")
    .add_column("HouseAge")
    .add_column("AveRooms")
    .add_column("AveBedrms")
    .add_column("Population")
    .add_column("AveOccup")
    .add_column("Latitude")
    .add_column("Longitude")
    .add_column("Price", target=True)
    .add_data_transformer(
        "KBinsDiscretizer",
        columns=["HouseAge", "MedInc"],
        params={"n_bins": 4},
    )
)

# Define job to train a Random Forest regressor on the two remote datasets
job = tb.create_job(
    job_name="Random Forest Regression Training",
    operation=tb.Operation.RANDOM_FOREST_TRAIN,
    dataset=[asset0, asset1],
    preprocessor=csv_pre,
    params={
        "train_type": "regression",
        # "fill_missing_categories": "random",
        # This supports all Scikit-learn Random Forest parameters.  See:
        #   https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html
        "random_forest_params": {"random_state": 0, "max_depth": 2},
    },
)

if job.submit():
    print("Performing Random Forest Regression training as organization-three...")
    print("Permission is needed from organizations one and two")
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(
            save_as="forest_model_regression.zip", overwrite=True
        )
        print()
        print("Trained Model Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Random Forest Regression")
        print(f"    Job ID:    {job.job_name}")
        print(f"    Model in file:  {filename}")
        print()

    else:
        raise SystemError("Model Training Failed")

    # Safe asset ID for later use in inference scripts
    tb.util.save_to("model_asset_id.2b.out", job.result.asset.uuid)
