#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
from pathlib import Path

import tripleblind as tb


tb.util.set_script_dir_current()
tb.initialize(api_token=tb.config.example_user1["token"])

if Path("model.h5").exists():
    print("\nPositioning Keras model on Access Point...")
    # Place a pre-trained neural network on your Access Point.  This example shows
    # a Keras .h5 model file.
    # Be sure to run a test inference on the resulting asset to validate it is working
    # properly. The 2_remote_inference.py script provide an example of this.
    model = tb.asset.NeuralNetwork.create(
        "model.h5", allow_overwrite=True, name="EXAMPLE Keras model", desc=""
    )

    # For this example we will attach an Agreement to the algorithm.  This
    # agreement makes the algorithm available to the other team, which
    # means the inference step will not require the algorithm owner to explicitly
    # grant permission for use.  But the usage is still logged for auditing.
    model.add_agreement(
        with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
    )
    # Save for use in 2_remote_inference.py
    tb.util.save_to("keras_asset_id.out", model.uuid)
    print(f"Done. Asset id in keras_asset_id.out {model.uuid}")


if Path("model.pth").exists():
    print("\nPositioning PyTorch model on Access Point...")
    model = tb.asset.NeuralNetwork.create(
        "model.pth", allow_overwrite=True, name="EXAMPLE PyTorch model", desc=""
    )
    model.add_agreement(
        with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
    )
    tb.util.save_to("pytorch_asset_id.out", model.uuid)
    print(f"Done. Asset id in pytorch_asset_id.out {model.uuid}")

if Path("model.onnx").exists():
    print("\nPositioning ONNX model on Access Point...")
    model = tb.asset.NeuralNetwork.create(
        "model.onnx", allow_overwrite=True, name="EXAMPLE ONNX model", desc=""
    )
    model.add_agreement(
        with_team=tb.config.example_user2["team_id"], operation=tb.Operation.EXECUTE
    )
    tb.util.save_to("onnx_asset_id.out", model.uuid)
    print(f"Done. Asset id in onnx_asset_id.out {model.uuid}")
