#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
import json
import os
import warnings
import zipfile


warnings.filterwarnings("ignore")  # Suppress some torch warnings
import torch
from transformers import AutoTokenizer

from preprocessor.nlp.transformers import network_join


# Unzip the model file into a new directory
os.makedirs("model_package", exist_ok=True)
with zipfile.ZipFile("output-model.zip", "r") as zip_ref:
    zip_ref.extractall("model_package")

# Load the trained model
model = network_join.NetworkJoin.load("model_package")

# Get the labels of the model and their ids from the model config file
json_dir = "model_package/server/config.json"
with open(json_dir, "r") as file:
    data = json.load(file)
    id2label = data.get("id2label")

# some test data
test_data = [
    "This one should win an Oscar!",
    "I fell asleep halfway through it.  Boring.",
]

# Tokenize the test data
# Use the same tokenizer used in the training
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
tokenized_text = tokenizer(
    test_data, return_tensors="pt", truncation=True, padding=True
)

# Predict the named entities in `test_data`
model_output = model(**tokenized_text)
logits = model_output.logits

for i, text in enumerate(test_data):
    print()
    print("=" * 40)
    print()

    print(text)

    text_logits = logits[i]
    probs = torch.softmax(text_logits, 0)

    print()
    for k, v in id2label.items():
        prob = probs[int(k)].item()
        print(v, prob)

print()
print("=" * 40)
print()
