#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


# Establish a connection to the Router as `example_user1`
# Authentication details must be configured in `tripleblind.yaml`
tb.initialize(api_token=tb.config.example_user3["token"], example=True)

# Find the training data in the Router index
train_dataset_1 = tb.Asset.find(
    "EXAMPLE - Movie rating database - part one",
    owned_by=tb.config.example_user1["team_id"],
)
train_dataset_2 = tb.Asset.find(
    "EXAMPLE - Movie rating database - part two",
    owned_by=tb.config.example_user2["team_id"],
)

# Identify the columns in your dataset
# The protocol expects the data columns to be named `text` and `entities`
# `labels` is the target column. Refer to `README.txt` for the data format
bert_pre = (
    tb.TabularPreprocessor.builder()
    .add_column("text")
    .add_column("label")
    .python_transform("python_transform_script.py")
)

# Define the training hyperparameters
job = tb.create_job(
    job_name="SC Training",
    operation=tb.Operation.SC_TRAIN,
    dataset=[train_dataset_1, train_dataset_2],
    preprocessor=bert_pre,
    params={
        "learning_rate": 3e-4,
        "epochs": 5,
        "batch_size": 16,
        "weight_decay": 1e-2,
        "test_size": 0.3,
        "evaluation_metric": "accuracy",
    },
)

# start the training and download the output model
if job.submit():
    job.wait_for_completion()

    if job.success:
        # Print the output report
        print(job.result)
        # Download the trained model locally
        job.result.asset.download("output-model.zip", overwrite=True)
    else:
        raise SystemExit("SC Training failed.")
