#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os.path
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler

import tripleblind as tb


# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")

tb.util.set_script_dir_current()

# Look for a model Asset ID from a previous run of 2_nn_model_train.py
try:
    trained_model = tb.util.load_from("local_model_filename.out")
except:
    raise SystemError("You must run 2_nn_model_train.py first.")
if not os.path.exists(trained_model):
    raise SystemError("ERROR: Unable to find the specified model.")

############################################################################
## Create test data
############################################################################


def load_test_df(test_file, rul_file):
    # Read files into dataframes
    test_df = pd.read_csv(test_file, sep="\s+", header=None)
    test_df.columns = ["id", "cycle", "setting1", "setting2", "setting3"] + [
        f"s{i}" for i in range(1, 22)
    ]
    test_rul = pd.read_table(rul_file, header=None)
    test_rul.columns = ["RUL-max-cycle"]

    # Add RUL column
    max_cycle = test_df.groupby(["id"])["cycle"].max()
    test_rul["cycle"] = max_cycle.values
    test_rul["id"] = np.arange(1, len(test_rul) + 1)
    # Merge given final RUL in sequence.
    test_df = test_df.merge(test_rul, how="outer", on=["cycle", "id"])
    # Calculate RUL for all non-final cycles in sequence
    test_df["RUL"] = (
        test_df.groupby(["id"])["RUL-max-cycle"].transform(max)
        + test_df.groupby(["id"])["cycle"].transform(max)
        - test_df["cycle"]
    )
    test_df["RUL"] = test_df["RUL"].astype(int)
    del test_df["RUL-max-cycle"]
    return test_df


# Download CMAPSS Data
data_dir = Path("example_data")
test_data = tb.util.download_tripleblind_resource(
    "CMAPSSData.zip", save_to_dir=data_dir, cache_dir="../../.cache", expand=True
)


test_df = load_test_df(data_dir / "test_FD001.txt", data_dir / "RUL_FD001.txt")

test_y = test_df["RUL"]
del test_df["RUL"]
test_df.drop(test_df.columns[[0, 1]], axis=1, inplace=True)
test_x = test_df

test_x = test_x.astype(np.float32)
test_y = test_y.astype(np.float32)

test_y.to_csv(data_dir / "FD001_nn_test_y.csv", header=["RUL"], index=False)
test_x.to_csv(data_dir / "FD001_nn_test_x.csv", index=False)


############################################################################
## Test the previously trained model
############################################################################

# Load the locally copy of the trained model object
pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

# Use the local test dataset for "batch" testing
#
data_dir = Path("example_data")

data_x = pd.read_csv(data_dir / "FD001_nn_test_x.csv")
data_y = pd.read_csv(data_dir / "FD001_nn_test_y.csv")

data_x = data_x.values
data_y = data_y.values

scaler = StandardScaler()
data_x = scaler.fit_transform(data_x)

y = data_y.astype(np.float32)
X = data_x.astype(np.float32)

y = torch.from_numpy(y)
X = torch.from_numpy(X)

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=32)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        y_test_pred = model(X_batch)
        y_pred_list.append(y_test_pred.numpy())
        y_true_list.append(y_batch.numpy())
y_pred_list = np.concatenate(y_pred_list)
y_true_list = np.concatenate(y_true_list)

pd.DataFrame(y_pred_list).to_csv("cmapss_nn_predictions.csv", header=None)
r2_metric = r2_score(y_true_list, y_pred_list)
print(f"R2 score: {r2_metric}")
