#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import numpy as np
import pandas as pd

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user2["token"], example=True)

for model_type in ("Keras", "PyTorch", "ONNX"):
    asset_id_file = model_type.lower()
    if not Path(f"{asset_id_file}_asset_id.out").exists():
        print(f"{asset_id_file} model not found, skipping.")
        continue

    asset_id = tb.util.load_from(f"{asset_id_file}_asset_id.out")
    try:
        model_asset = tb.ModelAsset(asset_id)
    except:
        raise SystemExit("You must run 1_position_model_on_accesspoint.py first.")

    # Build a simple .npy test data file
    if model_type.lower() == "keras":
        X = np.random.random((1, 32, 32, 3))
    else:
        X = np.random.random((1, 3, 32, 32))
    np.save("inference_test_data.npy", X)
    pre = tb.preprocessor.numpy_input.NumpyInputPreprocessor.builder().dtype("float32")

    # Notes on params:
    #
    # This example can illustrate how different parameters impact the output
    # results.  Two specific parameters control the format of the final output,
    # model_output and final_layer_softmax.
    #
    #   model_output: Controls how the inference value is reported.
    #
    #       "" or undefined - reports the raw inference value
    #
    #         or specify an output type, depending on the model
    #
    #       "binary"        - rounds to a simple 0.0 or 1.0 value
    #       "multiclass"    - rounds to the the output class values
    #       "regression"    - returns raw output
    #
    #   final_layer_softmax: If True, applies a softmax as a final layer.  This
    #                        results in a probability distribution. Default if not
    #                        specified is False.
    #
    #                        SMPC does not support softmax and it will be skipped
    #                        in SMPC evaluation.  However the data provider can
    #                        perform this as a final layer after evaluating a model
    #                        via SMPC.  This retains privacy and IP protections.
    #
    # Here are example of output from changing these values:
    #    model_output: "", final_layer_softmax: True
    #         0.999969   0.000031
    #
    #    model_output: "", final_layer_softmax: False
    #         6.28534   -4.10021
    #
    #    model_output: "binary", final_layer_softmax: False
    #         1.0        0.0
    #
    #    model_output: "binary", final_layer_softmax: True
    #         1.0        1.0.1

    print(f"\n{model_asset.name} Model Inference...")

    smpc_result = model_asset.infer(
        "inference_test_data.npy",
        preprocessor=pre,
        params={
            "security": "smpc",
            "data_type": "numpy",
            "model_output": "",  # "binary" for 1/0; nothing or "" for raw outputs
            "final_layer_softmax": False,
            # "batch_size": 2, # Only used for FED inference.
        },
        silent=True,
    )
    smpc_result.asset.retrieve(save_as="smpc_out.zip", overwrite=True)
    pack = tb.Package.load("smpc_out.zip")
    smpc_result = pack.records()

    fed_result = model_asset.infer(
        "inference_test_data.npy",
        preprocessor=pre,
        params={
            "security": "fed",
            "data_type": "numpy",
            "model_output": "",  # "binary" for 1/0; nothing or "" for raw outputs
            "final_layer_softmax": False,
            # "batch_size": 2, # Only used for FED inference.
        },
        silent=True,
    )

    fed_result.asset.retrieve(save_as="fed_out.zip", overwrite=True)
    pack = tb.Package.load("fed_out.zip")
    fed_result = pack.records()
    df = pd.DataFrame(
        np.vstack([smpc_result, fed_result]),
        index=["SMPC", "FED"],
    )
    print(df)
