#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import warnings
from pathlib import Path

import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics import classification_report

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")

# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")


# Look for a model Asset ID from a previous run of 2_model_train.py
trained_model = Path("local_model.zip")
if not trained_model.exists():
    print("ERROR: Unable to find the trained model.")
    print("You must run 2_model_train.py first.")
    sys.exit(1)

pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

test_images = ["three.jpg", "four.jpg", "seven.jpg", "big_eight.jpg"]
test_truth = [3, 4, 7, 8]

inference_predictions = []
for img_file_name in test_images:
    pil_img = Image.open(data_dir / img_file_name).resize((28, 28))
    np_image = np.array(pil_img)

    image = np.expand_dims(np_image, np_image.ndim)
    image = image.astype(np.float32)
    image_tensor = transforms.ToTensor()(image)
    _, prediction = model(image_tensor.unsqueeze(0)).max(1)

    print(f"Image file name: {img_file_name} --> prediction: {prediction.item()}")
    inference_predictions.append(prediction.item())

print()
print("Classification report:")
print(classification_report(test_truth, inference_predictions))

# Save results to a CSV file
with open(r"local_inferences.csv", "w") as out:
    for pred in inference_predictions:
        out.write(f"{pred}\n")
