#!/usr/bin/env Rscript
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

# Install required packages if not already installed
RequiredPackages <- c("reticulate", "stringr")
for (i in RequiredPackages) { # Installs packages if not yet installed
  if (!require(i, character.only = TRUE)) install.packages(i)
}

# Load packages
library(reticulate)
library(stringr)

np <- import("numpy", convert = FALSE)
pd <- import("pandas", convert = FALSE)
torch <- import("torch", convert = FALSE)

# below 2 lines is equivalent to python's 'from sklearn.metrics import classification_report'
sklearn_metrics <- import("sklearn.metrics")
classification_report <- sklearn_metrics$classification_report

tb <- import("tripleblind")

tb$util$set_script_dir_current()
data_dir <- "BankTransactionData"

# Look for a model Asset ID from a previous run of 2_model_train.py
result <- tryCatch(
  {
    f <- file("local_model_filename.out", open = "r")
    trained_model <- trimws(readLines(f)) # trimws is R equivalent of python strip(); removes leading and trailing whitespace
    close(f)
  },
  error = function(err) {
    print("You must run 2_model_train.py first.")
    quit(save = "ask")
  }
)

if (!file.exists(trained_model)) {
  print("ERROR: Unable to find the specified model.")
  quit(save = "ask")
}

############################################################################
# Load the locally stored trained model object
#
model <- torch$load(trained_model)
model$eval()

# Use the local test dataset for "batch" testing
#

data <- pd$read_pickle(file.path(data_dir, "test_small_demo.pkl"))

X <- data["x"]$astype(np$float32)
X <- torch$from_numpy(X)
y <- data["y"]$astype(np$int64)
y <- torch$from_numpy(y)$double()

ds <- torch$utils$data$TensorDataset(X, y)
test_loader <- torch$utils$data$DataLoader(ds, batch_size = 128L) # L makes it type Long

y_pred_list <- list()
y_true_list <- list()


test_loader_iter <- iterate(it = test_loader)
X_batch <- test_loader_iter[[1]][[0]] # This is X, X_batch[[0]] is the first column
y_batch <- test_loader_iter[[1]][[1]] # This is y

with(torch$no_grad(), {
  y_test_pred <- model$forward(X_batch)
  y_test_pred <- torch$sigmoid(y_test_pred)
  y_test_pred <- torch$round(y_test_pred)
})

y_true_list <- y_batch$numpy()
y_pred_list <- y_test_pred$numpy()

df <- pd$DataFrame(y_pred_list)
df$to_csv("tabular_local_predictions.csv", header = NULL, index = NULL)
report <- classification_report(y_true_list, y_pred_list)
cat(report)
