#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from sklearn.metrics import mean_squared_error, r2_score

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Find test datasets in Router's index
asset0 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree regression, test data (part 1)",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree regression, test data (part 2)",
    owned_by=tb.config.example_user2["team_id"],
)


# Find trained model
try:
    asset_id = tb.util.load_from("reg_model_asset_id.out")
    model = tb.Asset(asset_id)
except:
    raise SystemExit("No model found. You must run 1b_train_regression.py")


# For this example we will create an agreement between the data and model owners
# to allow the inference to be performed without requiring intervention to
# grant permission. Normally this would be handled by the data owners themselves.
asset1.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=tb.Session(api_token=tb.config.example_user2["token"], from_default=True),
)


csv_pre = tb.TabularPreprocessor.builder().all_columns()

job = tb.create_job(
    job_name="Decision Tree Distributed Inference",
    operation=model,
    dataset=[asset0, asset1],
    preprocessor=csv_pre,
    params={"psi": {"match_column": "ID"}},
)

if job.submit():
    job.wait_for_completion()

    if job.success:
        # Save result for later inspection and display them
        result = job.result
        filename = result.asset.retrieve(save_as="fed_inf_result.zip", overwrite=True)
        print("\nInference results:")
        print("    ", result.dataframe.values.flatten())

        # Retrieve the "truth" to calculate accuracy scores
        test_df = tb.TableAsset.find(
            "EXAMPLE - Decision Tree regression, test truth",
            owned_by=tb.config.example_user1["team_id"],
        ).dataframe
        test_df["ID"] = test_df["ID"].astype(str)
        test_df = test_df.sort_values("ID")
        truth = test_df["y"].to_list()
        mse = mean_squared_error(truth, result.dataframe.values.flatten())
        print("TripleBlind model:")
        print(f"    mse: {mse}")
        print(f"     r2: {r2_score(truth, result.dataframe.values.flatten())}")
    else:
        print(f"Inference failed")
