#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)


table1 = tb.TableAsset.find(
    "EXAMPLE - PSI VP XGBoost Training (part 1)",
    owned_by=tb.config.example_user1["team_id"],
)
csv_pre1 = (
    tb.TabularPreprocessor.builder().all_columns().add_column("target", target=True)
)

table2 = tb.TableAsset.find(
    "EXAMPLE - PSI VP XGBoost Training (part 2)",
    owned_by=tb.config.example_user2["team_id"],
)
csv_pre2 = tb.TabularPreprocessor.builder().all_columns()

job = tb.create_job(
    "Example PSI VP XGBoost Training",
    operation=tb.Operation.PSI_VP_XGBOOST_TRAIN,
    dataset=[table1, table2],
    params={
        "match_column": ["identifier", "id"],
        "objective": "binary:logistic",
        "n_estimators": 10,
        "max_depth": 3,
    },
    preprocessor=[csv_pre1, csv_pre2],
)
job.submit()
job.wait_for_completion()
if job.success:
    result = job.result
    tb.util.save_to("trained_model_asset_id.out", job.result.asset.uuid)
    print("Training finished.")
else:
    print(job)
    raise SystemError("")
