# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import math
import os
import random
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def generate_datasets(
    num_features,
    num_clusters,
    min_points_per_cluster=10,
    max_points_per_cluster=1000,
    mean_mean=0,
    mean_std=10,
    std_rate=1,
):
    """
    Generates data that are roughly in clusters. The process by which the data
    is generated makes it suitable for demonstrating K Means Clustering.

    The clusters in training and inference correspond to each other and use
    the same means and standard deviations for their features. Clusters may overlap.

    Args:
        mean_mean:
            The mean of the means of features of a column.
        mean_std:
            The standard deviation of the means of features of a column.
        std_rate:
            The rate of the exponential distribution used in determining the
            standard deviations of features of a column.

    Return:
        training_dataset:
            A dataset that is suitable for use in training.
        inference_dataset:
            A dataset that is suitable for use in inference.
    """

    # Generate the cluster parameters.
    train_cluster_sizes = np.random.randint(
        min_points_per_cluster, max_points_per_cluster + 1, size=(num_clusters,)
    )
    inference_cluster_sizes = np.random.randint(
        min_points_per_cluster, max_points_per_cluster + 1, size=(num_clusters,)
    )
    cluster_feature_means = np.random.normal(
        mean_mean, mean_std, size=(num_clusters, num_features)
    )
    cluster_feature_stds = np.random.exponential(
        std_rate, size=(num_clusters, 1)
    ) * np.exp(-1 + np.random.random(size=(num_clusters, num_features)))

    # Generate the cluster data.
    training_dataset = []
    inference_dataset = []
    for cluster_index in range(num_clusters):
        # Determine the cluster sizes.
        train_cluster_size = train_cluster_sizes[cluster_index]
        inference_cluster_size = inference_cluster_sizes[cluster_index]

        # Determine the distribution of cluster points.
        cluster_means = cluster_feature_means[cluster_index]
        cluster_stds = cluster_feature_stds[cluster_index]

        # Generate the cluster data.
        train_data = np.random.normal(
            cluster_means,
            cluster_stds,
            (train_cluster_size, num_features),
        )
        inference_data = np.random.normal(
            cluster_means,
            cluster_stds,
            (inference_cluster_size, num_features),
        )

        # Add the data to the output data structure.
        training_dataset += [train_data]
        inference_dataset += [inference_data]

    # Form a numpy array with the full datasets.
    training_dataset = np.concatenate(training_dataset)
    inference_dataset = np.concatenate(inference_dataset)

    # Shuffle the data.
    np.random.shuffle(training_dataset)
    np.random.shuffle(inference_dataset)

    return training_dataset, inference_dataset


def separate_dataset(dataset, num_columns_on_left, intersection_fraction=None):
    """
    Convert a full dataset into two separate datasets that together form a PSI
    Vertically Partitioned dataset. The match column is named "match_column"
    and its values have type int.

    This function will determine a fraction of the rows to leave in the intersection,
    then a fraction of the nonintersecting rows to keep in the individual datasets,
    and then generate unique IDs for the rows that are kept.

    Args:
        dataset: np.ndarray

    Return:
        left_dataset: pd.DataFrame
        right_dataset: pd.DataFrame
    """

    # Determine the PSI intersection size.
    if intersection_fraction is None:
        min_intersection_fraction = 0.3
        max_intersection_fraction = 0.7
        intersection_fraction = (
            min_intersection_fraction
            + (max_intersection_fraction - min_intersection_fraction) * random.random()
        )
    intersection_size = int(math.floor(intersection_fraction * dataset.shape[0]))
    nonintersection_size = dataset.shape[0] - intersection_size

    # Determine the number of nonmatching rows.
    left_nonintersection_keep_fraction = random.random()
    right_nonintersection_keep_fraction = random.random()
    nonintersection_delete_fraction = random.random()
    sum_fractions = (
        left_nonintersection_keep_fraction
        + right_nonintersection_keep_fraction
        + nonintersection_delete_fraction
        + 1e-3
    )
    left_nonintersection_keep_fraction /= sum_fractions
    right_nonintersection_keep_fraction /= sum_fractions
    nonintersection_delete_fraction /= sum_fractions
    left_nonintersection_keep_size = int(
        math.floor(left_nonintersection_keep_fraction * nonintersection_size)
    )
    right_nonintersection_keep_size = int(
        math.floor(right_nonintersection_keep_fraction * nonintersection_size)
    )

    # Generate PSI match column values.
    final_size = (
        intersection_size
        + left_nonintersection_keep_size
        + right_nonintersection_keep_size
    )
    num_retries = 0
    while True:
        num_retries += 1
        if num_retries > 10:
            raise RuntimeError("Failed to generate unique IDs for the dataset.")

        # Generate random ids.
        ids = np.random.randint(0, 1 << 32, size=(2 * final_size,), dtype=np.int64)

        # Delete duplicates of ids.
        ids = np.unique(ids)
        ids = ids[:final_size]

        # Break from loop once a suitable set of ids has been generated.
        if ids.shape[0] == final_size:
            break

    # Collect matching row data.
    left_real = dataset[:intersection_size, :num_columns_on_left]
    right_real = dataset[:intersection_size, num_columns_on_left:]
    real_ids = ids[:intersection_size]

    # Collect nonmatching row data.
    left_nonintersection = dataset[intersection_size:, :num_columns_on_left]
    right_nonintersection = dataset[intersection_size:, num_columns_on_left:]
    left_nonintersection = np.copy(left_nonintersection)
    right_nonintersection = np.copy(right_nonintersection)
    np.random.shuffle(left_nonintersection)
    np.random.shuffle(right_nonintersection)
    left_fake = left_nonintersection[:left_nonintersection_keep_size]
    right_fake = right_nonintersection[:right_nonintersection_keep_size]
    left_fake_ids = ids[
        intersection_size : intersection_size + left_nonintersection_keep_size
    ]
    right_fake_ids = ids[intersection_size + left_nonintersection_keep_size :]

    # Convert from numpy array to pandas dataframe.
    left_dataset = pd.DataFrame(
        np.concatenate((left_real, left_fake)),
        columns=[f"data_{i}" for i in range(num_columns_on_left)],
    )
    right_dataset = pd.DataFrame(
        np.concatenate((right_real, right_fake)),
        columns=[f"data_{i}" for i in range(num_columns_on_left, dataset.shape[1])],
    )

    # Add PSI match column to dataset.
    left_dataset["match_column"] = pd.DataFrame(
        np.concatenate((real_ids, left_fake_ids)), columns=["match_column"]
    )
    right_dataset["match_column"] = pd.DataFrame(
        np.concatenate((real_ids, right_fake_ids)), columns=["match_column"]
    )

    # Shuffle the data.
    left_dataset = left_dataset.sample(frac=1)
    right_dataset = right_dataset.sample(frac=1)

    return left_dataset, right_dataset


