#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)

#############################################################################
# Find training datasets on the Router's index

asset0 = tb.TableAsset.find(
    "EXAMPLE - NYU Student Admissions Data",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    "EXAMPLE - UMN Student Admissions Data",
    owned_by=tb.config.example_user2["team_id"],
)


csv_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("gmat")
    .add_column("gpa")
    .add_column("work_experience")
    .add_column("age")
    .add_column("admitted", target=True)
    # NOTE: Categories need to be specified which cover all data providers
    .add_data_transformer(
        "OneHotEncoder",
        columns=["work_experience"],
        params={
            "categories": [[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]],
            "handle_unknown": "ignore",
        },
    )
)


# Define job to train a Random Forest model on the two remote datasets
job = tb.create_job(
    job_name="Random Forest Training",
    operation=tb.Operation.RANDOM_FOREST_TRAIN,
    dataset=[asset0, asset1],
    preprocessor=csv_pre,
    params={
        "train_type": "classification",
        "fill_missing_categories": "random",
        # This supports all Scikit-learn Random Forest parameters.  See:
        #   https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
        "random_forest_params": {"random_state": 0, "max_depth": 2},
    },
)

if job.submit():
    print("Performing Random Forest training as organization-three...")
    print("Permission is needed from organizations one and two")
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(save_as="forest_model.zip", overwrite=True)
        print()
        print("Trained Model Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Random Forest Classification")
        print(f"    Job ID:    {job.job_name}")
        print(f"    Model in file:  {filename}")
        print()

    else:
        raise SystemExit("Model Training Failed")

    # Save trained asset ID for later use in inference scripts
    tb.util.save_to("model_asset_id.out", job.result.asset.uuid)
