#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import warnings

import numpy as np
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress tensorflow optimization and TensorRT warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")

from tensorflow.keras.datasets import cifar10


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Get for the model Asset-ID from a previous run of 2_model_train.py
trained_asset_id = tb.util.load_from("model_asset_id.out")
alg_cifar = tb.ModelAsset(trained_asset_id)

# Retrieve a small CSV dataset of labeled records for testing purposes
infer_data = tb.util.download_tripleblind_resource(
    "test_cifar.zip",
    save_to_dir="example_data",
    expand=False,  # download, but keep as a .zip
    cache_dir="../../.cache",
)

# Load the labels
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
y = np.array(test_labels[:10]).flatten()  # these label match the test_cifar.zip


# Run the local test dataset through this model
print("Running Inference against trained CIFAR model...")
result = alg_cifar.infer(
    infer_data,
    params={"batch_size": 32, "security": "fed", "data_type": "image"},
)

if result:
    answer = result.table.load(header=True).to_numpy().flatten()

    print("\nInference results:")
    print(f"     [{' '.join([str(x) for x in answer])}]")
    print("Truth:")
    print("    ", y)
    print()

    print("Classification report:")
    print(classification_report(y, answer))
else:
    print(f"Inference failed")
