#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


# Unique value used by all scripts in this folder.  Edit "run_id.txt" to change
run_id = tb.util.read_run_id()


##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'.
tb.initialize(api_token=tb.config.example_user1["token"], example=True)


# Find the first training database in the Router index
found = tb.TableAsset.find(
    f"ratings0-{run_id}", owned_by=tb.config.example_user3["team_id"]
)
if found:
    print(f"Found dataset 'ratings0-{run_id}'")
    dataset_train0 = found
else:
    raise SystemExit("You must run 1_position_data_on_accesspoint.py first")

found = tb.TableAsset.find(
    f"ratings1-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
if found:
    print(f"Found dataset 'ratings1-{run_id}'")
    dataset_train1 = found
else:
    raise SystemExit("You must run 1_position_data_on_accesspoint.py first")

pre = tb.TabularPreprocessor.builder().all_columns(True)

# train the model
job = tb.create_job(
    job_name=f"Recommendation model - {tb.util.timestamp()}",
    operation=tb.Operation.RECOMMENDER_TRAIN,
    dataset=[dataset_train0, dataset_train1],
    preprocessor=pre,
    params={
        "learning_rate": 0.001,
        "hidden_dim": 10,
        "epochs": 10,
        "user_id_column": "userId",
        "item_id_column": "movieId",
        "rating_column": "rating",
    },
)

print("Training network...")
if job.submit():
    job.wait_for_completion()

    if job.success:
        print()
        print("Trained Network Asset ID:")
        print("    ===============================================")
        print(f"    ===>  {job.result.asset.uuid} <===")
        print("    ===============================================")
        print("    Algorithm: Deep Learning Model")
        print(f"    Job ID:    {job.job_name}")
        print()

        # Save for use in 3b_fed_inference.py
        with open("asset_id.out", "w") as output:
            output.write(str(job.result.asset.uuid))

        trained_network = job.result.asset
        agreement = job.result.asset.add_agreement(
            with_team=tb.config.example_user2["team_id"],
            operation=tb.Operation.EXECUTE,
        )
    else:
        raise SystemExit("Training failed")
