#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


DATA_OWNER = tb.config.example_user1
CONSUMER = tb.config.example_user2


# The data consumer i
tb.initialize(api_token=CONSUMER["token"], example=True)

# Find the training data (previously positioned)
asset0 = tb.Asset.find("Energy forecast training data", owned_by=DATA_OWNER["team_id"])


T = 30
kernel_size = 2
D_in = 1
D_out = 5


# Define a small neural-network which will analyze the training data
builder0 = (
    tb.NetworkBuilder()
    .add_conv2d_layer(
        D_in, D_out, kernel_size=(kernel_size, 1), dilation=1, padding="same"
    )
    .add_relu()
    .add_conv2d_layer(
        D_out, D_out, kernel_size=(kernel_size, 1), dilation=2, padding="same"
    )
    .add_relu()
    .add_conv2d_layer(
        D_out, D_out, kernel_size=(kernel_size, 1), dilation=4, padding="same"
    )
    .add_relu()
    .add_flatten_layer()
    .add_dense_layer(T * D_out, 1)
)


model = tb.create_network(
    f"Dilated CNN",
    builder0,
    "Dilated CNN",
    is_federated_learning_model=True,
    is_discoverable=False,
)


# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name = "SmoothL1Loss"  # Like MSE loss, but more based

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "Adam"
optimizer_params = {"lr": 1e-3}

# Preprocess data
pre = (
    tb.preprocessor.numpy_input.NumpyInputPreprocessor.builder()
    .dtype("float32")
    .expand_target_dims()
    .target_column("target")
    .target_dtype("float32")
)


result = model.train(
    data=asset0,
    preprocessor=pre,
    epochs=2,
    model_output="regression",
    batch_size=32,
    loss_name=loss_name,
    optimizer_name=optimizer_name,
    optimizer_params=optimizer_params,
    data_type="numpy",
)

# Display and save the trained model ID for later use
print(f"Model trained, Asset created: {result.asset.uuid}")
tb.util.save_to("trained_model_asset_id.out", result.asset.uuid)


# Create an agreement which allows the other team to use this trained
# model in subsequent steps.
agreement = result.asset.add_agreement(
    with_team=DATA_OWNER["team_id"], operation=tb.Operation.EXECUTE
)
if agreement:
    print("Created Agreement for use of trained Asset.")
