#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import shutil
import tempfile
from multiprocessing import Pool, cpu_count
from pathlib import Path

import pandas as pd
from PIL import Image
from tensorflow.keras.datasets import cifar10
from tripleblind.util.timer import Timer

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)


def create_datasets(group, count):
    with Timer("Preprocessing data"):
        # Create temporary dir to build dataset structures
        work_dir = Path(tempfile.mkdtemp())

        # Retrieve the standard CIFAR-10 dataset of labeled color images
        (train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
        print(Timer.indent + f"Images: {train_images.shape} | {train_images.dtype}")
        print(Timer.indent + f"Labels: {train_labels.shape} | {train_labels.dtype}")

        # Create two equal sized sets of images and labels
        train_images_0 = train_images[:count]
        train_labels_0 = train_labels[:count]

        train_images_1 = train_images[-count:]
        train_labels_1 = train_labels[-count:]

        # Create folders with data layout
        pkg_name0 = f"{group}_train_cifar_0"
        pkg_name1 = f"{group}_train_cifar_1"
        build_data_in_folder(pkg_name0, train_images_0, train_labels_0, work_dir)
        build_data_in_folder(pkg_name1, train_images_1, train_labels_1, work_dir)
        build_data_in_folder("test_cifar", test_images[:10], test_labels[:10], work_dir)

    print()
    with Timer("Creating package files"):
        # Convert folders into Packages (.zips)
        create_package(pkg_name0, work_dir, data_dir)
        create_package(pkg_name1, work_dir, data_dir)
        create_package(f"test_cifar", work_dir, data_dir)

    # clean up working files (all data is in .zips now)
    print()
    with Timer("Cleaning up"):
        shutil.rmtree(work_dir)


def build_data_in_folder(name, images, labels, work_dir):
    # Each dataset is created in a folder similar to this:
    #    name /
    #           records.csv
    #           images /
    #                    1.png
    #                    2.png
    #                    ...
    # with records.csv containing a directory of contents like so:
    #    "label", "paths"
    #    "first label", "images/1.png"
    #    "second label", "images/2.png"
    #    ...

    with Timer(f"Creating '{name}'"):
        root = work_dir / name
        image_root = root / "images"
        record_data = root / "records.csv"

        # Make directories
        root.mkdir(parents=True, exist_ok=True)
        image_root.mkdir(parents=True, exist_ok=True)

        # Build CSV with path to all images, plus labels for those images
        label_frame = pd.DataFrame(labels)
        label_frame.columns = ["label"]
        rel_file_names = [
            (Path("images") / f"{i}.png").as_posix() for i in range(len(images))
        ]
        label_frame["paths"] = rel_file_names

        label_frame.to_csv(record_data, index=False)

        # Save Images
        file_names = [(image_root / f"{i}.png").as_posix() for i in range(len(images))]
        with Pool(cpu_count()) as p:
            p.map(save_image, zip(images, file_names))


def create_package(name, work_dir, dest_dir):
    with Timer(f"Creating package '{name}'"):
        return tb.Package.create(
            filename=dest_dir / (name + ".zip"),
            root=work_dir / name,
            record_data=work_dir / name / "records.csv",
            path_column="paths",
        )


def save_image(img_name_pair):
    pil_img = Image.fromarray(img_name_pair[0])
    pil_img.save(img_name_pair[1])


if __name__ == "__main__":
    with Timer("Building datasets"):
        create_datasets("EXAMPLE", 25000)
    with Timer("Building small test datasets"):
        create_datasets("TEST", 10)
