#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os.path
import sys
import warnings
from pathlib import Path

import tripleblind as tb


try:
    import numpy as np
    import pandas as pd
    import torch
    from sklearn.metrics import classification_report

except:
    print("You must install pandas, PyTorch, sklearn, and Numpy first.  Do this via:")
    print("  pip install torch numpy sklearn")
    print("Then rerun this script.")
    sys.exit(1)

# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()
data_dir = Path("example_data")

# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    with open("local_model_filename.out", "r") as f:
        trained_model = f.readline().strip()
except:
    print("You must run 2_model_train.py first.")
    sys.exit(1)

if not os.path.exists(trained_model):
    print("ERROR: Unable to find the specified model.")
    sys.exit(1)

############################################################################
# Load the locally stored trained model object
#
pack = tb.Package.load(trained_model)
_, path = pack.model_pointer()
model = tb.load_model(path)
model.eval()

# Use the local test dataset for "batch" testing
#
df_test = pd.read_csv(data_dir / "sant_psi_vertical_test.csv")
test_ids = df_test["ID_code"].copy()
test_y = df_test["target"].copy().values
del df_test["ID_code"]
del df_test["target"]

test_0_40 = df_test[df_test.columns[:40]]
X = test_0_40.values.astype(np.float32)
X = torch.from_numpy(X)
y = test_y
if y.dtype == np.uint8:
    y = y.astype(np.int64)
y = torch.from_numpy(y)
ds_0_40 = torch.utils.data.TensorDataset(X, y)
test_loader_0_40 = torch.utils.data.DataLoader(ds_0_40, batch_size=64, shuffle=False)

print(f"Running local testing on {len(ds_0_40)} records...")

test_41_100 = df_test[df_test.columns[40:100]]
X = test_41_100.values.astype(np.float32)
X = torch.from_numpy(X)
ds_41_100 = torch.utils.data.TensorDataset(X)
test_loader_41_100 = torch.utils.data.DataLoader(
    ds_41_100, batch_size=64, shuffle=False
)

test_101_200 = df_test[df_test.columns[100:]]
X = test_101_200.values.astype(np.float32)
X = torch.from_numpy(X)
ds_101_200 = torch.utils.data.TensorDataset(X)
test_loader_101_200 = torch.utils.data.DataLoader(
    ds_101_200, batch_size=64, shuffle=False
)

y_pred_list = []
y_true_list = []

with torch.no_grad():
    for X0_batch, X1_batch, X2_batch in zip(
        test_loader_0_40, test_loader_41_100, test_loader_101_200
    ):
        target = X0_batch[1]
        y_test_pred = model(X0_batch[0], X1_batch[0], X2_batch[0])
        y_test_pred = torch.sigmoid(y_test_pred)
        y_pred_tag = torch.round(y_test_pred)
        for i in y_pred_tag:
            y_pred_list.append(i.numpy())
        for i in target:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
pd.DataFrame(y_pred_list).to_csv("psi_vert_results.csv", header=None, index=None)
print(classification_report(y_true_list, y_pred_list))