def join(datasets: List[pd.DataFrame], match_column: List[str]) -> pd.DataFrame:
    """
    The names of the match columns are renamed to "match_column" in the output.
    """

    # Prepare the datasets for joining.
    for i in range(len(datasets)):
        datasets[i][match_column[i]] = datasets[i][match_column[i]].astype(str)
        datasets[i] = datasets[i].rename(columns={match_column[i]: "match_column"})

    # Perform the join.
    out = datasets[0]
    for i in range(1, len(datasets)):
        out = out.merge(datasets[i], on="match_column", how="inner")

    return out


def plot_data(
    full_dataset_transformed,
    cluster_means,
    labels,
    column_one,
    column_two,
    title=None,
    is_cluster_empty=None,
    xlabel=None,
    ylabel=None,
    show=True,
    savefig=False,
):
    if title is None:
        title = "2D Projection of Clustering Data"
    if is_cluster_empty is None:
        is_cluster_empty = np.zeros((np.max(labels) + 1,), dtype=np.int64)

    # Plot data grouped by their clusters.
    for label in range(np.max(labels) + 1):
        mask = labels == label
        data = full_dataset_transformed[mask]
        if is_cluster_empty[label]:
            if data.shape[0] != 0:
                raise RuntimeError("Some data was labelled for an empty cluster.")
        if data.shape[0] == 0:
            continue
        plt.scatter(
            data[:, column_one],
            data[:, column_two],
        )

    # Plot cluster means.
    if cluster_means is not None:
        plt.scatter(
            cluster_means[is_cluster_empty == 0, column_one],
            cluster_means[is_cluster_empty == 0, column_two],
            s=100,
            marker="P",
            c="black",
            label="Cluster Centers",
        )

    plt.title(title)
    plt.legend()
    if xlabel is not None:
        plt.xlabel(xlabel)
    if ylabel is not None:
        plt.ylabel(ylabel)
    if savefig:
        filename = f"{title}.png"
        filename = filename.replace(" ", "_")
        plt.savefig(filename)
        print()
        print(f"Saved image to {filename}")
        print("NOTE: Set SHOW_IMAGES=True in the script to display the image instead.")
    if show and "TB_TEST_SMALL" not in os.environ:
        plt.show()
    plt.clf()


def locally_compute_labels(full_dataset_transformed, cluster_means, is_cluster_empty):
    # Compute intermediate results for the distances.
    right = np.sum(cluster_means * cluster_means, 1)
    middle = np.matmul(full_dataset_transformed, np.transpose(cluster_means))

    # Combine intermediate results into the matrix of all distances.
    dists = np.expand_dims(right, 0) - 2 * middle

    # Mask all clusters that were marked as empty during training.
    max_value = np.max(dists)
    dists -= max_value + 1
    dists *= 1 - np.expand_dims(is_cluster_empty, 0)

    # Determine the closest cluster mean to each point.
    labels = np.argmin(dists, 1)

    return labels


def locally_compute_inertia(full_dataset_transformed, cluster_means, labels):
    assigned_cluster_means = cluster_means[labels]
    differences = full_dataset_transformed - assigned_cluster_means
    differences_squared = differences * differences
    sum_differences_squared = np.sum(differences_squared, 1)
    inertia = np.sum(sum_differences_squared, 0)

    return inertia
