#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
from pathlib import Path

import numpy as np
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

print("Generating data...")

# Create PSI keys.
t1_k = np.concatenate(
    (
        4 * np.arange(60),
        4 * np.arange(50) + 1,
    )
)
t2_k = np.concatenate(
    (
        4 * np.arange(60),
        4 * np.arange(60) + 2,
    )
)
t3_k = np.concatenate(
    (
        4 * np.arange(60),
        4 * np.arange(70) + 3,
    )
)
np.random.shuffle(t1_k)
np.random.shuffle(t2_k)
np.random.shuffle(t3_k)
t1_k = t1_k.astype(np.float64)
t2_k = t2_k.astype(np.float64)
t3_k = t3_k.astype(np.float64)

# Create feature columns.
t1_n = np.random.normal(size=(110,))
t1_u = np.random.uniform(size=(110,))
t2_u = np.random.uniform(size=(120,))
t2_d3 = np.random.randint(1, 4, size=(120,))
t3_d3 = np.random.randint(1, 4, size=(130,))
t3_d10 = np.random.randint(1, 9, size=(130,))
t3_d10[
    :10
] = 1  # Ensure that t3_d10 has at least one group that is above the k-grouping threshold.
t3_d10[
    -4:
] = 9  # Ensure that t3_d10 has at least one group that is below the k-grouping threshold.

# Insert NaN values into numerical columns.
t1_n[np.random.binomial(1, 0.1, size=(t1_n.shape[0],))] = np.NAN
t1_u[np.random.binomial(1, 0.1, size=(t1_u.shape[0],))] = np.NAN
t2_u[np.random.binomial(1, 0.1, size=(t2_u.shape[0],))] = np.NAN

# Assemble columns into pandas DataFrames.
df1 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t1_n, 1),
            np.expand_dims(t1_u, 1),
            np.expand_dims(t1_k, 1),
        ),
        1,
    ),
    columns=["Normal", "Continuous", "identifier"],
)
df2 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t2_u, 1),
            np.expand_dims(t2_d3, 1) * 10,
            np.expand_dims(t2_k, 1),
        ),
        1,
    ),
    # intentionally overlap column name with first dataset
    columns=["Continuous", "DiscreteA", "id"],
)
df3 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t3_d3, 1) * 100,
            np.expand_dims(t3_d10, 1) * 1000,
            np.expand_dims(t3_k, 1),
        ),
        1,
    ),
    columns=["DiscreteB", "DiscreteC", "id"],
)

df1.to_csv(data_dir / "dataset0.csv", index=False)
df2.to_csv(data_dir / "dataset1.csv", index=False)
df3.to_csv(data_dir / "dataset2.csv", index=False)
print(f"Data saved under '{data_dir}'")
