#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split

import tripleblind as tb


tb.util.set_script_dir_current()
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)


# Download Datasets Cached on TripleBlind's Storage Bucket:
###########################################################################
# Using a cached version of the Gene Expression Dataset
# data, found at:
#  https://www.kaggle.com/varimp/gene-expression-classification/notebook
###########################################################################
tb.util.download_tripleblind_resource(
    "actual.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)
tb.util.download_tripleblind_resource(
    "data_set_ALL_AML_train.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)
tb.util.download_tripleblind_resource(
    "data_set_ALL_AML_independent.csv",
    save_to_dir=data_dir,
    cache_dir="../../.cache",
)

y = pd.read_csv(data_dir / "actual.csv")
y = y.replace({"ALL": 0, "AML": 1})

df_train = pd.read_csv(data_dir / "data_set_ALL_AML_train.csv")
df_test = pd.read_csv(data_dir / "data_set_ALL_AML_independent.csv")

train_to_keep = [col for col in df_train.columns if "call" not in col]
test_to_keep = [col for col in df_test.columns if "call" not in col]

X_train_tr = df_train[train_to_keep]
X_test_tr = df_test[test_to_keep]

train_columns_titles = [
    "Gene Description",
    "Gene Accession Number",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
    "10",
    "11",
    "12",
    "13",
    "14",
    "15",
    "16",
    "17",
    "18",
    "19",
    "20",
    "21",
    "22",
    "23",
    "24",
    "25",
    "26",
    "27",
    "28",
    "29",
    "30",
    "31",
    "32",
    "33",
    "34",
    "35",
    "36",
    "37",
    "38",
]

X_train_tr = X_train_tr.reindex(columns=train_columns_titles)

test_columns_titles = [
    "Gene Description",
    "Gene Accession Number",
    "39",
    "40",
    "41",
    "42",
    "43",
    "44",
    "45",
    "46",
    "47",
    "48",
    "49",
    "50",
    "51",
    "52",
    "53",
    "54",
    "55",
    "56",
    "57",
    "58",
    "59",
    "60",
    "61",
    "62",
    "63",
    "64",
    "65",
    "66",
    "67",
    "68",
    "69",
    "70",
    "71",
    "72",
]

X_test_tr = X_test_tr.reindex(columns=test_columns_titles)

# Transpose columns so that gene expressions represent features
X_train = X_train_tr.T
X_test = X_test_tr.T

X_train.columns = X_train.iloc[1]
X_train = X_train.drop(["Gene Description", "Gene Accession Number"]).apply(
    pd.to_numeric
)

# Clean up the column names for Testing data
X_test.columns = X_test.iloc[1]
X_test = X_test.drop(["Gene Description", "Gene Accession Number"]).apply(pd.to_numeric)
y_train = y[:38]
y_test = y[38:]

x1, x2, y1, y2 = train_test_split(X_train, y_train.cancer.values, test_size=0.3)
pd.options.mode.chained_assignment = None  # suppress a scary Pandas warning
x1["target"] = y1
x2["target"] = y2

x1.to_csv(data_dir / "train0.csv", index=False)
x2.to_csv(data_dir / "train1.csv", index=False)


X_test.to_csv(data_dir / "test.csv", index=False)
