#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from sklearn.metrics import accuracy_score, classification_report

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Find test datasets in Router's index
asset0 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, test data (part 1)",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, test data (part 2)",
    owned_by=tb.config.example_user2["team_id"],
)
asset2 = tb.TableAsset.find(
    "EXAMPLE - Decision Tree classification, test data (part 3)",
    owned_by=tb.config.example_user3["team_id"],
)


# Find trained model
try:
    asset_id = tb.util.load_from("clf_model_asset_id.out")
    model = tb.Asset(asset_id)
except:
    raise SystemExit("No model found. You must run 1a_train_classification.py")

# For this example we will create agreements between the data and model owners
# to allow the inference to be performed without requiring intervention to
# grant permission. Normally this would be handled by the data owners themselves.
asset1.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=tb.Session(api_token=tb.config.example_user2["token"], from_default=True),
)
asset2.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=tb.Session(api_token=tb.config.example_user3["token"], from_default=True),
)


csv_pre = tb.TabularPreprocessor.builder().all_columns()

job = tb.create_job(
    job_name="Decision Tree Distributed Inference",
    operation=model,
    dataset=[asset0, asset1, asset2],
    preprocessor=csv_pre,
    params={"psi": {"match_column": "ID"}},
)

if job.submit():
    job.wait_for_completion()

    if job.success:
        # Save result for later inspection and display them
        result = job.result
        filename = result.asset.retrieve(save_as="fed_inf_result.zip", overwrite=True)
        print("\nInference results:")
        print("    ", result.dataframe.values.flatten())

        # Retrieve the "truth" to calculate accuracy scores
        df = tb.TableAsset.find(
            "EXAMPLE - Decision Tree classification, test truth",
            owned_by=tb.config.example_user1["team_id"],
        ).dataframe
        df["ID"] = df["ID"].astype(str)
        df = df.sort_values("ID")
        truth = df["y"].to_list()
        print("\TripleBlind model:")
        print(classification_report(truth, result.dataframe.values.flatten()))

        # Calculate simple accuracy score
        accuracy = accuracy_score(truth, result.dataframe.values.flatten())
        print(f"Classification accuracy: {accuracy}")
    else:
        print(f"Inference failed")
