#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import sys
from datetime import datetime
from pathlib import Path

import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)
tb.initialize(api_token=tb.config.example_user3["token"])

test_dataset = (
    "regression_test_small.csv"
    if "TB_TEST_SMALL" in os.environ
    else "regression_test.csv"
)

data_file = tb.util.download_tripleblind_resource(
    test_dataset,
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    model_asset_id = tb.util.load_from("model_asset_id_distributed.out")
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)
model = tb.asset.XGBoostModel(model_asset_id)
job_name = "XGBoost Regression Split Inference - " + str(datetime.now())


# Load and split test data into independent X (data) and y (target) dataframes
small_data_X = pd.read_csv(data_file)[:1]
small_data_X.columns = [
    f"_{col}" if isinstance(col, int) or col.isnumeric() else col
    for col in small_data_X.columns
]
small_data_y = small_data_X.copy()
del small_data_X["target"]

small_data_X.to_csv(data_dir / "small_reg_test.csv", index=False)

test_data_file = data_dir / "small_reg_test.csv"

result = model.predict(test_data_file, use_smpc=True, job_name=job_name)

# Show and save results
print(result.values)
result.to_csv("xgboost_reg_results_smpc.csv")
