#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import json
import os
import shutil
import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path

import pandas as pd
import torch

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# Label map
voc_labels = (
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
)
label_map = {k: v + 1 for v, k in enumerate(voc_labels)}
label_map["background"] = 0
rev_label_map = {v: k for k, v in label_map.items()}  # Inverse mapping


def parse_annotation(annotation_path):
    """Gather box information from xml doc

    Args:
        annotation_path (str or Path): Location of the xml document

    Returns:
        dict: boxes, labels, and difficulties data
    """
    tree = ET.parse(annotation_path)
    root = tree.getroot()

    boxes = list()
    labels = list()
    difficulties = list()
    for object in root.iter("object"):

        difficult = int(object.find("difficult").text == "1")

        label = object.find("name").text.lower().strip()
        if label not in label_map:
            continue

        bbox = object.find("bndbox")
        xmin = int(bbox.find("xmin").text) - 1
        ymin = int(bbox.find("ymin").text) - 1
        xmax = int(bbox.find("xmax").text) - 1
        ymax = int(bbox.find("ymax").text) - 1

        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(label_map[label])
        difficulties.append(difficult)

    return {"boxes": boxes, "labels": labels, "difficulties": difficulties}


def build_data_zip(voc07_path, name, id_path, slice=None):
    """Create zip file which TripleBlind uses within preprocessor

    Args:
        voc07_path (str or Path): Path to folder containing VOC2007 data
        name (str): zipfile output name
        id_path (str): path to text file which contains ids
        slice (tuple(int, int), optional): id range of data to limit data in
            zip files.  Defaults to None.
    """
    print(f"Building {name}.zip")
    train_images_path = list()
    train_objects = list()
    n_objects = 0

    work_dir = Path(tempfile.mkdtemp())

    path = Path(voc07_path)
    root = work_dir / name
    image_root = root / "images"
    label_root = root / "labels"
    record_data = root / "records.csv"

    root.mkdir(parents=True, exist_ok=True)
    image_root.mkdir(parents=True, exist_ok=True)
    label_root.mkdir(parents=True, exist_ok=True)

    # Training data

    # Find IDs of images in training data
    with open(path / f"ImageSets/Main/{id_path}") as f:
        ids = f.read().splitlines()
    if slice:
        ids = ids[slice[0] : slice[1]]
    for id in ids:
        # Parse annotation's XML file
        objects = parse_annotation(path / "Annotations" / (id + ".xml"))
        if len(objects["boxes"]) == 0:
            continue

        n_objects += len(objects)
        current_labels_path = label_root / (id + ".json")
        with open(current_labels_path, "w") as json_label_fp:
            json.dump(objects, json_label_fp)
        train_objects.append(Path("labels", id + ".json").as_posix())
        shutil.copy(path / "JPEGImages" / (id + ".jpg"), image_root / (id + ".jpg"))
        train_images_path.append(Path("images", id + ".jpg").as_posix())

    path_df = pd.DataFrame(
        {
            "paths": train_images_path,
            "label_path": train_objects,
            "polygon_key": ["boxes"] * len(train_objects),
            "labels": ["labels"] * len(train_objects),
        }
    )

    path_df.to_csv(record_data, index=False)

    tb.Package.create(
        filename=name + ".zip",
        root=work_dir / name,
        record_data=work_dir / name / "records.csv",
        path_column="paths",
        label_column="label_path",
    )

    shutil.rmtree(work_dir)


###########################################################################

if not (data_dir / "VOC2007" / "ImageSets" / "Main" / "train.txt").exists():
    # Download data used in the PASCAL Visual Object Classes Challenge 2007
    # Originally from: http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
    tb.util.download_tripleblind_resource(
        "VOC2007.zip",
        save_to_dir=data_dir,
        expand=True,
        cache_dir="../../.cache",
    )

if "TB_TEST_SMALL" in os.environ:
    split_one, split_two = (0, 10), (10, 20)
else:
    split_one, split_two = (0, 100), (100, 150)
build_data_zip(data_dir / "VOC2007", "train", "trainval.txt", split_one)
build_data_zip(data_dir / "VOC2007", "train2", "trainval.txt", split_two)
build_data_zip(data_dir / "VOC2007", "test", "val.txt")
