#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.


from sklearn.metrics import classification_report

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Retrieve the trained model's Asset ID from the 2_model_train.py step
trained_asset_id = tb.util.load_from("model_asset_id.out")
alg_cifar = tb.ModelAsset(trained_asset_id)


###########################################################################
# Image to be classified.
# - Any image file (.png, .jpg, .bmp) can potentially be used
# - The image should contain one the things listed in class_names below
# - An image of an airplane is used by default for convenience
###########################################################################
inference_image = tb.util.download_tripleblind_resource(
    "Flying-airplane.jpg",
    save_to_dir="example_data",
    cache_dir="../../.cache",
)

class_names = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
expected = [class_names.index("airplane")]


# This variation of SMPC inference sends the data for the inference as a
# single image file, via the job's dataset.  This file is used to create a
# temporary Image Package, since the data an image.  This implies an
# automatic ImagePreprocessor.
result = alg_cifar.infer(
    inference_image,
    params={"security": "smpc", "data_type": "image"},
)
if result:
    try:
        answer = result.table.load(header=True).values[0][0]
        name = class_names[answer]
        inferred = [answer]
    except Exception as e:
        print(e)
        name = "unrecognized"
        inferred = []

    print(f"\nSMPC Inference result:")
    print(f"      {answer} - {name}")
    print("Truth:")
    print("    ", expected)
    print()
    print(classification_report(expected, inferred, zero_division=0))
else:
    raise SystemError("SMPC Inference failed")
