#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

"""
This example script demonstrates SMPC inference using the trained model with a
vertically partitioned inference dataset.
"""

from pathlib import Path

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)


people = dict(
    alice=tb.config.example_user1["token"],
    bob=tb.config.example_user2["token"],
)

# Path to store inference output
# Path to store inference output
data_dir = Path("example_data")
if not data_dir.exists():
    raise SystemExit(
        "You must 1_train before this script to train the model associated with this example"
    )

# Find trained model asset.
trained_model_asset_uuid = tb.util.load_from(data_dir / "train_output_asset_uuid.out")
model = tb.ModelAsset(trained_model_asset_uuid)
if not model.is_valid:
    print("No model found. You must run 1_train.py.")

model_team_id = model.team_id

# Find dataset assets.
alice_asset = tb.TableAsset.find(
    f"EXAMPLE - alice's portion of a PSI Vertically Partitioned dataset for inference.",
    owned_by=tb.config.example_user1["team_id"],
)
bob_asset = tb.TableAsset.find(
    f"EXAMPLE - bob's portion of a PSI Vertically Partitioned dataset for inference.",
    owned_by=tb.config.example_user2["team_id"],
)

# Add agreement to use the trained model with the data.
tb.initialize(api_token=people["bob"])
bob_asset.add_agreement(
    with_team=model_team_id,
    operation=model,
)
tb.initialize(api_token=people["alice"])

pre = tb.TabularPreprocessor.builder().all_columns()


# Perform an inference using the trained model
result = model.psi_infer(
    data=[alice_asset, bob_asset],
    match_column=["ID0", "ID1"],
    preprocessor=pre,
)
if result is None:
    raise SystemExit("Inference failed")

# Retrieve output asset and download it into a zip file.

filename = result.asset.retrieve(
    save_as=f"{data_dir}/inference_output.zip",
    overwrite=True,
)

# Extract output from file into a pandas dataframe.
package = tb.Package.load(filename)
result = package.model()

# Print the results.
labels = result["data"]["labels"]
inertia = result["data"]["inertia"]

print("\nCluster labels for inference data:")
print(labels)
print("\nInertia for inference data:")
print(inertia)
