#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import datetime
import os
import sys

import tripleblind as tb


# Unique value used by all scripts in this folder.  Edit "run_id.txt" to change
run_id = tb.util.read_run_id()

tb.util.set_script_dir_current()

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE MARKETPLACE
#
# Establish the connection details to reach the TripleBlind market instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'
tb.initialize(api_token=tb.config.example_user1["token"])


# Find the training databases in the Router index
dataset_train0 = tb.Asset.find(
    f"train0_mnist-{run_id}", owned_by=tb.config.example_user1["team_id"]
)
dataset_train1 = tb.Asset.find(
    f"train1_mnist-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
if not dataset_train0 or not dataset_train1:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first")
    sys.exit(1)


# Define the neural network we want to use for training
training_model_name = "example-mnist-network-trainer"

builder = tb.NetworkBuilder()
builder.add_conv2d_layer(1, 32, 3, 1)
builder.add_batchnorm2d(32)
builder.add_relu()
builder.add_max_pool2d_layer(2, 2)
builder.add_conv2d_layer(32, 64, 3, 1)
builder.add_batchnorm2d(64)
builder.add_relu()
builder.add_max_pool2d_layer(2, 2)
builder.add_flatten_layer()
builder.add_dense_layer(1600, 128),
builder.add_relu()
builder.add_dense_layer(128, 10)

training_model = tb.create_network(
    training_model_name, builder, is_federated_learning_model=True
)

# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "CrossEntropyLoss"

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 0.001}
lr_scheduler_name = "CyclicCosineDecayLR"
lr_scheduler_params = {
    "init_decay_epochs": 10,
    "min_decay_lr": 0.0001,
    "restart_interval": 3,
    "restart_interval_multiplier": 1.5,
    "restart_lr": 0.01,
}
epochs = 2  # Increase to improve model accuracy
federated_rounds = 2
if os.environ.get("TB_TEST_SMALL"):
    epochs = 1
    federated_rounds = 1

image_pre = (
    tb.ImagePreprocessor.builder()
    .target_column("label")
    .resize(28, 28)
    .convert("L")  # use grayscale
    .channels_first()
    .dtype("float32")
)

# train the network
job = tb.create_job(
    job_name=f"MNIST - {str(datetime.datetime.now()).replace(' ', ' @ ')}",
    operation=training_model,
    dataset=[dataset_train0, dataset_train1],
    preprocessor=image_pre,
    params={
        "epochs": epochs,
        "test_size": 0.2,
        "batchsize": 64,
        "loss_meta": {"name": loss_name},
        "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
        "lr_scheduler_meta": {"name": lr_scheduler_name, "params": lr_scheduler_params},
        "data_type": "image",
        "data_shape": [
            28,  # image data: width
            28,  #             height
            1,  #              color bytes.  1 == grayscale
        ],
        "model_output": "multiclass",  # binary/multiclass/regression
        "federated_rounds": federated_rounds,
    },
)

print("Training network...")
if job.submit():
    print(f"Creating network asset under name: {training_model_name}")
    job.wait_for_completion()

    # Throw away this network definition (no longer needed)
    training_model.archive()

    if job.success:
        print()
        print("Trained Network Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Deep Learning Model")
        print(f"    Job ID:    {job.job_name}")
        print()

        # Save for use in 3b_marketplace_inference.py
        with open(tb.config.script_dir / "asset_id.out", "w") as output:
            output.write(str(job.result.asset.uuid))

        trained_network = job.result.asset

        # Create an agreement which allows the other team to use this
        # trained model in subsequent steps.
        agreement = trained_network.add_agreement(
            with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
        )
        if agreement:
            print("Created Agreement for use of trained model Asset.")
    else:
        print(f"Training failed")
        sys.exit(1)

    # Pull down the model for local validation
    local_filename = trained_network.retrieve(save_as="local_model.zip", overwrite=True)
    print("Trained network has been downloaded as:")
    print(f"   {local_filename}")

    # Save for use in 3b_marketplace/smpc_inference.py
    with open(tb.config.script_dir / "model_asset_id.out", "w") as output:
        output.write(str(job.result.asset.uuid))

    print("Ready to run local inference.")
    print()
    print("NOTE: Markeplace inference requires creating an Agreement between this")
    print(
        f"      user ('{tb.config.example_user1['login']}') and the other user before the second"
    )
    print("      user is able to infer against this model.")
