#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import warnings
from datetime import datetime
from pathlib import Path

import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"])
tb.util.set_script_dir_current()
data_dir = Path("example_data")

# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")


# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    trained_asset_id = tb.util.load_from("model_asset_id.out")
except:
    print("You must first run 2_model_train.py.")
    sys.exit(1)
trained_model = tb.Asset(trained_asset_id)

# Use a test dataset for "batch" testing
test_data = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)
# Load and split test data into independent X (data) and y (target) dataframes
data_X = pd.read_csv(test_data)
y = data_X["target"].copy()
del data_X["target"]

# Create a CSV to pass in to the inference for simplicity and easy examination
dataset = data_dir / "test_small_notarget.csv"
data_X.to_csv(dataset, index=False)


# Use the remote model to infer against the local dataset
job = tb.create_job(
    job_name="Hope Valley Inference - " + str(datetime.now()),
    operation=trained_model,
    dataset=dataset,
    params={"security": "fed"},
)
if job.submit():
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(save_as="fed_out.zip", overwrite=True)
        pack = tb.Package.load(filename)
        result = pd.read_csv(pack.record_data_as_file(), names=["results"])
        print("\nInference results:")
        print("    ", result.values.astype(int).flatten())
        print("\nTruth:")
        print("    ", y.values.astype(int))

        print("\nClassification report:")
        print(classification_report(y, result.values))
    else:
        print(f"Inference failed")
