#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
from datetime import datetime
from pathlib import Path

import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

tb.initialize(api_token=tb.config.example_user3["token"])

data_file = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    model_asset_id = tb.util.load_from("model_asset_id.out")
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)
model = tb.asset.XGBoostModel(model_asset_id)
job_name = "XGBoost Inference - " + str(datetime.now())

# Create a test inference file from the Santanandar data by removing the "target" column.
df = pd.read_csv(data_file, header=0)
del df["target"]
df.to_csv(data_dir / "inference_data.csv", index=False)

# Run the test file against the model to get a prediction
result = model.predict_proba(
    data_dir / "inference_data.csv", job_name=job_name, use_smpc=False
)

if result is not None:
    # Show and save results
    print(result.values)
    result.to_csv("xgboost_results_fed.csv")
