#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import warnings
from datetime import datetime
from pathlib import Path

import pandas as pd
from sklearn.metrics import classification_report

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user2["token"])
tb.util.set_script_dir_current()
data_dir = Path("example_data")

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("model_asset_id.out", "r") as f:
        asset_id = f.readline().strip()
except:
    print("You must first run 2_model_train.py.")
    sys.exit(1)
trained_network = tb.Asset(asset_id)


# Run MNIST trained model against a several examples of handwritten digits
inference_predictions = []
list_of_files = ["three.jpg", "four.jpg", "seven.jpg", "big_eight.jpg"]
test_truth = [3, 4, 7, 8]

for name in list_of_files:
    job = tb.create_job(
        job_name="Test trained network - " + str(datetime.now()),
        operation=trained_network,
        params={"security": "smpc"},
        dataset=data_dir / name,
    )
    if not job:
        print("ERROR: Failed to create the job")
        sys.exit(1)

    if job.submit():
        job.wait_for_completion()

        if job.success:
            filename = job.result.asset.retrieve(
                save_as="mnist_result.zip", overwrite=True
            )
            result = job.result.table.load(header=True)
            inference_predictions.append(result.values[0][0])
        else:
            print(f"SMPC Inference failed")


print("\n\nInference results:")
print("    ", inference_predictions)
print("Truth:")
print("    ", test_truth)

print("\nClassification report:")
# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")
print(classification_report(test_truth, inference_predictions))

# Save results to a CSV file
with open(r"smpc_inferences.csv", "w") as out:
    for pred in inference_predictions:
        out.write(f"{pred}\n")
