#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys

import tripleblind as tb


tb.util.set_script_dir_current()
tb.initialize(api_token=tb.config.example_user1["token"])


# Find the training datasets on the Router's index
asset0 = tb.Asset.find("college-acceptance-mongo")
if not asset0:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first.")
    sys.exit(1)


# Train the XGBoost model
model = tb.asset.XGBoostModel.train(
    training_data=[asset0],
    datatype="float32",
    target_var="admitted",
    variables="ALL",
    job_name="XGBoost Distributed College Acceptance Training",
)

print("\nSaving model locally...")
filename = model.retrieve(
    "xgboost_ca_model.zip",
    overwrite=True,
    show_progress=True,
)

# Retain the Router Asset ID for later usage
with open("model_asset_id.out", "w") as output:
    output.write(str(model.uuid))

# Create an agreement which allows the other team to use this
# trained model in subsequent steps.
agreement = model.add_agreement(
    with_team=tb.config.example_user1["team_id"], operation=tb.Operation.EXECUTE
)
if agreement:
    print("Created Agreement for use of trained model Asset.")
