#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# Find test datasets in Router's index
prefix = "EXAMPLE - "
asset0 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical test 0-40",
    owned_by=tb.config.example_user1["team_id"],
)
asset1 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical test 41-100",
    owned_by=tb.config.example_user2["team_id"],
)
asset2 = tb.TableAsset.find(
    f"{prefix}Linear PSI Regression vertical test 101-120",
    owned_by=tb.config.example_user3["team_id"],
)

if not asset0 or not asset1 or not asset2:
    raise SystemError("Datasets not found.")

# Find trained model in Router's index
asset_id = tb.util.load_from("psi_vert_linear_reg_model_asset_id.out")
model = tb.ModelAsset(asset_id)


if not model.is_valid:
    raise SystemError("No model found. You must run 1a_train_linear.py")

session_team_2 = tb.Session(
    api_token=tb.config.example_user2["token"], from_default=True
)
asset1.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=session_team_2,
)

session_team_3 = tb.Session(
    api_token=tb.config.example_user3["token"], from_default=True
)
asset2.add_agreement(
    with_team=tb.config.example_user1["team_id"],
    operation=asset_id,
    session=session_team_3,
)

preproc = tb.TabularPreprocessor.builder().all_columns()

result = model.psi_infer(
    data=[asset0, asset1, asset2],
    match_column="ID",
    regression_type=tb.RegressionType.LINEAR,
    preprocessor=preproc,
    job_name="PSI Regression Distributed Inference",
)

if result:
    filename = result.asset.retrieve(save_as="fed_inf_result.zip", overwrite=True)
    result_df = result.table.dataframe
    # replace empty strings with zeros in result_df and convert to float
    result_df.replace("", 0, inplace=True)
    result_vals = result_df.values.astype(float)
    print("\nInference results:")
    print("    ", result_vals.flatten())

    # Replace NaN values with 0 for purposes of later reporting; this is a known issue
    np.nan_to_num(result_vals.flatten(), copy=False, nan=0.0)

    test_asset = tb.TableAsset.find(
        "EXAMPLE - Linear test data - psi vertical regression",
        owned_by=tb.config.example_user1["team_id"],
    )
    test_asset.retrieve(save_as="test_reg_asset.zip", overwrite=True)
    test_df = tb.Package.load("test_reg_asset.zip").records()

    test_df["ID"] = test_df["ID"].astype(str)
    test_df = test_df.sort_values("ID")
    truth = test_df["y"].to_list()
    print(f"Truth: {truth}")

    mse = mean_squared_error(truth, result_vals.flatten())
    print("TripleBlind model:")
    print(f"\tmse:{mse}")
    print(f"\tr2:{r2_score(truth, result_vals.flatten())}")

    train_asset = tb.TableAsset.find(
        "EXAMPLE - Linear train data - psi vertical regression",
        owned_by=tb.config.example_user1["team_id"],
    )
    train_asset.retrieve(save_as="train_reg_asset.zip", overwrite=True)
    train_df = tb.Package.load("train_reg_asset.zip").records()

    y = train_df["y"].values
    del train_df["y"]
    X = train_df.values
    sk_model = LinearRegression()
    sk_model.fit(X, y)

    del test_df["y"]
    sk_pred = sk_model.predict(test_df.values)
    print("SKLearn model:")
    print(f"\tmse:{mean_squared_error(truth,sk_pred)}")
    print(f"\tr2:{r2_score(truth, sk_pred)}")

else:
    print(f"Inference failed")
