#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import zipfile as zipfile
from pathlib import Path

import tripleblind as tb


###############################################################################
# Retrieve local copies of data used throughout this example.  These have been
# cached as private assets that we retrieve and unpack here.  Files are placed
# in an "example_data" folder for use during local inferencing and plotting.
###############################################################################


# initialize user groups

tb.initialize(api_token=tb.config.example_user1["token"], example=True)

# create example_data directory
data_dir = Path("example_data")
data_dir.mkdir(exist_ok=True)


# find assets
alice_asset_train = tb.Asset.find(
    "EXAMPLE - alice's portion of a PSI Vertically Partitioned dataset for training.",
    owned_by=tb.config.example_user1["team_id"],
)
bob_asset_train = tb.Asset.find(
    "EXAMPLE - bob's portion of a PSI Vertically Partitioned dataset for training.",
    owned_by=tb.config.example_user2["team_id"],
)

alice_asset_inf = tb.Asset.find(
    "EXAMPLE - alice's portion of a PSI Vertically Partitioned dataset for inference.",
    owned_by=tb.config.example_user1["team_id"],
)
bob_asset_inf = tb.Asset.find(
    "EXAMPLE - bob's portion of a PSI Vertically Partitioned dataset for inference.",
    owned_by=tb.config.example_user2["team_id"],
)

print("Retrieving assets for local runs...")

# retrieve assets
alice_asset_train.retrieve(
    save_as=f"{data_dir}/alice_training.zip",
    overwrite=True,
    session=tb.initialize(api_token=tb.config.example_user1["token"]),
)


bob_asset_train.retrieve(
    save_as=f"{data_dir}/bob_training.zip",
    overwrite=True,
    session=tb.initialize(api_token=tb.config.example_user2["token"]),
)

alice_asset_inf.retrieve(
    save_as=f"{data_dir}/alice_inf_data.zip",
    overwrite=True,
    session=tb.initialize(api_token=tb.config.example_user1["token"]),
)


bob_asset_inf.retrieve(
    save_as=f"{data_dir}/bob_inf_data.zip",
    overwrite=True,
    session=tb.initialize(api_token=tb.config.example_user2["token"]),
)

# unzip the assets
files = [
    f"{data_dir}/alice_inf_data.zip",
    f"{data_dir}/bob_inf_data.zip",
    f"{data_dir}/alice_training.zip",
    f"{data_dir}/bob_training.zip",
]

for file in files:
    with zipfile.ZipFile(file, "r") as zip_ref:
        zip_ref.extractall(data_dir)
