#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

"""
For easy comparison, this script illustrates K-Means clustering in a non-private
method using SKLearn, where all the data from all parties for both training and
inference resides on a single, central server.
"""

from pathlib import Path

import helpers
import pandas as pd
from sklearn.cluster import KMeans

import tripleblind as tb


tb.util.set_script_dir_current()

# If True, displays the image in a window.
# If False, saves the image to a file.
SHOW_IMAGES = True

# Example data path
data_dir = Path("example_data")

# Load the full dataset from file.
left_training_dataset = pd.read_csv(f"{data_dir}/alice_training.csv")
right_training_dataset = pd.read_csv(f"{data_dir}/bob_training.csv")
left_inference_dataset = pd.read_csv(f"{data_dir}/alice_inference.csv")
right_inference_dataset = pd.read_csv(f"{data_dir}/bob_inference.csv")

num_columns_on_left = left_training_dataset.drop(columns=["ID0"]).shape[1]

# Load the clustering model from training from the downloaded asset file.
package = tb.Package.load(f"{data_dir}/training_results.zip")
result = package.model()
cluster_means = result["model"]["cluster_means"]
zmuv_mean = result["model"]["zmuv_mean"]
zmuv_linear = result["model"]["zmuv_linear"]

# Join the datasets together locally.
full_training_dataset = helpers.join(
    [left_training_dataset, right_training_dataset], ["ID0", "ID1"]
)
full_inference_dataset = helpers.join(
    [left_inference_dataset, right_inference_dataset], ["ID0", "ID1"]
)

# Remove the match column from the data.
full_training_dataset = full_training_dataset.drop(columns=["match_column"])
full_inference_dataset = full_inference_dataset.drop(columns=["match_column"])

# Convert from pandas dataframe to numpy array for the sake of convenience.
full_training_dataset = full_training_dataset.to_numpy()
full_inference_dataset = full_inference_dataset.to_numpy()

# Scale the features of the data in the same way that was done during training with TripleBlind.
full_training_dataset_transformed = full_training_dataset
full_training_dataset_transformed = full_training_dataset_transformed - zmuv_mean
full_training_dataset_transformed = full_training_dataset_transformed * zmuv_linear
full_inference_dataset_transformed = full_inference_dataset
full_inference_dataset_transformed = full_inference_dataset_transformed - zmuv_mean
full_inference_dataset_transformed = full_inference_dataset_transformed * zmuv_linear

# Perform clustering in a non-private way by using SKLearn on the full training dataset.
num_clusters = cluster_means.shape[0]
kmeans = KMeans(n_clusters=num_clusters)

# NOTE: This next line might cause an error message. It can be safely ignored.
#    AttributeError: 'NoneType' object has no attribute 'split'
# If you see this, upgrade threadpoolctl to above version 3.  For more info see:
#    https://stackoverflow.com/a/72840515
kmeans.fit(full_training_dataset_transformed)

# Perform inference with SKLearn.
labels = kmeans.predict(full_inference_dataset_transformed)

# Compute the inertia with the inference data.
inference_inertia = helpers.locally_compute_inertia(
    full_inference_dataset_transformed, kmeans.cluster_centers_, labels
)

# Print status.
print("\nTraining Inertia:", kmeans.inertia_)
print("\nInference Inertia:", inference_inertia)

# Plot training data.
helpers.plot_data(
    full_training_dataset_transformed,
    kmeans.cluster_centers_,
    kmeans.labels_,
    0,
    num_columns_on_left,
    title="Results of SKLearn Training",
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)

# Plot inference data.
helpers.plot_data(
    full_inference_dataset_transformed,
    kmeans.cluster_centers_,
    labels,
    0,
    num_columns_on_left,
    title="Results of SKLearn Inference with SKLearn Trained Model",
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)
