#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os

import tripleblind as tb


tb.util.set_script_dir_current()
tb.initialize(api_token=tb.config.example_user3["token"])

run_id = tb.util.read_run_id()

# Retrieve the model Asset ID from a previous run of 1_position_model.py
if os.environ.get("TB_TEST_SMALL"):
    name = f"TEST - PMML Random Forest model - {run_id}"
else:
    name = "EXAMPLE - PMML Random Forest model"
try:
    asset = tb.Asset.find(name, owned_by=tb.config.example_user1["team_id"])
    model = tb.ModelAsset.cast(asset)
except:
    raise SystemExit("You must run 1_position_model.py first.")


df_fed, df_smpc = None, None
for security_mode in ("fed", "smpc"):
    print(f"Running {security_mode} inference...")
    result = model.infer(
        data="unknown_flowers.csv",
        preprocessor=tb.TabularPreprocessor.builder()
        .add_column("Sepal.Length")
        .add_column("Sepal.Width")
        .add_column("Petal.Length")
        .add_column("Petal.Width"),
        params={"security": security_mode},
        job_name=f"PMML Tree Example - {security_mode}",
    )

    if result:
        df = result.table.dataframe
        df.columns = ["Species"]
        print()
        print(df)
        df.to_csv(f"results-{security_mode}.csv", index=False)
        if security_mode == "smpc":
            df_smpc = df
        else:
            df_fed = df
    else:
        print(f"Inference failed")

print(f"Comparing SMPC and Federated results: {all(df_smpc == df_fed)}")
