#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Find test datasets in Router's index
prefix = "EXAMPLE - "
asset0 = tb.TableAsset.find(
    f"{prefix}Classification PSI Regression vertical test 0-40",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    f"{prefix}Classification PSI Regression vertical test 41-100",
    owned_by=tb.config.example_user2["team_id"],
)
asset2 = tb.TableAsset.find(
    f"{prefix}Classification PSI Regression vertical test 101-120",
    owned_by=tb.config.example_user3["team_id"],
)
if not asset0 or not asset1 or not asset2:
    raise SystemError("Datasets not found.")

# Find trained model in Router's index
asset_id = tb.util.load_from("psi_vert_clf_reg_model_asset_id.out")
model = tb.ModelAsset(asset_id)
if not model.is_valid:
    raise SystemError("No model found. You must run 2b_train_clf.py")

session_org_2 = tb.Session(
    api_token=tb.config.example_user2["token"], from_default=True
)
asset1.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=session_org_2,
)

session_org_3 = tb.Session(
    api_token=tb.config.example_user3["token"], from_default=True
)
asset2.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=session_org_3,
)


preproc = tb.TabularPreprocessor.builder().all_columns()

result = model.psi_infer(
    data=[asset0, asset1, asset2],
    match_column="ID",
    regression_type=tb.RegressionType.LOGISTIC,
    preprocessor=preproc,
    job_name="PSI Vertical Regression Distributed Inference",
)

if result:
    filename = result.asset.retrieve(save_as="fed_inf_result.zip", overwrite=True)
    result_df = result.table.dataframe
    # replace empty strings with zeros in result_df
    result_df.replace("", 0, inplace=True)
    result_vals = result_df.values.astype(
        float
    )  # np.delete(result_df.values, 0).astype(float)
    print("\nInference results:")
    print("    ", result_vals.flatten())

    # Replace NaN values with 0 for purposes of later reporting; this is a known issue
    np.nan_to_num(result_vals.flatten(), copy=False, nan=0.0)

    test_asset = tb.TableAsset.find(
        "EXAMPLE - Classification test data - psi vertical regression",
        owned_by=tb.config.example_user1["team_id"],
    )
    test_asset.retrieve(save_as="test_clf_asset.zip", overwrite=True)
    df = tb.Package.load("test_clf_asset.zip").records()
    test_df = tb.Package.load("test_clf_asset.zip").records()

    df["ID"] = df["ID"].astype(str)
    df = df.sort_values("ID")
    truth = df["y"].to_list()
    accuracy = accuracy_score(truth, result_vals.round().flatten())
    print("TripleBlind model:")
    print(classification_report(truth, result_vals.round().flatten()))

    train_asset = tb.TableAsset.find(
        "EXAMPLE - Classification train data - psi vertical regression",
        owned_by=tb.config.example_user1["team_id"],
    )
    train_asset.retrieve(save_as="train_clf_asset.zip", overwrite=True)
    train_df = tb.Package.load("train_clf_asset.zip").records()

    y = train_df["y"].values
    del train_df["y"]
    del train_df["ID"]
    X = train_df.values
    sk_model = LogisticRegression(random_state=0)
    sk_model.fit(X, y)

    test_df["ID"] = test_df["ID"].astype(str)
    test_df = test_df.sort_values("ID")
    del test_df["y"]
    del test_df["ID"]
    sk_pred = sk_model.predict(test_df.values)
    print("SKLearn model:")
    print(classification_report(truth, sk_pred))

else:
    print(f"Inference failed")
