#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import warnings
from pathlib import Path

import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Look for a model Asset ID from a previous run of 1_model_train.py
try:
    trained_asset_id = tb.util.load_from("model_asset_id.out")
    alg_santandar = tb.ModelAsset(trained_asset_id)
except:
    raise SystemError("You must first run 1_model_train.py.")

# Use a test dataset for "batch" testing
#
data_dir = Path("example_data")
test_data = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Load and split test data into independent X (data) and y (target) dataframes
data_X = pd.read_csv(test_data)
y = data_X["target"].copy()
del data_X["target"]

# Create a CSV to pass in to the inference for simplicity and easy examination
dataset = data_dir / "test_small_notarget.csv"
data_X.to_csv(dataset, index=False)


############################################################################
# Use the remote model to infer against the local dataset
result = alg_santandar.infer(
    data=dataset,
    params={"security": "fed", "identifier_columns": ["amt"]},
    job_name=f"Santandar Inference - {tb.util.timestamp()}",
)
if not result:
    raise SystemError("Inference failed")

result.asset.retrieve(save_as="fed_out.zip", overwrite=True)
results = result.table.load(header=True)
# pack = Package.load("fed_out.zip")
# results = pack.records()  # Display the inference output

print("\nInference results:")
print("    ", results)
print("\nTruth:")
print("    ", y.values.astype(int))
print("\nClassification report:")
print(classification_report(y, results["results"].values.astype(int).flatten()))
