#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys

import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
run_id = tb.util.read_run_id()

tb.initialize(api_token=tb.config.example_user1["token"])


# Find datasets for testing
asset0 = tb.Asset.find(
    f"random-multimodal-test-{run_id}", owned_by=tb.config.example_user2["team_id"]
)
asset1 = tb.Asset.find(
    f"test_ct_multimodal-{run_id}", owned_by=tb.config.example_user3["team_id"]
)
if not asset0 or not asset1:
    print("Datasets not found.")
    print("You must run 1_position_data_on_accesspoint.py first")
    sys.exit(1)


# Find the trained model
asset_id = tb.util.load_from("asset_id.out")
model = tb.Asset.find(asset_id, owned_by=tb.config.example_user1["team_id"])
if not model:
    print("No model found. You must run 2_model_train.py")


job = tb.create_job(
    job_name="CT Distributed Inference",
    operation=model,
    dataset=[asset0, asset1],
    params={},
)

if job.submit():
    job.wait_for_completion()

    if job.success:
        filename = job.result.asset.retrieve(
            save_as="fed_inf_result.zip", overwrite=True
        )
        result = job.result.table.dataframe.to_numpy().flatten()
        expected = pd.read_csv("expected.out", header=None)
        print("\nInference results:")
        print("    ", result)
        print("Expected output:")
        with open("expected.out", "r") as f:
            print("    ", f.readline())

    else:
        print(f"Inference failed")
