#!/usr/bin/env Rscript
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

# Install required packages if not already installed
RequiredPackages <- c("reticulate", "stringr")
for (i in RequiredPackages) { # Installs packages if not yet installed
  if (!require(i, character.only = TRUE)) install.packages(i)
}

# Load packages
library(reticulate)
library(stringr)

pd <- import("pandas", convert = FALSE)

# from sklearn.metrics import classification_report
sklearn_metrics <- import("sklearn.metrics")
classification_report <- sklearn_metrics$classification_report

tb <- import("tripleblind")

tb$util$set_script_dir_current()
data_dir <- "BankTransactionData"

##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
tb$initialize(api_token = tb$config$example_user2[["token"]])

# Look for a model Asset ID from a previous run of 2_model_train.py
result <- tryCatch(
  {
    f <- file("model_asset_id.out", open = "r")
    asset_id <- trimws(readLines(f, warn = FALSE)) # trimws is R equivalent of python strip(); removes leading and trailing whitespace
    close(f)
  },
  error = function(err) {
    print("You must run 2_model_train.py first.")
    quit(save = "ask")
  }
)

alg_santandar <- tb$Asset(asset_id)

testdf <- pd$read_csv(file.path(data_dir, "test_small_demo.csv"))
testdf <- testdf$drop(columns = "target")

testdf$to_csv(file.path(data_dir, "test_small_notarget_demo.csv"), index = FALSE)
datasets <- list(file.path(data_dir, "test_small_notarget_demo.csv"))

preproc <- (
  tb$TabularPreprocessor$builder()
  $all_columns(TRUE)
  $dtype("float32")
)

# Define a job using this model
job <- tb$create_job(
  job_name = paste("Santandar Inference - ", toString(Sys.time()), sep = ""),
  operation = alg_santandar,
  dataset = datasets,
  preprocessor = preproc,
  params = list("security" = "fed") # fed or smpc
)

if (is.null(job)) {
  print("ERROR: Failed to create the job -- do you have an Agreement to run this?")
  cat("\n")
  print(
    paste("NOTE: Remote inference requires the user '", tb$config$example_user1[["login"]], "' create an", sep = "")
  )
  print(
    paste("      Agreement on their algorithm asset with user '", tb$config$example_user2[["login"]], "'", sep = "")
  )
  print(
    paste("      (", tb$config$example_user2[["name"]], ") before they can use it to infer.  You can do", sep = "")
  )
  print("      this on the Router at:")
  print(paste("      ", tb$config$gui_url, "/dashboard/algorithm/", alg_santandar$uuid, sep = ""))
  cat("\n")
  print("Program exiting...")
  quit(save = "ask")
}

# Load the labels
p <- pd$read_pickle(file.path(data_dir, "test_small_demo.pkl"))

y <- p[["y"]]

# Run against the local test dataset
if (job$submit()) {
  job$wait_for_completion()

  Sys.sleep(3)
  if (job$success) {
    filename <- job$result$asset$download(
      save_as = "tabular_remote_predictions.csv", overwrite = TRUE
    )
    result <- pd$read_csv(filename, header = NULL)
    cat("\n")
    print("Inference results:")
    print(paste(py_to_r(result$values), collapse = "  "))
    cat("\n")

    print("Truth:")
    print(paste(py_to_r(y), collapse = "  "))
    cat("\n")

    print("Classification report:")
    report <- classification_report(y, result$values)
    cat(report)
    cat("\n")
  } else {
    print("Inference Failed")
  }
}
