#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import shutil
import sys
import tempfile
from multiprocessing import Pool, cpu_count
from pathlib import Path

import pandas as pd
from PIL import Image
from tensorflow.keras import datasets
from tripleblind.util.timer import Timer

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)


def main():
    if "TB_TEST_SMALL" in os.environ:
        count = 300
    else:
        count = 30000  # Update to change the size of the training dataset
    if count < 1 or count > 30000:
        print("ERROR: Count must be between 1 and 30000")
        sys.exit(1)

    # Download images Cached on TripleBlind's Google Drive:
    download_images()

    print(f"Creating MNIST image training dataset sized: {count}...")
    with Timer("Preprocessing data"):
        # Create temporary dir to build dataset structures
        work_dir = Path(tempfile.mkdtemp())

        # Retrieve the standard MNIST datasets of handwritten images and labels
        (train_images, train_labels), (_, _) = datasets.mnist.load_data()

        # Create two equal sized sets of images and labels
        train_images0 = train_images[:count]
        train_labels0 = train_labels[:count]

        train_images1 = train_images[-count:]
        train_labels1 = train_labels[-count:]

        build_data_in_folder("train_mnist_0", train_images0, train_labels0, work_dir)
        build_data_in_folder("train_mnist_1", train_images1, train_labels1, work_dir)

    print()
    with Timer("Creating assets"):
        create_package("train_mnist_0", work_dir, data_dir)
        create_package("train_mnist_1", work_dir, data_dir)

    # clean up working files (all data is in .zips now)
    print()
    with Timer("Cleaning up"):
        shutil.rmtree(work_dir)


def build_data_in_folder(name, images, labels, work_dir):
    # Each dataset is created in a folder similar to this:
    #    name /
    #           records.csv
    #           images /
    #                    1.png
    #                    2.png
    #                    ...
    # with records.csv containing a directory of contents like so:
    #    "label", "paths"
    #    "first label", "images/1.png"
    #    "second label", "images/2.png"
    #    ...

    root = work_dir / name
    image_root = root / "images"
    record_data = root / "records.csv"

    # Make directories
    root.mkdir(parents=True, exist_ok=True)
    image_root.mkdir(parents=True, exist_ok=True)

    # Build CSV with path to all images, plus labels for those images
    label_frame = pd.DataFrame(labels)
    label_frame.columns = ["label"]
    rel_file_names = [
        (Path("images") / f"{i}.png").as_posix() for i in range(len(images))
    ]
    label_frame["paths"] = rel_file_names

    label_frame.to_csv(record_data, index=False)

    # Save Images
    file_names = [image_root / f"{i}.png" for i in range(len(images))]
    with Pool(cpu_count()) as p:
        p.map(save_image, zip(images, file_names))


def create_package(name, work_dir, dest_dir):
    return tb.Package.create(
        filename=dest_dir / (name + ".zip"),
        root=work_dir / name,
        record_data=work_dir / name / "records.csv",
        path_column="paths",
    )


def save_image(img_name_pair):
    pil_img = Image.fromarray(img_name_pair[0])
    pil_img.save(img_name_pair[1])


def download_images():
    ###########################################################################
    # Downloading images used for inference
    ###########################################################################
    tb.util.download_tripleblind_resource(
        "three.jpg",
        save_to_dir=data_dir,
        cache_dir="../../.cache",
    )
    tb.util.download_tripleblind_resource(
        "four.jpg",
        save_to_dir=data_dir,
        cache_dir="../../.cache",
    )
    tb.util.download_tripleblind_resource(
        "seven.jpg",
        save_to_dir=data_dir,
        cache_dir="../../.cache",
    )
    tb.util.download_tripleblind_resource(
        "big_eight.jpg",
        save_to_dir=data_dir,
        cache_dir="../../.cache",
    )


if __name__ == "__main__":
    main()
