#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys

import torch

import tripleblind as tb


tb.util.set_script_dir_current()
run_id = tb.util.read_run_id()

tb.initialize(api_token=tb.config.example_user1["token"])

# Find training datasets in Router index
asset_0 = tb.Asset.find(
    f"sant psi vertical 0-40-{run_id}", owned_by=tb.config.example_user1["team_id"]
)
asset_1 = tb.Asset.find(
    f"sant psi vertical 41-100-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
asset_2 = tb.Asset.find(
    f"sant psi vertical 101-200-{run_id}",
    owned_by=tb.config.example_user3["team_id"],
)
if not asset_0 or not asset_1 or not asset_2:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first")
    sys.exit(1)

builder0 = tb.NetworkBuilder()
builder0.add_dense_layer(40, 120)
builder0.add_relu()
builder0.add_dense_layer(120, 160)
builder0.add_relu()
builder0.add_dropout(0.25)
builder0.add_dense_layer(160, 200)
builder0.add_relu()

builder1 = tb.NetworkBuilder()
builder1.add_dense_layer(60, 120)
builder1.add_relu()
builder1.add_dense_layer(120, 160)
builder1.add_relu()
builder1.add_dropout(0.25)
builder1.add_dense_layer(160, 200)
builder1.add_relu()

builder2 = tb.NetworkBuilder()
builder2.add_dense_layer(100, 120)
builder2.add_relu()
builder2.add_dense_layer(120, 160)
builder2.add_relu()
builder2.add_dropout(0.25)
builder2.add_dense_layer(160, 200)
builder2.add_relu()

server_builder = tb.NetworkBuilder()
server_builder.add_dense_layer(600, 160)
server_builder.add_relu()
server_builder.add_dense_layer(160, 10)
server_builder.add_relu()
server_builder.add_dense_layer(10, 1)

model = tb.create_vertical_network(
    "vertical_network",
    server=server_builder,
    clients=[builder0, builder1, builder2],
    is_psi=True,
)

csv_pre = tb.TabularPreprocessor.builder().all_columns()

#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "BCEWithLogitsLoss"
pos_weight = tb.TorchEncoder.encode(torch.arange(17, 18, dtype=torch.int32))

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 0.001}
lr_scheduler_name = "CyclicCosineDecayLR"
lr_scheduler_params = {
    "init_decay_epochs": 10,
    "min_decay_lr": 0.0001,
    "restart_interval": 3,
    "restart_interval_multiplier": 1.5,
    "restart_lr": 0.01,
}

job = tb.create_job(
    job_name="PSI + vertical partition Santander",
    operation=model,
    dataset=[asset_0, asset_1, asset_2],
    preprocessor=csv_pre,
    params={
        "psi": {"match_column": "ID_code"},
        "split_learning": {
            "epochs": 1,
            "loss_meta": {"name": loss_name, "params": {"pos_weight": pos_weight}},
            "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
            "lr_scheduler_meta": {
                "name": lr_scheduler_name,
                "params": lr_scheduler_params,
            },
            "data_type": "table",
            "data_shape": [200],  # number of columns of data in table
            "model_output": "binary",  # binary/multiclass/regression},
            "test_size": 0.2,
        },
    },
)

if job.submit():
    print("Performing PSI + vertical partition training (organization-one)...")
    print("Permission is needed from organization-two")
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(save_as="model.zip", overwrite=True)
        print("Model in file:")
        print("    ", filename)
        with open("local_model_filename.out", "w") as output:
            output.write(str(filename))
            print(f"Saved trained model: {filename}")
    else:
        print(f"Overlap failed")
        sys.exit(1)
