#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import pickle
import sys
import warnings
from datetime import datetime
from zipfile import ZipFile

import numpy as np
from lstm_utils import convert_data_for_inference, convert_int_to_char

import tripleblind as tb


# Suppress the SkLearn "UndefinedMetricWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
tb.initialize(api_token=tb.config.example_user2["token"])

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("model_asset_id.out", "r") as f:
        asset_id = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)
alg_lstm = tb.Asset(asset_id)

results = []
if "TB_TEST_SMALL" in os.environ:
    max_chars = 5
else:
    max_chars = 10
inference_string = "arm"

while len(inference_string) < max_chars:
    inference_ary = convert_data_for_inference(inference_string, "indices.pkl")

    np.save("sample_inference.npy", inference_ary)

    # Define a job using this model
    job = tb.create_job(
        job_name="LSTM Inference - " + str(datetime.now()),
        operation=alg_lstm,
        dataset=["sample_inference.npy"],
        params={"security": "smpc"},  # fed or smpc
    )
    if not job:
        print(
            "ERROR: Failed to create the job -- do you have an Agreement to run this?"
        )
        print()
        print(
            f"NOTE: Remote inference requires the user '{tb.config.example_user1['login']}' create an"
        )
        print(
            f"      Agreement on their algorithm asset with user '{tb.config.example_user2['login']}'"
        )
        print(
            f"      ({tb.config.example_user2['name']}) before they can use it to infer.  You can do"
        )
        print(f"      this on the Router at:")
        print(f"      {tb.config.gui_url}/dashboard/algorithm/{alg_lstm.uuid}")
        sys.exit(1)

        # Run against the local test dataset
    if job.submit():
        job.wait_for_completion()

        if job.success:
            filename = job.result.asset.retrieve(
                save_as="lstm_remote_predictions.zip", overwrite=True
            )
            with ZipFile(filename, "r") as zip:
                zip.extractall()
                with zip.open("result.pkl") as f:
                    result = pickle.load(f)
            index = np.argmax(result[-1])
            inference_string += convert_int_to_char(index, "indices.pkl")
        else:
            print(f"Inference failed")
            break

with open("final_smpc_results.txt", "w") as f:
    f.write(inference_string)

print("\nInference results:")
print("    ", inference_string)
