#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import warnings
from datetime import datetime
from pathlib import Path

import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'
tb.initialize(api_token=tb.config.example_user2["token"])

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("model_asset_id.out", "r") as f:
        asset_id = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)

trained_network = tb.Asset(asset_id)


# Run MNIST trained model against a several examples of handwritten digits
inference_predictions = []
list_of_files = ["three.jpg", "four.jpg", "seven.jpg", "big_eight.jpg"]
test_truth = [3, 4, 7, 8]

for name in list_of_files:
    job = tb.create_job(
        job_name="Test trained network - " + str(datetime.now()),
        operation=trained_network,
        params={"security": "smpc"},
        dataset=data_dir / name,
    )
    if not job:
        print(
            "ERROR: Failed to create the job -- do you have an Agreement to run this?"
        )
        print(
            f"NOTE: Remote inference requires user '{tb.config.example_user1['login']}' to create an"
        )
        print(
            f"      Agreement on their algorithm asset with user '{tb.config.example_user2['login']}'"
        )
        print(
            f"      ({tb.config.example_user2['name']}) before they can use it to infer.  You can do"
        )
        print("      this on the Router at:")
        print(f"{tb.config.gui_url}/dashboard/algorithm/{trained_network.uuid}")
        sys.exit(1)

    if job.submit():
        job.wait_for_completion()

        if job.success:
            filename = job.result.asset.retrieve(
                save_as="mnist_result.zip", overwrite=True
            )
            pack = tb.Package.load(filename)
            result = pd.read_csv(pack.record_data_as_file())
            inference_predictions.append(result.values[0][0])
        else:
            print(f"SMPC Inference failed")


print("\n\nInference results:")
print("    ", inference_predictions)
print("Truth:")
print("    ", test_truth)

print("\nClassification report:")
# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")
print(classification_report(test_truth, inference_predictions))

# Save results to a CSV file
with open(r"smpc_inferences.csv", "w") as out:
    for pred in inference_predictions:
        out.write(f"{pred}\n")
