#!/usr/bin/env Rscript
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

# Install required packages if not already installed
RequiredPackages <- c("reticulate", "stringr")
for (i in RequiredPackages) { # Installs packages if not yet installed
  if (!require(i, character.only = TRUE)) install.packages(i)
}

# Load packages
library(reticulate)
library(stringr)

torch <- import("torch")

tb <- import("tripleblind")

tb$util$set_script_dir_current()

# Unique value used by all scripts in this folder.  Edit "run_id.txt" to change
run_id <- tb$util$read_run_id()


##########################################################################
# GET AUTHENTICATION TOKENS AND ESTABLISH CONNECTION TO THE ROUTER
#
# IN THIS INSTANCE WE ARE TREATING ORGANIZATION-ONE AS "SANTANDER"
#
# Establish the connection details to reach the TripleBlind instance.
# Unless explicitly specified, all operations will occur via this default
# session as the user 'organization_one'
tb$initialize(api_token = tb$config$example_user1[["token"]])

#############################################################################
# Validate that the datasets are available

# Find the training databases in the Router index
dataset_train0 <- tb$Asset$find(paste("SAN", run_id, sep = ""), owned_by=tb$config$example_user1[["team_id"]])
dataset_train1 <- tb$Asset$find(paste("JPM", run_id, sep = ""), owned_by=tb$config$example_user2[["team_id"]])
dataset_train2 <- tb$Asset$find(paste("PNB", run_id, sep = ""), owned_by=tb$config$example_user3[["team_id"]])

if (identical(NULL, dataset_train0)) {
  print("Must run 1_position_data_on_accesspoint.py first.")
  quit(save = "ask")
}

#############################################################################
# Define the neural network we want to train
#############################################################################

training_model_name <- "example-network-santandar-trainer"

builder <- tb$NetworkBuilder()
builder$add_dense_layer(200L, 120L) # use L to make the number an integer (R default number type a double)
builder$add_relu()
builder$add_dense_layer(120L, 160L)
builder$add_relu()
builder$add_dropout(0.25)
builder$add_dense_layer(160L, 200L)
builder$add_relu()
builder$add_split() # required split layer
builder$add_dense_layer(200L, 160L)
builder$add_relu()
builder$add_dense_layer(160L, 10L)
builder$add_relu()
builder$add_dense_layer(10L, 1L)

training_model <- tb$create_network(training_model_name, builder)

#############################################################################
# Designate the files to use and train the network
#

# Loss function names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/nn.html#loss-functions
# Currently tested: 'BCEWithLogitsLoss', 'NLLLoss', 'CrossEntropyLoss'
loss_name <- "BCEWithLogitsLoss"
pos_weight <- tb$TorchEncoder$encode(torch$arange(17, 18, dtype = torch$int32))

# Optimizer names must be consistent with PyTorch.
#   See: https://pytorch.org/docs/stable/optim.html
# Currently tested: 'SGD', 'Adam', 'Adadelta'
optimizer_name <- "Adam"
optimizer_params <- list("lr" = 0.001)

# Use the CSV Preprocessor to specify what data to use for training and which
# column to treat as the classification label.
csv_pre <- (
  tb$TabularPreprocessor$builder()
  $add_column("target", target = TRUE)
  $all_columns(TRUE)
  $dtype("float32")
)

job <- tb$create_job(
  job_name = print(paste("Santandar - ", gsub(" ", " @ ", toString(Sys.time())), sep = "")),
  operation = training_model,
  dataset = list(dataset_train0, dataset_train1, dataset_train2),
  preprocessor = csv_pre,
  params = list(
    "epochs" = 1L,
    "loss_meta" = list("name" = loss_name, "params" = list("pos_weight" = pos_weight)),
    "optimizer_meta" = list("name" = optimizer_name, "params" = optimizer_params),
    "data_type" = "table",
    "data_shape" = list(200L), # number of columns of data in table
    "model_output" = "binary" # binary/multiclass/regression
  )
)
print("Training network")

###########################################################################
# Create the network asset and local .pth file from the trained network

if (job$submit()) {
  print(paste("Creating network asset under name: ", training_model_name, sep = ""))
  job$wait_for_completion()

  # Throw away this network definition (no longer needed)
  training_model$archive()

  if (job$success) {
    cat("\n") # R equivalent to print()
    print("Trained Network Asset ID:")
    print("    ===============================================")
    print(paste("    ===>  ", job$result$asset$uuid, " <===", sep = ""))
    print("    ===============================================")
    print("    Algorithm: Deep Learning Model")
    print(paste("    Job ID:    ", job$job_name, sep = ""))
    cat("\n")
    trained_network <- job$result$asset
  } else {
    print("Training failed")
    quit(save = "ask") # sort of like sys.exit(1), choose c (cancel) when asks in the console, otherwise rstudio session will close
  }

  # Pull down the model for local validation
  local_filename <- trained_network$download(save_as = "local.pth", overwrite = TRUE)
  print("Trained network has been downloaded as:")
  print(paste("   ", local_filename, sep = ""))

  # Save for use in 3a_local_inference.py
  output <- file("local_model_filename.out", open = "w")
  writeLines(toString(local_filename), output)
  close(output)

  # Save for use in 3b_fed_inference.py / 3c_smpc_inference.py
  output <- file("model_asset_id.out", open = "w")
  writeLines(toString(job$result$asset$uuid), output)
  close(output)

  print("Ready to run local inference.")
  cat("\n")

  # Create an agreement which allows the other team to use this
  # trained model in subsequent steps.
  agreement <- job$result$asset$add_agreement(
    with_team = tb$config$example_user2[["team_id"]], operation = tb$Operation$EXECUTE
  )
  if (!is.null(agreement)) {
    print("Created Agreement for use of trained Asset.")
  }
}


############################################################################
# The 'trained_network.filename' variable is the local filename used when
# downloading the trained PyTorch object locally. It could easily be passed to
# an additional step to run the local inference.
