#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os.path
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler

import tripleblind as tb


# Suppress the PyTorch "SourceChangeWarning"
warnings.filterwarnings("ignore")


##########################################################################
# Build test data

# Download full CMAPSS Data to build test dataset
data_dir = Path("example_data")
test_data = tb.util.download_tripleblind_resource(
    "CMAPSSData.zip", save_to_dir=data_dir, cache_dir="../../.cache", expand=True
)


def load_test_df(test_file, rul_file):
    # Read files into dataframes
    test_df = pd.read_csv(test_file, sep="\s+", header=None)
    test_df.columns = ["id", "cycle", "setting1", "setting2", "setting3"] + [
        f"s{i}" for i in range(1, 22)
    ]
    test_rul = pd.read_table(rul_file, header=None)
    test_rul.columns = ["RUL-max-cycle"]

    # Add RUL column
    max_cycle = test_df.groupby(["id"])["cycle"].max()
    test_rul["cycle"] = max_cycle.values
    test_rul["id"] = np.arange(1, len(test_rul) + 1)
    # Merge given final RUL in sequence.
    test_df = test_df.merge(test_rul, how="outer", on=["cycle", "id"])
    # Calculate RUL for all non-final cycles in sequence
    test_df["RUL"] = (
        test_df.groupby(["id"])["RUL-max-cycle"].transform(max)
        + test_df.groupby(["id"])["cycle"].transform(max)
        - test_df["cycle"]
    )
    test_df["RUL"] = test_df["RUL"].astype(int)
    del test_df["RUL-max-cycle"]
    return test_df


def create_test_csv_from_dataframe(test_df, output_filename):
    test_y = test_df["RUL"]
    del test_df["RUL"]
    test_df.drop(test_df.columns[[0, 1]], axis=1, inplace=True)
    test_x = test_df

    test_y = test_y.astype(np.float32)
    test_x = test_x.astype(np.float32)

    test_y.to_csv(f"{output_filename}_y.csv", header=["RUL"], index=False)
    test_x.to_csv(f"{output_filename}_x.csv", index=False)


# Build 4 test datasets
for i in range(1, 5):
    base = f"FD00{i}"
    test_df = load_test_df(data_dir / f"test_{base}.txt", data_dir / f"RUL_{base}.txt")
    create_test_csv_from_dataframe(test_df, data_dir / f"{base}_test")


##########################################################################

tb.util.set_script_dir_current()
data_dir = Path("example_data")


def windows(nrows, size):
    start, step = 0, 2
    while start < nrows:
        yield start, start + size
        start += step


def segment_signal(features, labels, window_size=30):
    segments = []
    segment_labels = []
    nrows = len(features)
    for start, end in windows(nrows, window_size):
        if len(features[start:end]) == window_size:
            segment = features[np.newaxis, start:end]  # (channel, rows, cols)
            label = labels[(end - 1)]
            segments.append(segment)
            segment_labels.append(label)
    segments = np.stack(segments)  # (batch, channel, rows, cols)
    segment_labels = np.stack(segment_labels)
    return segments, segment_labels


def reformat_data(data_x, data_y):
    data_x = data_x.values
    data_y = data_y.values
    data_y = np.expand_dims(data_y, axis=1)

    scaler = StandardScaler()
    data_x = scaler.fit_transform(data_x)
    data_x, data_y = segment_signal(data_x, data_y)

    data_y = data_y.astype(np.float32)
    data_x = data_x.astype(np.float32)

    data_y = torch.from_numpy(data_y)
    data_x = torch.from_numpy(data_x)

    return (data_x, data_y)


# Look for a model Asset ID from a previous run of 2_model_train.py
try:
    trained_model = tb.util.load_from("local_model_filename.out")
except:
    raise SystemError("You must run 2_model_train.py first")
if not os.path.exists(trained_model):
    raise SystemError("ERROR: Unable to find the specified model.")

############################################################################
# Load the locally stored trained model object
#
pack = tb.Package.load(trained_model)
model = pack.model()
model.eval()

# Use the local test dataset for "batch" testing
#

data_x = pd.read_csv(data_dir / "FD001_test_x.csv")
data_y = pd.read_csv(data_dir / "FD001_test_y.csv")

X, y = reformat_data(data_x, data_y)

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        y_test_pred = model(X_batch)
        for i in y_test_pred:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
r2_metric = r2_score(y_true_list, y_pred_list)
print(f"R2 score(FD001_test): {r2_metric}")

data_x = pd.read_csv(data_dir / "FD002_test_x.csv")
data_y = pd.read_csv(data_dir / "FD002_test_y.csv")

X, y = reformat_data(data_x, data_y)

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        y_test_pred = model(X_batch)
        for i in y_test_pred:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
r2_metric = r2_score(y_true_list, y_pred_list)
print(f"R2 score(FD002_test): {r2_metric}")

data_x = pd.read_csv(data_dir / "FD003_test_x.csv")
data_y = pd.read_csv(data_dir / "FD003_test_y.csv")

X, y = reformat_data(data_x, data_y)

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        y_test_pred = model(X_batch)
        for i in y_test_pred:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
r2_metric = r2_score(y_true_list, y_pred_list)
print(f"R2 score(FD003_test): {r2_metric}")

data_x = pd.read_csv(data_dir / "FD004_test_x.csv")
data_y = pd.read_csv(data_dir / "FD004_test_y.csv")

X, y = reformat_data(data_x, data_y)

ds = torch.utils.data.TensorDataset(X, y)
test_loader = torch.utils.data.DataLoader(ds, batch_size=128)

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        y_test_pred = model(X_batch)
        for i in y_test_pred:
            y_pred_list.append(i.numpy())
        for i in y_batch:
            y_true_list.append(i.item())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
r2_metric = r2_score(y_true_list, y_pred_list)
print(f"R2 score(FD004_test): {r2_metric}")
