#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
from pathlib import Path

import numpy as np
import pandas as pd

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

# Create 3 datasets, each holding these columns:
#      value - a floating point value
#        sex - either 'M' or 'F' for gender
#  ethnicity - ethnicity ("latino", "asian", "white", "black", "hawaiian")

print("Generating data...")
n = 1000
n_missing = 5
data = pd.DataFrame(np.random.normal(loc=50, scale=25, size=n), columns=["value"])
missing_idx = np.random.randint(0, n - 1, size=n_missing)
data.iloc[missing_idx, :] = np.NAN  # Stats currently drops rows with NaN values.
cat = pd.DataFrame(np.random.choice(["M", "F"], size=(n,)), columns=["sex"])


eth = pd.DataFrame(
    np.random.choice(["latino", "asian", "white", "black"], size=(n,)),
    columns=["ethnicity"],
)
# Force exactly 4 "hawaiian" values two of the datasets (for k-grouping demo)
eth.values[4] = "hawaiian"
eth.values[700] = "hawaiian"
eth.values[800] = "hawaiian"
eth.values[900] = "hawaiian"

df = pd.concat([data, cat, eth], axis=1)

ds0, ds1, ds2 = np.array_split(df, 3)
ds0.to_csv(data_dir / "dataset0.csv", index=False)
ds1.to_csv(data_dir / "dataset1.csv", index=False)
ds2.to_csv(data_dir / "dataset2.csv", index=False)
print(f"Data saved under '{data_dir}'")
