#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Set

from faker import Faker

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

# Parameters controlling generated data.
#
# By default, generate 100,000 records in each dataset.
# Allow custom count via command line.
total_count = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
if os.environ.get("TB_TEST_SMALL"):
    total_count = 10
print(f"Generating {total_count} records...")


@dataclass
class Person:
    first_name: str
    last_name: str
    ssn: str
    credit_card: str
    account_number: str
    license_number: str


Faker.seed(423675)
fake = Faker()
random.seed(423675)

people: List[Person] = []

# First dataset has all the fake SSNs
data0_low = 0
data0_high = total_count

# Second dataset has a 50% SSN overlap.
overlap_percent = 0.50
data1_low = int(total_count - (total_count * overlap_percent))
data1_high = int(data1_low + total_count)

# Third dataset has 100 SSN overlap with both
data2_low = data1_low
data2_high = data1_low + 100

print(f"Total Count: {total_count}")
print(f"Overlap %:   {overlap_percent}")
print(f"D0 Low:      {data0_low}")
print(f"D0 High:     {data0_high}")
print(f"D1 Low:      {data1_low}")
print(f"D1 High:     {data1_high}")
print(f"D2 Low:      {data2_low}")
print(f"D2 High:     {data2_high}")


class SSNGen:
    ssns: Set[str]
    fake: Faker

    def __init__(self, local_fake: Faker):
        self.ssns = set()
        self.fake = local_fake

    def next(self) -> str:
        ssn = self.fake.ssn()
        if ssn in self.ssns:
            # Ignore duplicate random SSNs
            return self.next()
        else:
            self.ssns.add(ssn)
            return ssn


# Generate
ssngen = SSNGen(fake)
for _ in range(data1_high):
    people.append(
        Person(
            first_name=fake.first_name(),
            last_name=fake.last_name(),
            ssn=ssngen.next(),
            credit_card=fake.credit_card_number(),
            account_number=random.randint(10000000, 99999999),
            license_number=fake.isbn10(),
        )
    )
# Create first dataset: example_data/psi_data0.csv
with open(data_dir / "psi_data0.csv", "w") as out:
    out.write("id,first_name,last_name,ssn,account_number\n")
    list0 = list(people[data0_low:data0_high])
    random.shuffle(list0)
    identifier = 1
    for person in list0:
        out.write(
            f"{identifier},{person.first_name},{person.last_name},{person.ssn},{person.account_number}\n"
        )
        identifier += 1

# Create second dataset: example_data/psi_data1.csv
with open(data_dir / "psi_data1.csv", "w") as out:
    out.write("id,full_name,ssn,credit_card\n")
    list1 = list(people[data1_low:data1_high])
    random.shuffle(list1)
    identifier = 1
    for person in list1:
        out.write(
            f"{identifier},{person.first_name} {person.last_name},{person.ssn},{person.credit_card}\n"
        )
        identifier += 1

# Create a third dataset: example_data/psi_data2.csv
with open(data_dir / "psi_data2.csv", "w") as out:
    out.write("id,full_name,license_number,ssn\n")
    list2 = list(people[data2_low:data2_high])
    random.shuffle(list2)
    identifier = 1
    for person in list2:
        out.write(
            f"{identifier},{person.first_name} {person.last_name},{person.license_number},{person.ssn}\n"
        )
        identifier += 1

# Create "expected.csv" dataset of overlap SSNs from dataset 3 for validation
with open("expected.csv", "w") as out:
    overlap_size = min(data0_high - data1_low, 100)
    list_overlap = list(people[data1_low : overlap_size + data1_low])
    for person in list_overlap:
        out.write(f"{person.ssn}\n")
