#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import shutil
import tempfile
from pathlib import Path

import pandas as pd
from sklearn.datasets import make_regression
from tripleblind.util.timer import Timer

import tripleblind as tb


count = 99  # Size of the training dataset, max = 99
test_count = 100 - count  # Size of the test dataset

# Create a synthetic set of data using a random linear regression
tabular, _ = make_regression(100)
df = pd.DataFrame(tabular)
df[:count].to_csv(tb.config.data_dir / "random_multimodal_train.csv", index=False)
df[count : count + test_count].to_csv(
    tb.config.data_dir / "random_multimodal_test.csv", index=False
)

# This example uses CT Scans and other patient data found in this Kaggle
# notebook: https://www.kaggle.com/kmader/siim-medical-images
if not (tb.config.data_dir / "ct_dicom/overview.csv").exists():
    tb.util.kaggle.dataset_download_files(
        "kmader/siim-medical-images", tb.config.data_dir / "ct_dicom"
    )


def build_data_in_folder(name, image_paths, labels, work_dir):
    # Each dataset is created in a folder similar to this:
    #    name /
    #           records.csv
    #           images /
    #                    1.png
    #                    2.png
    #                    ...
    # with records.csv containing a directory of contents like so:
    #    "label", "paths"
    #    "first label", "images/1.png"
    #    "second label", "images/2.png"
    #    ...

    root = work_dir / name
    image_root = root / "dicom_images"
    record_data = root / "records.csv"

    # Make directories
    root.mkdir(parents=True, exist_ok=True)
    image_root.mkdir(parents=True, exist_ok=True)

    # Build CSV with path to all images, plus labels for those images
    label_frame = pd.DataFrame(labels)
    label_frame.columns = ["label"]
    rel_file_names = [
        (Path("dicom_images") / f"{i}.dcm").as_posix() for i in range(len(image_paths))
    ]
    label_frame["paths"] = rel_file_names

    label_frame.to_csv(record_data, index=False)

    # Save Images
    for i in range(len(image_paths)):
        file_name = image_root / f"{i}.dcm"
        shutil.copy(image_paths[i], file_name)


def create_package(name, work_dir, dest_dir):
    return tb.Package.create(
        filename=dest_dir / (name + ".zip"),
        root=work_dir / name,
        record_data=work_dir / name / "records.csv",
        path_column="paths",
    )


print(f"Creating CT image training dataset sized: {count}...")
with Timer("Preprocessing data"):
    # Create temporary dir to build dataset structures
    work_dir = Path(tempfile.mkdtemp())

    image_df = pd.read_csv(tb.config.data_dir / "ct_dicom/overview.csv")

    image_paths = [
        tb.config.data_dir / "ct_dicom/dicom_dir" / x
        for x in image_df["dicom_name"].values
    ]
    labels = image_df["Age"].values

    build_data_in_folder(
        "train_ct_multimodal",
        image_paths=image_paths[:count],
        labels=labels[:count],
        work_dir=work_dir,
    )
    build_data_in_folder(
        "test_ct_multimodal",
        image_paths[count : count + test_count],
        labels[count : count + test_count],
        work_dir,
    )
    with open("expected.out", "w") as f:
        f.writelines(str(labels[count : count + test_count]))

print()
with Timer("Creating assets"):
    create_package("train_ct_multimodal", work_dir, tb.config.data_dir)
    create_package("test_ct_multimodal", work_dir, tb.config.data_dir)

# clean up working files (all data is in .zips now)
print()
with Timer("Cleaning up"):
    shutil.rmtree(work_dir)
