#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
import json
import os
import warnings
import zipfile


warnings.filterwarnings("ignore")  # Suppress some torch warnings
import torch
from transformers import AutoTokenizer

from preprocessor.nlp.transformers import network_join


# Helper function to print the tokens and their predicted labels
def postprocessing(tokenized_text, logits, labels, tokenizer):
    predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
    tokens_list = tokenized_text.detach().cpu().numpy()
    token_word_list = [
        tokenizer.convert_ids_to_tokens(sample) for sample in tokens_list
    ]

    for sample_tokens, sample_preds in zip(token_word_list, predictions):
        print("\n--- Sample ---")
        combined_phrase = ""
        current_label = None
        for token, pred_index in zip(sample_tokens, sample_preds):
            if token == tokenizer.pad_token:
                break
            predicted_label = labels[str(pred_index)]
            if predicted_label == "None":
                if combined_phrase:
                    print(f"{combined_phrase.strip()} \t prediction: {current_label}")
                    combined_phrase = ""
                    current_label = None
                continue

            if token.startswith("##"):
                combined_phrase += token[2:]
            else:
                if combined_phrase and predicted_label != current_label:
                    print(f"{combined_phrase.strip()} \t prediction: {current_label}")
                    combined_phrase = ""
                combined_phrase += " " + token

            current_label = predicted_label

        if combined_phrase:
            print(f"{combined_phrase.strip()} \t prediction: {current_label}")


# Unzip the model file into a new directory
os.makedirs("model_package", exist_ok=True)
with zipfile.ZipFile("output-model.zip", "r") as zip_ref:
    zip_ref.extractall("model_package")

# Load the trained model
model = network_join.NetworkJoin.load("model_package")

# Get the labels of the model and their ids from the model config file
json_dir = "model_package/server/config.json"
with open(json_dir, "r") as file:
    data = json.load(file)
    id2label = data.get("id2label")

# some test data
test_data = [
    "Taking Venlafaxine for depression.",
    "Mary visited the hospital for severe chest pain.",
]

# Tokenize the test data
# Use the same tokenizer used in the training
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
tokenized_text = tokenizer(
    test_data, return_tensors="pt", truncation=True, padding=True
)

# Predict the named entities in `test_data`
model_output = model(**tokenized_text)
logits = model_output.logits

# Print the tokens and their predicted labels
postprocessing(tokenized_text["input_ids"], logits, id2label, tokenizer)
