#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import warnings

import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
tb.initialize(api_token=tb.config.example_user3["token"], example=True)
data_dir = tb.util.script_dir() / "example_data"
data_dir.mkdir(exist_ok=True)

# Look for a model Asset ID from a previous run of 2_transfer_train.py
trained_asset_id = tb.util.load_from("model_asset_id.out")
alg_santandar = tb.Asset(trained_asset_id)
if not alg_santandar.is_valid:
    raise SystemExit("ERROR: Unable to find the specified model.")

# Use a test dataset for "batch" testing
#
data_file = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Load and split test data into independent X (data) and y (target) dataframes
data_X = pd.read_csv(data_file)
y = data_X["target"].copy()
del data_X["target"]

# Create a CSV to pass in to the inference for simplicity and easy examination
dataset = data_dir / "test_small_notarget.csv"
data_X.to_csv(dataset, index=False)


# Define a job using this model
job = tb.create_job(
    job_name=f"Santandar Inference - {tb.util.timestamp()}",
    operation=alg_santandar,
    dataset=dataset,
    params={"security": "smpc"},  # fed or smpc
)
if not job:
    raise SystemExit("ERROR: Failed to create the job")


# Use the remote model to infer against the local dataset
if job.submit():
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(
            save_as="tabular_remote_predictions.zip", overwrite=True
        )
        result = job.result.dataframe

        print("\nInference results:")
        print("    ", result.values.astype(int).flatten())
        print("\nTruth:")
        print("    ", y.values.astype(int))

        print("\nClassification report:")
        print(classification_report(y, result.values))
    else:
        print(f"Inference failed")
