#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import sys
from datetime import datetime

import tripleblind as tb


tb.util.set_script_dir_current()
# Unique value used by all scripts in this folder.  Edit "run_id.txt" to change
run_id = tb.util.read_run_id()


##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# IN THIS INSTANCE WE ARE TREATING ORGANIZATION-ONE AS "SANTANDER"
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'.
tb.initialize(api_token=tb.config.example_user1["token"])

#############################################################################
# Validate that the datasets are available

# Find the first training database in the Router index
dataset_train0 = tb.Asset.find(
    f"train-lstm-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
if not dataset_train0:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first")
    sys.exit(1)


#############################################################################
# Define the neural network we want to train
#############################################################################

training_model_name = "example-lstm-trainer"

builder = tb.NetworkBuilder()
builder.add_lstm_layer(39, 100, batch_first=True)
builder.add_split()  # required split layer
builder.add_dense_layer(100, 39)

training_model = tb.create_network(training_model_name, builder)

#############################################################################
# Designate the files to use and train the network
#

# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "CrossEntropyLoss"

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 0.01}

# Use the CSV Preprocessor to specify what data to use for training and which
# column to treat as the classification label.
b = tb.preprocessor.numpy_input.NumpyInputPreprocessor.builder()
b.target_column("label_path")
b.target_is_path(True)
b.dtype("float32")
b.target_dtype("int64")

if "TB_TEST_SMALL" in os.environ:
    epochs = 1
else:
    epochs = 1000

job = tb.create_job(
    job_name=f"LSTM - {str(datetime.now()).replace(' ', ' @ ')}",
    operation=training_model,
    dataset=[dataset_train0],
    preprocessor=b,
    params={
        "epochs": epochs,
        "loss_meta": {"name": loss_name},
        "optimizer_meta": {"name": optimizer_name},
        "data_type": "numpy",
        "data_shape": [200],  # number of columns of data in table
        "batchsize": 1,
        "model_output": "multiclass",
    },
)
print("Training network...")

#############################################################################
# Create the network asset and local .pth file from the trained network
#
if job.submit():
    print(f"Creating network asset under name: {training_model_name}")
    job.wait_for_completion()

    # Throw away this network definition (no longer needed)
    training_model.archive()

    if job.success:
        print()
        print("Trained Network Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Deep Learning Model")
        print(f"    Job ID:    {job.job_name}")
        print()
        trained_network = job.result.asset
    else:
        print(f"Training failed")
        sys.exit(1)

    # Pull down the model for local validation
    local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
    print("Trained network has been downloaded as:")
    print(f"   {local_filename}")

    # Save for use in 3a_local_inference.py
    with open("local_model_filename.out", "w") as output:
        output.write(str(local_filename))

    # Save for use in 3b_fed_inference / 3c_smpc_inference.py
    with open("model_asset_id.out", "w") as output:
        output.write(str(job.result.asset.uuid))

    print("Ready to run local inference.")
    print()

    # Create an agreement which allows the other team to use this
    # trained model in subsequent steps.
    agreement = job.result.asset.add_agreement(
        with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
    )
    if agreement:
        print("Created Agreement for use of trained Asset.")


############################################################################
# The 'trained_network.filename' variable is the local filename used when
# downloading the trained PyTorch object locally. It could easily be passed to
# an additional step to run the local inference.
