#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()
data_dir = Path("BankTransactionData")

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
tb.initialize(api_token=tb.config.example_user2["token"])

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("model_asset_id.out", "r") as f:
        asset_id = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)

alg_santandar = tb.Asset(asset_id)

# Retrieve from TripleBlind's demo data server.
data_file = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)
testdf = pd.read_csv(data_file)
del testdf["target"]
testdf.to_csv(data_dir / "test_small_notarget_demo.csv", index=False)
datasets = [data_dir / "test_small_notarget_demo.csv"]

preproc = tb.TabularPreprocessor.builder().all_columns(True).dtype("float32")

# Define a job using this model
job = tb.create_job(
    job_name="Santandar Inference - " + str(datetime.now()),
    operation=alg_santandar,
    dataset=datasets,
    preprocessor=preproc,
    params={"security": "smpc"},  # fed or smpc
)
if not job:
    print("ERROR: Failed to create the job -- do you have an Agreement to run this?")
    print()
    print(
        f"NOTE: Remote inference requires the user '{tb.config.example_user1['login']}' create an"
    )
    print(
        f"      Agreement on their algorithm asset with user '{tb.config.example_user2['login']}'"
    )
    print(
        f"      ({tb.config.example_user2['name']}) before they can use it to infer.  You can do"
    )
    print(f"      this on the Router at:")
    print(f"      {tb.config.gui_url}/dashboard/algorithm/{alg_santandar.uuid}")
    sys.exit(1)

# Load the labels
data_y = pd.read_csv(data_dir / "test_small_target_demo.csv")
y = data_y["target"]
y = np.expand_dims(y.values, axis=1)

# Run against the local test dataset
if job.submit():
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(
            save_as="tabular_remote_predictions.zip", overwrite=True
        )
        pack = tb.Package.load(filename)
        result = pd.read_csv(pack.record_data_as_file(), header=None)
        print("\nInference results:")
        print("    ", result.values.astype(int).flatten())
        print("\nTruth:")
        print("    ", y.astype(int).flatten())

        print("\nClassification report:")
        print(classification_report(y, result.values))
    else:
        print(f"Inference failed")
