#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import pickle
import shutil
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
from lstm_utils import one_hot_encode

import tripleblind as tb


tb.util.set_script_dir_current()


text = []
with open("sample.txt", "r") as f:
    for line in f:
        text.append(line.lower())

shortest_doc = len(text[0])
for i in text:
    if len(i) < shortest_doc:
        shortest_doc = len(i)

for i in range(len(text)):
    text[i] = text[i][:shortest_doc]

# Join all text to create text index
full_doc = set("".join(text))

# Creating a dictionary that maps index to the characters
integer_to_char_lookup = dict(enumerate(full_doc))

# Index of character to index map
char_to_integer_lookup = {char: ind for ind, char in integer_to_char_lookup.items()}

with open("indices.pkl", "wb") as f:
    pickle.dump(
        {
            "integer_to_char_lookup": integer_to_char_lookup,
            "char_to_integer_lookup": char_to_integer_lookup,
        },
        f,
    )

# Creating lists that will hold our input and target sequences
input_seq = []
target_seq = []

# Create input and expected output sequences
for i in range(len(text)):
    # Remove last character for input sequence
    input_seq.append(text[i][:-1])

    # Remove firsts character for target sequence
    target_seq.append(text[i][1:])

for i in range(len(text)):
    input_seq[i] = [char_to_integer_lookup[character] for character in input_seq[i]]
    target_seq[i] = [char_to_integer_lookup[character] for character in target_seq[i]]


dict_size = len(char_to_integer_lookup)
seq_len = shortest_doc - 1
batch_size = len(text)

input_seq = one_hot_encode(input_seq, dict_size, seq_len, batch_size)
print(
    "Input shape: {} --> (Batch Size, Sequence Length, One-Hot Encoding Size)".format(
        input_seq.shape
    )
)


# Package up all data into numpy arrays which will be within zip file used in training

root = Path(tempfile.mkdtemp())

input_seq_root = root / "input"
target_root = root / "target"

record_data = root / "records.csv"

root.mkdir(parents=True, exist_ok=True)
input_seq_root.mkdir(parents=True, exist_ok=True)
target_root.mkdir(parents=True, exist_ok=True)

target_seq = np.array(target_seq)

target_seq = target_seq.flatten().astype(np.int64)

all_record_input_paths = []
all_record_target_paths = []


i = 0
input_file_path = input_seq_root / f"{i}.npy"
target_file_path = target_root / f"{i}.npy"
np.save(input_file_path, input_seq)
np.save(target_file_path, target_seq)
record_input_path = f"input/{i}.npy"
record_target_path = f"target/{i}.npy"

all_record_input_paths.append(record_input_path)
all_record_target_paths.append(record_target_path)

record_df = pd.DataFrame(
    {"paths": all_record_input_paths, "label_path": all_record_target_paths}
)

record_df.to_csv(record_data, index=False)

tb.Package.create(
    filename="train" + ".zip",
    root=root,
    record_data=root / "records.csv",
    path_column="paths",
    label_column="label_path",
)

shutil.rmtree(root)
