#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

"""
If training with TripleBlind has been performed and all the data for inference
resides on a single, central server, then one can follow this example script
to perform inference locally.
"""


from pathlib import Path

import helpers
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()

# If True, displays the image in a window.
# If False, saves the image to a file.
SHOW_IMAGES = True


# Example data path
data_dir = Path("example_data")
if not data_dir.exists():
    raise SystemExit("Run 1_train.py before this script train the model")


# Load both parties' inference datasets from file.
left_dataset = pd.read_csv(f"{data_dir}/alice_inference.csv")
right_dataset = pd.read_csv(f"{data_dir}/bob_inference.csv")

num_columns_on_left = left_dataset.drop(columns=["ID0"]).shape[1]

# Load training results from the downloaded asset file.
package = tb.Package.load(f"{data_dir}/training_results.zip")
model = package.model()
cluster_means = model["model"]["cluster_means"]
is_cluster_empty = model["model"]["is_cluster_empty"]
zmuv_mean = model["model"]["zmuv_mean"]
zmuv_linear = model["model"]["zmuv_linear"]

# Join the datasets together locally.
full_dataset = helpers.join([left_dataset, right_dataset], ["ID0", "ID1"])

# Remove the IDs from the data.
full_dataset = full_dataset.drop(columns=["match_column"])

# Convert from pandas dataframe to numpy array for the sake of convenience.
full_dataset = full_dataset.to_numpy()

# Scale the features of the data in the same way that was done during training with TripleBlind.
full_dataset_transformed = full_dataset
full_dataset_transformed = full_dataset_transformed - zmuv_mean
full_dataset_transformed = full_dataset_transformed * zmuv_linear

# Perform inference.
labels = helpers.locally_compute_labels(
    full_dataset_transformed, cluster_means, is_cluster_empty
)
print("Inference Labels:")
print(labels)
print()

# Compute the inertia of the data as if it is located on a single, central server.
inertia = helpers.locally_compute_inertia(
    full_dataset_transformed, cluster_means, labels
)
print("Inference Inertia:", inertia)

# Plot a 2D projection of the data, where the two dimensions come from separate parties.
helpers.plot_data(
    full_dataset_transformed,
    cluster_means,
    labels,
    0,
    num_columns_on_left,
    title="Results of Local Inference with TripleBlind Trained Model",
    is_cluster_empty=is_cluster_empty,
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)
