#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


# Establish a connection to the Router as `example_user1`
# Authentication details must be configured in `tripleblind.yaml`
tb.initialize(api_token=tb.config.example_user3["token"], example=True)

# Find the training data in the Router index
train_dataset_1 = tb.Asset.find(
    "EXAMPLE - Natural Language Processing training data",
    owned_by=tb.config.example_user2["team_id"],
)

# Identify the columns in your dataset
# The protocol expects the data columns to be named `text` and `entities`
# `entities` is the target column. Refer to `README.txt` for the data format
bert_pre = tb.TabularPreprocessor.builder().add_column("text").add_column("entities")

# Define the training hyperparameters
job = tb.create_job(
    job_name="NLP Training",
    operation=tb.Operation.NLP_TRAIN,
    dataset=[train_dataset_1],
    preprocessor=bert_pre,
    params={
        "learning_rate": 3e-4,
        "epochs": 5,
        "batch_size": 16,
        "weight_decay": 1e-2,
    },
)

# start the training and download the output model
if job.submit():
    job.wait_for_completion()

    if job.success:
        # Print the output report
        print(job.result)
        # Download the trained model locally
        job.result.asset.download("output-model.zip", overwrite=True)
    else:
        raise SystemExit("NLP Training failed.")
