#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os.path
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)

# Look for a model created by a previous run of 1_model_train.py
trained_model = "local.zip"
if not os.path.exists(trained_model):
    raise SystemError("Unable to find the model, run 1_model_train.py.")

############################################################################
# Load the locally stored trained model object
#
pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

# Use a test dataset for "batch" testing
#
test_data = tb.util.download_tripleblind_resource(
    "test_small_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

# Load and split test data into independent X (data) and y (target) dataframes
data_X = pd.read_csv(test_data)
data_y = data_X["target"].copy()
del data_X["target"]


X = data_X.values
X = X.astype(np.float32)
X = torch.from_numpy(X)

y = data_y.values.astype(np.int64)
y = np.expand_dims(y, axis=1)
y = torch.from_numpy(y).double()

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch
        y_test_pred = model(X_batch)
        y_test_pred = torch.sigmoid(y_test_pred)
        y_pred_tag = torch.round(y_test_pred)
        for i in y_pred_tag:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
df = pd.DataFrame(y_pred_list)
df.to_csv("tabular_local_predictions.csv", header=None, index=None)
print(classification_report(y_true_list, y_pred_list))
