#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)


# Find the training databases in the Router index
dataset_train1 = tb.Asset.find(
    "EXAMPLE - NASA Turbofan Failure Data (Part 1)",
    owned_by=tb.config.example_user3["team_id"],
)
dataset_train3 = tb.Asset.find(
    "EXAMPLE - NASA Turbofan Failure Data (Part 3)",
    owned_by=tb.config.example_user1["team_id"],
)

# Define the neural network to train
training_model_name = "example-NASA-network-trainer"
builder = tb.NetworkBuilder()
builder.add_conv2d_layer(1, 10, [10, 1], 1)
builder.add_relu()
builder.add_conv2d_layer(10, 10, [10, 1], 1)
builder.add_relu()
builder.add_conv2d_layer(10, 10, [10, 1], 1)
builder.add_relu()
builder.add_flatten_layer()
builder.add_split()  # required split layer
builder.add_dense_layer(720, 1),
training_model = tb.create_network(training_model_name, builder)

# Prepare the training data, marking the "target" column
pre = (
    tb.preprocessor.numpy_input.NumpyInputPreprocessor.builder()
    .target_column("target")
    .expand_target_dims()
    .dtype("float32")
)


# train the network
result = training_model.train(
    data=[dataset_train1, dataset_train3],
    data_type="image",
    data_shape=[24, 30, 1],  # width x height x 1 for grayscale
    preprocessor=pre,
    #
    # Training params
    epochs=1 if "TB_TEST_SMALL" in os.environ else 5,  # use 250 for high accuracy
    batch_size=512,
    test_size=0.1,
    model_output="regression",
    delete_trainer=True,
    #
    job_name=f"NASA CNN Model Train - {tb.util.timestamp()}",
    #
    # Loss function
    loss_name="SmoothL1Loss",
    #
    # Optimizer
    optimizer_name="Adam",
    optimizer_params={"lr": 0.001},
)

trained_network = result.asset
print()
print(f"Trained Network Asset ID:")
print(f"    ===============================================")
print(f"    ===>  {trained_network.uuid} <===")
print(f"    ===============================================")
print(f"    Algorithm: Deep Learning Model")
print()

# Save for use in 3b_fed_inference / 3c_smpc_inference.py
tb.util.save_to("asset_id.out", trained_network.uuid)

# Pull down the model for local validation
local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
print(f"Trained network has been downloaded as:")
print(f"   {local_filename}")

# Save for use in 3a_local_inference.py
tb.util.save_to("local_model_filename.out", local_filename)

print("Ready to run local inference.")
