#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user1["token"], example=True)

people = dict(
    alice=tb.config.example_user1["token"],
    bob=tb.config.example_user2["token"],
)

# Example data path to save trained model uuid
data_dir = Path("example_data")
if not data_dir.exists():
    SystemExit("Run 0_setup.py before this script.")


# Find assets.
alice_asset = tb.TableAsset.find(
    "EXAMPLE - alice's portion of a PSI Vertically Partitioned dataset for training.",
    owned_by=tb.config.example_user1["team_id"],
)
bob_asset = tb.TableAsset.find(
    "EXAMPLE - bob's portion of a PSI Vertically Partitioned dataset for training.",
    owned_by=tb.config.example_user2["team_id"],
)

csv_pre = tb.TabularPreprocessor.builder().all_columns()

# Create training job
num_clusters = 8  # Target number of clusters
num_iter = 10  # Number of iterations to run the training

job = tb.create_job(
    job_name="PSI Vertical KMeans Training",
    operation=tb.Operation.PSI_VERTICAL_KMEANS_TRAIN,
    dataset=[alice_asset, bob_asset],
    preprocessor=csv_pre,
    params={
        "psi": {
            "match_column": ["ID0", "ID1"],
        },
        "kmeans": {
            "num_clusters": num_clusters,
            "num_iter": num_iter,
        },
    },
)

# Run job.
job.submit()
job.wait_for_completion()
if job.success:
    result = job.result

    # Save in directory with training data
    data_dir = Path("example_data")
    data_dir.mkdir(exist_ok=True)

    # Retrieve output asset and download it into a zip file.
    filename = result.asset.retrieve(
        save_as=data_dir / "training_results.zip",
        overwrite=True,
    )

    # Save output asset uuid to file for reference in later scripts.
    tb.util.save_to(data_dir / "train_output_asset_uuid.out", result.asset.uuid)

    # Read outputs from asset file.
    package = tb.Package.load(filename)
    training_result = package.model()

    cluster_means = training_result["model"]["cluster_means"]
    is_cluster_empty = training_result["model"]["is_cluster_empty"]
    zmuv_mean = training_result["model"]["zmuv_mean"]
    zmuv_linear = training_result["model"]["zmuv_linear"]

    labels = training_result["data"]["labels"]
    inertia = training_result["data"]["inertia"]
else:
    print(job)
