#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

"""
This example script combines the training results with both parties' training data
to form a visualization. Note that in typical use cases, this visualization is not
available since the vertically partitioned data will not be brought together onto
a single, central server to form the visualization.
"""
from pathlib import Path

import helpers
import numpy as np
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()

# If True, displays the image in a window.
# If False, saves the image to a file.
SHOW_IMAGES = True

# Example data path
data_dir = Path("example_data")

# Load both parties' datasets from file.
left_training_dataset = pd.read_csv(f"{data_dir}/alice_training.csv")
right_training_dataset = pd.read_csv(f"{data_dir}/bob_training.csv")
left_inference_dataset = pd.read_csv(f"{data_dir}/alice_inference.csv")
right_inference_dataset = pd.read_csv(f"{data_dir}/bob_inference.csv")

num_columns_on_left = left_training_dataset.drop(columns=["ID0"]).shape[1]

# Join the datasets together locally.
full_training_dataset = helpers.join(
    [left_training_dataset, right_training_dataset], ["ID0", "ID1"]
)
full_inference_dataset = helpers.join(
    [left_inference_dataset, right_inference_dataset], ["ID0", "ID1"]
)

# Remove the match_column from the data.
full_training_dataset = full_training_dataset.drop(columns=["match_column"])
full_inference_dataset = full_inference_dataset.drop(columns=["match_column"])

# Convert from pandas dataframe to numpy array for the sake of convenience.
full_training_dataset = full_training_dataset.to_numpy()
full_inference_dataset = full_inference_dataset.to_numpy()

# Print status.
print(f"There are {full_training_dataset.shape[0]} rows in the training dataset.")
print(f"There are {full_inference_dataset.shape[0]} rows in the inference dataset.")

# Plot training data.
helpers.plot_data(
    full_training_dataset,
    None,
    np.zeros((full_training_dataset.shape[0],), dtype=np.int64),
    0,
    num_columns_on_left,
    title="2D Projection of Training Data",
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)

# Plot inference data.
helpers.plot_data(
    full_inference_dataset,
    None,
    np.zeros((full_inference_dataset.shape[0],), dtype=np.int64),
    0,
    num_columns_on_left,
    title="2D Projection of Inference Data",
    xlabel="First Column from First Party",
    ylabel="First Column from Second Party",
    show=SHOW_IMAGES,
    savefig=True,
)
