#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

print(
    """The XGBoost library produces models which can be tricky to run locally
due to dependencies on other libraries which may not be easily installed on some
hardware and operating system combinations, e.g. MacOS on Apple M1/M2
processors. This script demonstrates how to load a trained XGBoost model
decoupled from the TripleBlind ecosystem and may work unaltered if the
necessary dependencies are available, but some customization maybe be
required to run on some systems.
"""
)

# Remove the following line to run this script.
raise SystemExit("This script is for reference only.")

import os
from pathlib import Path

import pandas as pd
from sklearn.metrics import r2_score

import tripleblind as tb


try:
    import xgboost
except Exception as e:
    print(f"Exception: {e}")  # Show raw error

    import platform

    from colorama import Fore, Style

    if platform.processor() == "arm":
        tb.util.wrap(
            "\nLocal xgboost inference requires a library which doesn't yet run on "
            + "ARM-based machines, such as the Apple M1/M2 processors. "
            + "XGBoost models can still be trained and used for inference via "
            + "TripleBlind, see the 3b_ and 3c_ script examples."
        )
    else:
        tb.util.wrap(
            Fore.RED + "\nDependency missing: XGBoost requires libomp" + Style.RESET_ALL
        )
        print()
        tb.util.wrap(
            "Local xgboost inference requires an additional library "
            + "install. On MacOS xgboost requires libomp 11 which is not "
            + "available from brew. Run the following commands to get the "
            + "correct version:"
        )
        print(
            Fore.CYAN + "  mamba install conda-forge::llvm-openmp=11" + Style.RESET_ALL
        )
        print()
        tb.util.wrap("Then update your dylib reference so xgboost can find it:")
        print(
            Fore.CYAN
            + "  cp ~/anaconda3/envs/tripleblind/lib/libomp.dylib /usr/local/bin/libomp.dylib"
            + Style.RESET_ALL
        )
    raise SystemExit(1)


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)


def load_xgmodel(filename):
    pack = tb.Package.load(filename)
    model = pack.model()
    return model


# Use a test dataset for "batch" testing

test_dataset = (
    "regression_test_small.csv"
    if "TB_TEST_SMALL" in os.environ
    else "regression_test.csv"
)

data_file = tb.util.download_tripleblind_resource(
    test_dataset,
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Load and split test data into independent X (data) and y (target) dataframes
data_X = pd.read_csv(data_file)
data_y = data_X["target"].copy()
del data_X["target"]

model = load_xgmodel("xgboost_rand_reg_split_model.zip")

print(model)

predictions = model.predict(data_X)
actual = data_y.to_numpy() if isinstance(data_y, pd.DataFrame) else data_y.values
accuracy = r2_score(actual, predictions)
print("Split model:")
print(f"Predictions: {predictions}")
print(f"Actual:      {actual}")
print(f"r2_score:    {accuracy}")
