#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import sys
from datetime import datetime

import torch

import tripleblind as tb


tb.util.set_script_dir_current()
tb.initialize(api_token=tb.config.example_user3["token"])

#############################################################################
# Find the customer account datasets in the Router index for training

prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"

dataset_train0 = tb.Asset.find(
    f"{prefix} - Hope Valley Hospital Database",
    owned_by=tb.config.example_user1["team_id"],
)
dataset_train1 = tb.Asset.find(
    f"{prefix} - CMS Hospital Database", owned_by=tb.config.example_user2["team_id"]
)
dataset_train2 = tb.Asset.find(
    f"{prefix} - NHS Hospital Database", owned_by=tb.config.example_user3["team_id"]
)
if not dataset_train0 or not dataset_train1 or not dataset_train2:
    print("Datasets not found.")
    sys.exit(1)


#############################################################################
# Define the neural network we want to train
#############################################################################

training_model_name = "example-network-trainer"

builder = tb.NetworkBuilder()
builder.add_dense_layer(26, 120)
builder.add_relu()
builder.add_dense_layer(120, 160)
builder.add_relu()
builder.add_dropout(0.25)
builder.add_dense_layer(160, 200)
builder.add_relu()
builder.add_split()  # required split layer
builder.add_dense_layer(200, 160)
builder.add_relu()
builder.add_dense_layer(160, 10)
builder.add_relu()
builder.add_dense_layer(10, 1)

training_model = tb.create_network(training_model_name, builder)

#############################################################################
# Designate the files to use and train the network
#

# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "BCEWithLogitsLoss"
pos_weight = tb.TorchEncoder.encode(torch.arange(17, 18, dtype=torch.int32))

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 0.001}

# Use the CSV Preprocessor to specify what data to use for training and which
# column to treat as the classification label.  All three hospitals use the same
# format for their data, so no further preprocessing is necessary.
csv_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("target", target=True)
    .all_columns(True)
    .dtype("float32")
)

job = tb.create_job(
    job_name=f"Tabular_Data_Example - {datetime.now()}",
    operation=training_model,
    dataset=[dataset_train0, dataset_train1, dataset_train2],
    preprocessor=csv_pre,
    params={
        "epochs": 1,
        "loss_meta": {"name": loss_name, "params": {"pos_weight": pos_weight}},
        "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
        "data_type": "table",
        "data_shape": [26],  # number of columns of data in table
        "model_output": "binary",  # binary/multiclass/regression
        "test_size": 0.2,
    },
)
print(f"Training network")

###########################################################################
# Create the network asset and local .pth file from the trained network

if job.submit():
    print(f"Creating network asset under name: {training_model_name}")
    job.wait_for_completion()

    # Throw away this network definition (no longer needed)
    training_model.archive()

    if job.success:
        print()
        print("Trained Network Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Deep Learning Model")
        print(f"    Job ID:    {job.job_name}")
        print()
        trained_network = job.result.asset
    else:
        print(f"Training failed")
        sys.exit(1)

    # Pull down the model for local validation
    local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
    print("Trained network has been downloaded as:")
    print(f"   {local_filename}")

    pack = tb.Package.load(local_filename)

    classification_errors = pack.get_model_misclassifications()
    if classification_errors:
        print("\nUnrestricted information about misclassified test cases:")
        for df in classification_errors:
            print(df.to_string(show_dimensions=True, index=False))
        print()

    # Save trained asset for use in 3b_aes_inference.py / 3c_smpc_inference.py
    tb.util.save_to("model_asset_id.out", job.result.asset.uuid)

    print("Ready to run local inference.")
    print()


############################################################################
# The 'trained_network.filename' variable is the local filename used when
# downloading the trained PyTorch object locally. It could easily be passed to
# an additional step to run the local inference.
