#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys

import tripleblind as tb


tb.util.set_script_dir_current()
run_id = tb.util.read_run_id()

tb.initialize(api_token=tb.config.example_user1["token"])
session2 = tb.Session(api_token=tb.config.example_user2["token"], from_default=True)
session3 = tb.Session(api_token=tb.config.example_user3["token"], from_default=True)


asset0 = tb.Asset.find(
    f"Random Multimodal-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
asset1 = tb.Asset.find(
    f"train_ct_multimodal-{run_id}", owned_by=tb.config.example_user3["team_id"]
)
inference_asset0 = tb.Asset.find(
    f"random-multimodal-test-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
inference_asset1 = tb.Asset.find(
    f"test_ct_multimodal-{run_id}", owned_by=tb.config.example_user3["team_id"]
)
if not asset0 or not asset1 or not inference_asset0 or not inference_asset1:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first")
    sys.exit(1)


builder0 = tb.NetworkBuilder()
builder0.add_dense_layer(100, 120)
builder0.add_relu()
builder0.add_dense_layer(120, 160)
builder0.add_relu()
builder0.add_dropout(0.25)
builder0.add_dense_layer(160, 200)
builder0.add_relu()
builder0.add_split()  # required split layer

builder1 = tb.NetworkBuilder()
builder1.add_conv2d_layer(1, 32, 3, 1)
builder1.add_batchnorm2d(32)
builder1.add_relu()
builder1.add_max_pool2d_layer(2, 2)
builder1.add_conv2d_layer(32, 64, 3, 1)
builder1.add_batchnorm2d(64)
builder1.add_relu()
builder1.add_max_pool2d_layer(2, 2)
builder1.add_flatten_layer()
builder1.add_split()  # required split layer


server_builder = tb.NetworkBuilder()
server_builder.add_dense_layer(12744, 1000),
server_builder.add_relu()
server_builder.add_dense_layer(1000, 128),
server_builder.add_relu()
server_builder.add_dense_layer(128, 1)


model = tb.create_vertical_network(
    "vertical_network", server=server_builder, clients=[builder0, builder1]
)


# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "SmoothL1Loss"

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 0.001}

csv_pre = tb.TabularPreprocessor.builder().all_columns(True).dtype("float32")

image_pre = (
    tb.ImagePreprocessor.builder()
    .target_column("label")
    .resize(64, 64)
    .convert("L")  # use grayscale
    .channels_first()
    .dtype("float32")
    .dicom(True)
    .target_dtype("float32")
)

job = tb.create_job(
    job_name="CT Vertical Training",
    operation=model,
    dataset=[asset0, asset1],
    preprocessor=[csv_pre, image_pre],
    params={
        "epochs": 10,
        "batchsize": 32,
        "test_size": 0.10,
        "loss_meta": {"name": loss_name},
        "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
        "data_type": "image",
        "data_shape": [200],  # number of columns of data in table
        "model_output": "regression",  # binary/multiclass/regression
    },
)

if job.submit():
    job.wait_for_completion()
    model.archive()
    if job.success:
        print(f"Asset created: {job.result.asset.uuid}")

        # Save for use in 3b_aes_inference.py
        tb.util.save_to("asset_id.out", job.result.asset.uuid)

        # Create an agreement which allows the other team to use this trained
        # model in subsequent steps.
        agreement = job.result.asset.add_agreement(
            with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
        )
        if agreement:
            print("Created Agreement for use of trained Asset.")

    # Create agreements to allow the other team to use the
    # inference datasets to run inferences against this model.
    inference_asset0.add_agreement(
        with_team=tb.config.example_user1["team_id"],
        operation=job.result.asset,
        session=session2,
    )
    inference_asset1.add_agreement(
        with_team=tb.config.example_user1["team_id"],
        operation=job.result.asset,
        session=session3,
    )
