#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os.path
import sys
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report

import tripleblind as tb


# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()
data_dir = Path("BankTransactionData")

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("local_model_filename.out", "r") as f:
        trained_model = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)

if not os.path.exists(trained_model):
    print("ERROR: Unable to find the specified model.")
    sys.exit(1)

############################################################################
# Load the locally stored trained model object
#
pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

# Use the test dataset for "batch" testing.
# Retrieve from TripleBlind's demo data server.
data_file = tb.util.download_tripleblind_resource(
    "test_small_X_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)
target_file = tb.util.download_tripleblind_resource(
    "test_small_target_demo.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

data_X = pd.read_csv(data_file)
data_y = pd.read_csv(target_file)

X = data_X.values
X = X.astype(np.float32)
X = torch.from_numpy(X)

y = data_y["target"].astype(np.int64)
y = np.expand_dims(y.values, axis=1)
y = torch.from_numpy(y).double()

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch
        y_test_pred = model(X_batch)
        y_test_pred = torch.sigmoid(y_test_pred)
        y_pred_tag = torch.round(y_test_pred)
        for i in y_pred_tag:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
df = pd.DataFrame(y_pred_list)
df.to_csv("tabular_local_predictions.csv", header=None, index=None)
print(classification_report(y_true_list, y_pred_list))
