#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import sys
from pathlib import Path

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

tb.initialize(api_token=tb.config.example_user3["token"], example=True)

#############################################################################
# Find the customer account datasets in the Router index for training

prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"

dataset_train0 = tb.Asset.find(
    f"{prefix} - SAN Customer Database", owned_by=tb.config.example_user1["team_id"]
)
dataset_train1 = tb.Asset.find(
    f"{prefix} - JPM Customer Database", owned_by=tb.config.example_user2["team_id"]
)
dataset_train2 = tb.Asset.find(
    f"{prefix} - PNB Customer Database", owned_by=tb.config.example_user3["team_id"]
)
if not dataset_train0 or not dataset_train1 or not dataset_train2:
    print("Datasets not found.")
    sys.exit(1)


# Train the XGBoost model
model = tb.asset.XGBoostModel.train(
    training_data=[dataset_train0, dataset_train1, dataset_train2],
    datatype="float32",
    target_var="target",
    variables="ALL",
    job_name="XGBoost Distributed Santandar Training",
)

print("\nSaving model locally...")
filename = model.retrieve(
    data_dir / "xgboost_sant_model.zip",
    overwrite=True,
    show_progress=True,
)

# Retain the Router Asset ID for later usage
asset_id = tb.util.save_to("model_asset_id.out", model.uuid)
