#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)

# Find training datasets in the Router's index
prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"

dataset_train0 = tb.TableAsset.find(
    f"{prefix} - Experian Credit Score Data",
    owned_by=tb.config.example_user1["team_id"],
)
dataset_train1 = tb.TableAsset.find(
    f"{prefix} - TransUnion Credit Score Data",
    owned_by=tb.config.example_user2["team_id"],
)


# Train the XGBoost model
model = tb.asset.XGBoostModel.train(
    training_data=[dataset_train0, dataset_train1],
    datatype="float32",
    target_var="target",
    variables="ALL",
    is_regression=True,
    job_name="XGBoost Distributed Regression Training",
)

print("\nSaving model locally...")
filename = model.retrieve(
    "xgboost_rand_reg_split_model.zip",
    overwrite=True,
    show_progress=True,
)

# Retain the Router Asset ID for later usage
tb.util.save_to("model_asset_id_distributed.out", model.uuid)

# Organization 3 is the owner of the trained model. Organization 3 will also
# be running the inferences. If the trained model owner (Organization 3) wants
# to allow any other organization's to use this trained model to run inferences,
# it can create agreements here:

# agreement = model.add_agreement(
#     with_team=tb.config.example_user1["team_id"], operation=tb.Operation.EXECUTE
# )
# if agreement:
#     print("Created Agreement for use of trained model Asset.")
