#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
from datetime import datetime
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw, ImageFont

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")

###########################################################################
# Retrieve airplane image for inference from TripleBlind's demo data server.
img_path = tb.util.download_tripleblind_resource(
    "Flying-airplane.jpg",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'
tb.initialize(api_token=tb.config.example_user2["token"])

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("model_asset_id.out", "r") as f:
        asset_id = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)
trained_network = tb.Asset(asset_id)

# Run ROI trained model against a several examples of handwritten digits
list_of_files = [img_path]
inference_predictions = []
for name in list_of_files:
    job = tb.create_job(
        job_name="Test trained network - " + str(datetime.now()),
        operation=trained_network,
        params={"security": "fed"},
        dataset=name,
    )
    if not job:
        print("ERROR: Run 2_model_train.py first")
        sys.exit(1)

    if job.submit():
        job.wait_for_completion()

        if job.success:
            filename = job.result.asset.retrieve(
                save_as="roi_result.zip", overwrite=True
            )
            pack = tb.Package.load(filename)
            inference_predictions = pack.records()

        else:
            print(f"Inference failed")

print("\n\nInference results:")
print("    ", inference_predictions)

original_image = Image.open(data_dir / "Flying-airplane.jpg", mode="r")
original_image = original_image.convert("RGB")
dims = [
    [
        original_image.width,
        original_image.height,
        original_image.width,
        original_image.height,
    ]
]

predictions = inference_predictions[0]
det_boxes = np.array(predictions["detected_polygons"][0]) * np.array(dims)

voc_labels = (
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
)
label_map = {k: v + 1 for v, k in enumerate(voc_labels)}
label_map["background"] = 0
rev_label_map = {v: k for k, v in label_map.items()}  # Inverse mapping


distinct_colors = [
    "#e6194b",
    "#3cb44b",
    "#ffe119",
    "#0082c8",
    "#f58231",
    "#911eb4",
    "#46f0f0",
    "#f032e6",
    "#d2f53c",
    "#fabebe",
    "#008080",
    "#000080",
    "#aa6e28",
    "#fffac8",
    "#800000",
    "#aaffc3",
    "#808000",
    "#ffd8b1",
    "#e6beff",
    "#808080",
    "#FFFFFF",
]
label_color_map = {k: distinct_colors[i] for i, k in enumerate(label_map.keys())}

det_labels = [rev_label_map[l] for l in predictions["predicted_labels"][0]]

# If no objects found, the detected labels will be set to ['0.'], i.e. ['background'] in SSD300.detect_objects() in model.py
if det_labels == ["background"]:
    print("No objects detected.")
else:

    # Annotate
    annotated_image = original_image
    draw = ImageDraw.Draw(annotated_image)
    font = ImageFont.load_default()

    # Suppress specific classes, if needed
    for i in range(len(det_boxes)):
        print(det_boxes[i])
        # Boxes
        box_location = det_boxes[i].tolist()
        draw.rectangle(xy=box_location, outline=label_color_map[det_labels[i]])
        draw.rectangle(
            xy=[l + 1.0 for l in box_location], outline=label_color_map[det_labels[i]]
        )
        # Text
        text_size = font.getbbox(det_labels[i].upper())  # tupple: left,top,right,bottom
        text_location = [
            box_location[0] + 2.0,
            box_location[1] - (text_size[1] - text_size[3]),
        ]
        textbox_location = [
            box_location[0],
            box_location[1] - text_size[1],
            box_location[0] + text_size[0] + 4.0,
            box_location[1],
        ]
        draw.rectangle(xy=textbox_location, fill=label_color_map[det_labels[i]])
        draw.text(xy=text_location, text=det_labels[i].upper(), fill="white", font=font)
    del draw

    annotated_image.show()
