#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os

import torch

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)

#############################################################################
# Find the customer account datasets in the Router index for training
#############################################################################

prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"
dataset_train0 = tb.Asset.find(
    f"{prefix} - SAN Customer Database", owned_by=tb.config.example_user1["team_id"]
)
dataset_train1 = tb.Asset.find(
    f"{prefix} - JPM Customer Database", owned_by=tb.config.example_user2["team_id"]
)
dataset_train2 = tb.Asset.find(
    f"{prefix} - PNB Customer Database", owned_by=tb.config.example_user3["team_id"]
)


#############################################################################
# Define the neural network we want to train
#############################################################################

training_model_name = "example-network-santandar-trainer"

builder = tb.NetworkBuilder()
builder.add_dense_layer(26, 120)
builder.add_relu()
builder.add_dense_layer(120, 160)
builder.add_relu()
builder.add_dropout(0.25)
builder.add_dense_layer(160, 200)
builder.add_relu()
builder.add_split()  # required split layer
builder.add_dense_layer(200, 160)
builder.add_relu()
builder.add_dense_layer(160, 10)
builder.add_relu()
builder.add_dense_layer(10, 1)

training_model = tb.create_network(training_model_name, builder)

#############################################################################
# Train the network
#############################################################################

# Use the CSV Preprocessor to specify what data to use for training and which
# column to treat as the classification label.  All three banks use the same
# format for their data, so no further preprocessing is necessary.
csv_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("target", target=True)
    .all_columns(True)
    .dtype("float32")
)

pos_weight = tb.TorchEncoder.encode(torch.arange(17, 18, dtype=torch.int32))

result = training_model.train(
    data=[dataset_train0, dataset_train1, dataset_train2],
    data_type="table",
    data_shape=[26],  # number of columns of data in table
    preprocessor=csv_pre,
    #
    # Loss function
    loss_name="BCEWithLogitsLoss",
    loss_params={"pos_weight": pos_weight},
    #
    # Optimizer and parameters
    optimizer_name="Adam",
    optimizer_params={"lr": 0.001},
    #
    # Training options
    epochs=1,
    test_size=0.2,
    model_output="binary",
    delete_trainer=True,
    job_name=f"Tabular_Data_Example - {tb.util.timestamp()}",
)


###########################################################################
# Create the network asset and local .pth file from the trained network
trained_network = result.asset
print()
print("Trained Network Asset ID:")
print("    ===============================================")
print(f"    ===>  {trained_network.uuid} <===")
print("    ===============================================")
print("    Algorithm: Deep Learning Model")
print()
# Save trained asset for use in 2b_fed_inference.py / 2c_smpc_inference.py
tb.util.save_to("model_asset_id.out", trained_network.uuid)


# Add an agreement to allow access to the model by user1 in later inferences.
trained_network.add_agreement(
    with_team=tb.config.example_user1["team_id"], operation=tb.Operation.EXECUTE
)

# Pull down the model for local validation
local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
print("Trained network has been downloaded as:")
print(f"   {local_filename}")

pack = tb.Package.load(local_filename)

classification_errors = pack.get_model_misclassifications()
if classification_errors:
    print("\nUnrestricted information about misclassified test cases:")
    for df in classification_errors:
        print(df.to_string(show_dimensions=True, index=False))
    print()


print("Ready to run local inference.")
print()
