#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
import os
import warnings

import tripleblind as tb


# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")

tb.initialize(api_token=tb.config.example_user3["token"], example=True)

# Place a pre-trained neural network on your Access Point.  This example shows
# a PyTorch .pth model file.
model_path = "original.pth"
asset_name = "EXAMPLE - Tabular transfer model"
try:
    model = tb.asset.NeuralNetwork.create(
        model_path,
        name=asset_name,
        desc="Fraud detection model used in the Transfer_Learning example.",
    )
except tb.TripleblindAssetAlreadyExists:
    model = tb.asset.NeuralNetwork.find(
        asset_name, owned_by=tb.config.example_user3["team_id"]
    )

# Save for use in 2_transfer_train.py
tb.util.save_to("model_asset_id.out", model.uuid)


# Create agreements to allow organization-three to train this model against
# training datasets owned by other teams.

session1 = tb.Session(api_token=tb.config.example_user1["token"], from_default=True)
session2 = tb.Session(api_token=tb.config.example_user2["token"], from_default=True)

prefix = "TEST" if "TB_TEST_SMALL" in os.environ else "EXAMPLE"

dataset_train0 = tb.TableAsset.find(
    f"{prefix} - SAN Customer Database", owned_by=tb.config.example_user1["team_id"]
)
dataset_train1 = tb.TableAsset.find(
    f"{prefix} - JPM Customer Database", owned_by=tb.config.example_user2["team_id"]
)
if not dataset_train0 or not dataset_train1:
    print("Datasets not found.")
    raise SystemExit(1)

dataset_train0.add_agreement(
    with_team=tb.config.example_user3["team_id"],
    operation=model,
    session=session1,
)
dataset_train1.add_agreement(
    with_team=tb.config.example_user3["team_id"],
    operation=model,
    session=session2,
)

print("Model is in position.")
