#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)


# Find the training database in the Router index
dataset_train0 = tb.Asset.find(
    "EXAMPLE - NASA Turbofan Failure Data (with Target)",
    owned_by=tb.config.example_user2["team_id"],
)

# Define the neural network to train
training_model_name = "example-NASA-network-trainer"
builder = tb.NetworkBuilder()
builder.add_dense_layer(24, 50)
builder.add_relu()
builder.add_dense_layer(50, 100)
builder.add_relu()
builder.add_dense_layer(100, 200)
builder.add_relu()
builder.add_split()  # required split layer
builder.add_dense_layer(200, 50)
builder.add_relu()
builder.add_dense_layer(50, 20)
builder.add_relu()
builder.add_dense_layer(20, 1)
training_model = tb.create_network(training_model_name, builder)

# Prepare the training data, marking the "target" column
csv_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("target", target=True)
    .all_columns(True)
    .dtype("float32")
)


# train the network
result = training_model.train(
    job_name=f"NASA Neural Network Train - {tb.util.timestamp()}",
    #
    # Training data
    data=dataset_train0,
    data_type="table",
    data_shape=[24],  # number of columns in the table
    preprocessor=csv_pre,
    #
    # Training parameters
    epochs=1,  # Change to ~200 to get better results
    test_size=0.1,  # reserve 10% of data for validation at each epoch
    model_output="regression",
    delete_trainer=True,  # set to destroy the training model when complete
    #
    # Loss function
    loss_name="MSELoss",
    #
    # Optimizer and settings
    optimizer_name="Adam",
    optimizer_params={"lr": 0.001},
)

trained_network = result.asset
print()
print("Trained Network Asset ID:")
print("    ===============================================")
print(f"    ===>  {trained_network.uuid} <===")
print("    ===============================================")
print("    Algorithm: Deep Learning Model")
print()

# Save for use in 3a_local_inference.py
tb.util.save_to("asset_id.out", trained_network.uuid)

# Pull down the model for local validation
local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
print("Trained network has been downloaded as:")
print(f"   {local_filename}")

# Save for use in 3a_local_inference.py
tb.util.save_to("local_model_filename.out", local_filename)

print("Ready to run local inference.")
