#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import helpers

import tripleblind as tb


# set file directory
tb.util.set_script_dir_current()

data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

# These parameters control the size of the data.
num_columns_on_left = 1
num_columns_on_right = 1
num_clusters = 6
min_points_per_cluster = 100
max_points_per_cluster = 1000

# This parameter controls the separation between clusters.
mean_std = 10

# Create the full training and inference datasets.
num_columns_total = num_columns_on_left + num_columns_on_right
training_dataset, inference_dataset = helpers.generate_datasets(
    num_columns_total,
    num_clusters,
    min_points_per_cluster=min_points_per_cluster,
    max_points_per_cluster=max_points_per_cluster,
    mean_std=mean_std,
)

# Split the full datasets into PSI Vertically Partitioned datasets.
(
    left_training_dataset,
    right_training_dataset,
) = helpers.separate_dataset(training_dataset, num_columns_on_left)
(
    left_inference_dataset,
    right_inference_dataset,
) = helpers.separate_dataset(inference_dataset, num_columns_on_left)

# Here, we change the PSI match column names to demonstrate that the
# column names need not match when performing TripleBlind PSI operations.
# Note that TripleBlind PSI operations will also work when the match columns
# have the same name.
left_training_dataset.rename(columns={"match_column": "ID0"}, inplace=True)
right_training_dataset.rename(columns={"match_column": "ID1"}, inplace=True)
left_inference_dataset.rename(columns={"match_column": "ID0"}, inplace=True)
right_inference_dataset.rename(columns={"match_column": "ID1"}, inplace=True)

# Save the datasets to file.
left_training_dataset.to_csv(data_dir / "alice_training.csv", index=False)
right_training_dataset.to_csv(data_dir / "bob_training.csv", index=False)
left_inference_dataset.to_csv(data_dir / "alice_inference.csv", index=False)
right_inference_dataset.to_csv(data_dir / "bob_inference.csv", index=False)

# Print status.
print(f"Party 1 has {left_training_dataset.shape[0]} rows in the training dataset.")
print(f"Party 2 has {right_training_dataset.shape[0]} rows in the training dataset.")
print(f"Party 1 has {left_inference_dataset.shape[0]} rows in the inference dataset.")
print(f"Party 2 has {right_inference_dataset.shape[0]} rows in the inference dataset.")
print(f"Party 1 has {num_columns_on_left} column.")
print(f"Party 2 has {num_columns_on_right} column.")
