#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

"""
This example script combines the training results with both parties' training data
to form a visualization. Note that in typical use cases, this visualization is not
available since the vertically partitioned data will not be brought together onto
a single, central server to form the visualization.
"""

from pathlib import Path

import helpers
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()

# If True, displays the image in a window.
# If False, saves the image to a file.
SHOW_IMAGES = True

# Example data path
data_dir = Path("example_data")

# Load the full training dataset from file.
left_dataset = pd.read_csv(f"{data_dir}/alice_training.csv")
right_dataset = pd.read_csv(f"{data_dir}/bob_training.csv")

num_columns_on_left = left_dataset.drop(columns=["ID0"]).shape[1]

# Load training results from the downloaded asset file.
package = tb.Package.load(f"{data_dir}/training_results.zip")
result = package.model()
cluster_means = result["model"]["cluster_means"]
is_cluster_empty = result["model"]["is_cluster_empty"]
zmuv_mean = result["model"]["zmuv_mean"]
zmuv_linear = result["model"]["zmuv_linear"]
labels = result["data"]["labels"]
inertia = result["data"]["inertia"]

# Join the datasets to associate the labels with the data.
full_dataset = helpers.join(
    [left_dataset, right_dataset, labels], ["ID0", "ID1", "match_column"]
)

# Separate the labels from the data but keep them in order.
labels = full_dataset["labels"]
full_dataset = full_dataset.drop(columns=["match_column", "labels"])

# Convert from pandas dataframe to numpy array for the sake of convenience.
labels = labels.to_numpy()
full_dataset = full_dataset.to_numpy()

# Scale the features of the data in the same way that was done during training with TripleBlind.
full_dataset_transformed = full_dataset
full_dataset_transformed = full_dataset_transformed - zmuv_mean
full_dataset_transformed = full_dataset_transformed * zmuv_linear

# Compute the inertia of the data as if it is located on a single, central server.
print("Training Inertia:", inertia)

# Plot a 2D projection of the data, where the two dimensions come from separate parties.
helpers.plot_data(
    full_dataset_transformed,
    cluster_means,
    labels,
    0,
    num_columns_on_left,
    title="Results of TripleBlind Training",
    is_cluster_empty=is_cluster_empty,
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)
