#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'.
tb.initialize(api_token=tb.config.example_user1["token"], example=True)


# Find the training databases in the Router index
dataset_train0 = tb.Asset.find(
    "Movie Rating Data", owned_by=tb.config.example_user2["team_id"]
)
if not dataset_train0:
    print("You must run 1_position_data_on_accesspoint.py first")
    exit()

print("Disclaimer: This algorithm has not been tested with multiple datasets.")

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name = "SGD"
optimizer_params = {"lr": 0.001}

bert_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("comments")
    .add_column("ratings", target=True)
)

# This sentiment model training is built upon https://huggingface.co/bert-base-uncased.

# train the network
job = tb.create_job(
    job_name="BERT Test",
    operation=tb.Operation.BERT_SEQ_CLF_TRAIN,
    dataset=[dataset_train0],
    preprocessor=bert_pre,
    params={
        "epochs": 1,
        "batchsize": 4,
        "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
        "num_labels": 2,
        "target_column": "ratings",
        "data_column": "comments",
        "test_size": 0.2,
    },
)


print("Training network...")
if job.submit():
    job.wait_for_completion()

    if job.success:
        print()
        print("Trained Network Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Deep Learning Model")
        print(f"    Job ID:    {job.job_name}")
        print()

        # Save for use in 3b_fed_inference.py
        with open("asset_id.out", "w") as output:
            output.write(str(job.result.asset.uuid))

        trained_network = job.result.asset
    else:
        raise SystemExit("Training failed")

    # Pull down the model for local validation
    local_filename = trained_network.retrieve(save_as="local.zip", overwrite=True)
    print("Trained network has been downloaded as:")
    print(f"   {local_filename}")

    # Save for use in 3_bert_inference.py
    tb.util.save_to("model_asset_id.out", job.result.asset.uuid)

    print("Ready to run local inference.")
    print()

    # Create an agreement which allows the other team to use this
    # trained model in subsequent steps.
    agreement = job.result.asset.add_agreement(
        with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
    )
    if agreement:
        print("Created Agreement for use of trained Asset.")
