#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)

#############################################################################
# Find the datasets for training
#############################################################################

# Find the distributed training databases in the Router index
prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"
dataset_train0 = tb.Asset.find(
    f"{prefix} - CIFAR-10, part 1", owned_by=tb.config.example_user1["team_id"]
)
dataset_train1 = tb.Asset.find(
    f"{prefix} - CIFAR-10, part 2", owned_by=tb.config.example_user2["team_id"]
)

################################################################################
# Define the neural network architecture
#
# Architectures can be built explicitly or standard architectures can be used
# from a template via the model factory. Alternately, a prebuilt model could
# be used, as shown in the Transfer_Learning example.
# ##############################################################################
training_model_name = "example-cifar-network-trainer"

use_VGG = False
if use_VGG:
    # A complex VGG architecture can be created with just a few lines of code.
    # Use models.list_vgg_types() to see all options.
    builder = tb.ModelFactory.vgg(
        vgg_type="vgg11",
        num_classes=10,
        batch_norm=False,
        dropout=0.0,
    )
else:
    # By default this example uses the alternative explicit network builder
    # mechanism to create the architecture layer by layer.  This architecture is
    # much smaller than the VGG, even though it takes more lines of code.
    builder = tb.NetworkBuilder()
    builder.add_conv2d_layer(3, 32, 3, 1)
    builder.add_relu()
    builder.add_conv2d_layer(32, 32, 3, 1)
    builder.add_relu()
    builder.add_max_pool2d_layer(kernel_size=(2, 2), stride=2)
    builder.add_dropout(0.25)
    builder.add_conv2d_layer(32, 64, 3, 1)
    builder.add_relu()
    builder.add_conv2d_layer(64, 64, 3, 1)
    builder.add_relu()
    builder.add_max_pool2d_layer(kernel_size=(2, 2), stride=2)
    builder.add_dropout(0.25)
    builder.add_flatten_layer()
    builder.add_split()  # required split layer
    builder.add_dense_layer(1600, 512)
    builder.add_relu()
    builder.add_dense_layer(512, 10)

training_model = tb.create_network(training_model_name, builder)


#############################################################################
# Perform the training
#############################################################################

# This preprocessor will ensure all training images are the same size and format
normalize_images = (
    tb.ImagePreprocessor.builder()
    .target_column("label")
    .dtype("float32")
    .resize(32, 32)
    .channels_first()
)

result = training_model.train(
    data=[dataset_train0, dataset_train1],
    data_type="image",
    data_shape=[32, 32, 3],  # width x height x bytes-per-pixel
    preprocessor=normalize_images,
    #
    # Loss function
    loss_name="CrossEntropyLoss",
    #
    # Optimizer and parameters
    optimizer_name="Adam",
    optimizer_params={"lr": 0.001},
    #
    lr_scheduler_name="CyclicLR",
    lr_scheduler_params={
        "step_size": 30,
        "base_lr": 0.0001,
        "max_lr": 0.001,
        "mode": "triangular2",
    },
    #
    # Training settings
    epochs=2,
    test_size=0.1,
    batch_size=64,
    model_output="multiclass",
    delete_trainer=True,
)

trained_network = result.asset
print()
print(f"Trained Network Asset ID:")
print(f"    ===============================================")
print(f"    ===>  {trained_network.uuid} <===")
print(f"    ===============================================")
print(f"    Algorithm: Deep Learning Model")
print()
# Save for use in 2b_fed_inference.py / 2c_smpc_inference.py
tb.util.save_to("model_asset_id.out", trained_network.uuid)

# Add an agreement to allow access to the model by user1 in later inferences.
trained_network.add_agreement(
    with_team=tb.config.example_user1["team_id"], operation=tb.Operation.EXECUTE
)

# Pull down the model for local validation
local_filename = trained_network.retrieve(
    save_as="local.zip", overwrite=True, show_progress=True
)
print("Trained network has been downloaded as:")
print(f"   {local_filename}")

print("Ready to run local inference.")

# OPTIONAL: Create an agreement which allows the another team to
# use this trained model in subsequent steps.  This can also be done via
# the Router's web interface later.
#
# TEAM_ID = tb.config.example_user1["team_id"]
# trained_network.add_agreement(
#     with_team=TEAM_ID, operation=tb.Operation.EXECUTE
# )

# Alternately, instead of using add_agreement(), which automatically
# accepts all uses of your algorithm by the given party, you can use
# the Asset.publish_to_team() mechanism.  This allows the other party to _see_
# the algorithm and start a job, but in order to use it an explicit
# Access Request must still be granted by the owner.
#
# trained_network.publish_to_team(TEAM_ID)

# Of course, you can use the following to allow _everyone_ to see the
# algorithm on the Router's index and request to use it.
#
# trained_network.is_discoverable = True
