#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
from pathlib import Path

import sklearn.datasets
import numpy as np
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

print("Generating data...")

dataset, target = sklearn.datasets.make_classification(n_samples=1000)
train_dataset, infer_dataset = np.array_split(dataset, 2)
train_target = target[: train_dataset.shape[0]]
infer_target = target[train_dataset.shape[0] :]
train_dataset_partitions = np.array_split(train_dataset, 2, 1)
infer_dataset_partitions = np.array_split(infer_dataset, 2, 1)

# Create PSI keys.
t1_k = np.concatenate(
    (
        4 * np.arange(250),
        4 * np.arange(250) + 1,
    )
)
t2_k = np.concatenate(
    (
        4 * np.arange(250),
        4 * np.arange(250) + 2,
    )
)
t1_k = t1_k.astype(np.float64)
t2_k = t2_k.astype(np.float64)

# Assemble columns into pandas DataFrames.
dft1 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t1_k, 1),
            np.expand_dims(train_target, 1),
            train_dataset_partitions[0],
        ),
        1,
    ),
    columns=["identifier", "target"]
    + ["f{:d}".format(i) for i in range(train_dataset_partitions[0].shape[1])],
)
dft2 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t2_k, 1),
            train_dataset_partitions[1],
        ),
        1,
    ),
    # intentionally overlap column name with first dataset
    columns=["id"]
    + ["f{:d}".format(i) for i in range(train_dataset_partitions[1].shape[1])],
)
dfi1 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t1_k, 1),
            infer_dataset_partitions[0],
        ),
        1,
    ),
    columns=["identifier"]
    + ["f{:d}".format(i) for i in range(infer_dataset_partitions[0].shape[1])],
)
dfi2 = pd.DataFrame(
    np.concatenate(
        (
            np.expand_dims(t2_k, 1),
            infer_dataset_partitions[1],
        ),
        1,
    ),
    # intentionally overlap column name with first dataset
    columns=["id"]
    + ["f{:d}".format(i) for i in range(infer_dataset_partitions[1].shape[1])],
)

dft1.to_csv(data_dir / "datasett0.csv", index=False)
dft2.to_csv(data_dir / "datasett1.csv", index=False)
dfi1.to_csv(data_dir / "dataseti0.csv", index=False)
dfi2.to_csv(data_dir / "dataseti1.csv", index=False)
print(f"Data saved under '{data_dir}'")
