#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import sys
import warnings

import numpy as np
import torch
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress tensorflow optimization and TensorRT warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Suppress the PyTorch "SourceChangeWarning" and tensorflow optimization warnings
warnings.filterwarnings("ignore")

from tensorflow.keras.datasets import cifar10


tb.util.set_script_dir_current()

# Look for a model Asset ID from a previous run of 2_model_train.py
trained_model = "local.zip"
if not os.path.exists(trained_model):
    print("ERROR: Unable to find the specified model.")
    sys.exit(1)

############################################################################
# Load the locally stored trained model
#

pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

# Retrieve the standard CIFAR-10 dataset of labeled color images
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Prep their 'test' dataset to match model
test_images = np.transpose(test_images, (0, 3, 1, 2))
# Convert labels to numeric representations.
test_labels = np.array([x[0] for x in test_labels]).astype(np.int64)

data = {"x": test_images[:10], "y": test_labels[:10]}

num_images_to_test = 10  # must be less than 10,000
X = test_images[:num_images_to_test].astype(np.float32)
y = test_labels[:num_images_to_test].astype(np.int64)

X = torch.from_numpy(X)

with torch.no_grad():
    preds = model(X).argmax(1).numpy()
    y = data["y"].astype(np.int64)

print("\nInference results against test set:")
print("    ", preds)
print("Actual values:")
print("    ", test_labels[:num_images_to_test])
print()
print("\nClassification report:")
print(classification_report(y, preds))
print()
print("NOTE: Retrain for more epochs to increase accuracy")
