#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import time

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)
MAX_DEPTH = 2 if "TB_TEST_SMALL" in os.environ else 3


#############################################################################
# Find training datasets on the Router's index
asset0 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, training data (part 1)",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, training data (part 2)",
    owned_by=tb.config.example_user2["team_id"],
)
asset2 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, training data (part 3)",
    owned_by=tb.config.example_user3["team_id"],
)


csv_pre = tb.TabularPreprocessor.builder().all_columns().add_column("y", target=True)

# Define job to train a Decision Tree model on the two remote datasets
job = tb.create_job(
    job_name="Decision Tree Training",
    operation=tb.Operation.PSI_VERTICAL_DECISION_TREE_TRAIN,
    dataset=[asset0, asset1, asset2],
    preprocessor=csv_pre,
    params={
        "decision_tree": {
            "regression": False,
            "max_depth": MAX_DEPTH,
        },
        "psi": {"match_column": "ID"},
        "target_column": "y",
    },
)
start = time.time()
if job.submit():
    print("Performing Decision Tree training as organization-three...")
    job.wait_for_completion()
    end = time.time()

    if job.success:
        filename = job.result.asset.retrieve(save_as="dt_model.zip", overwrite=True)
        print()
        print("Trained Model Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Decision Tree Classification")
        print(f"    Job ID:    {job.job_name}")
        print(f"    Model in file:  {filename}")
        print(f"    Training time: {end-start} seconds")
        print()
    else:
        raise SystemExit("Model Training Failed")

    # Save trained asset ID for later use in inference scripts
    tb.util.save_to("clf_model_asset_id.out", job.result.asset.uuid)
