#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.
from datetime import datetime
from urllib.parse import urljoin

import tripleblind as tb


DATA_OWNER = tb.config.example_user1
CONSUMER = tb.config.example_user2

# Connects to the streaming inference being run by another party.  This assumes
# the DATA_OWNER started the inference with the "allow_to_listen" set to the ID
# of the CONSUMER's Team.  Otherwise this user will not have the rights to see
# the output stream.

session = tb.initialize(api_token=CONSUMER["token"])

# Get the most recently trained model
model_asset = tb.ModelAsset(tb.util.load_from("trained_model_asset_id.out"))

# Find a job using the trained model belonging to this team, using the
# audit records.
response = session.get_as_json(
    url=urljoin(session.endpoint, "/api/audit/assets?limit=50"),
    headers={"Authorization": f"Token {session.token}"},
)
job_id = None
for audit in response["results"]:
    if audit["job_status"] == "calc" and audit["algorithm_id"] == str(model_asset.uuid):
        job_id = audit["job_id"]
        break

if not job_id:
    print(f"No active process was found using your model:")
    print(f"    Asset ID: {model_asset.uuid}")
    print(f"        Name: {model_asset.name}")
    raise SystemExit(1)

# Connect to the job and display the inferences as they are generated.
output_stream = tb.job.RemoteStatusOutputStream(job_id=job_id)
last = datetime.now().timestamp()
last_data = last
target_interval = 1  # in seconds

try:
    print(f"Connecting to streaming inference session {output_stream.job_id}")
    for status in output_stream.remote_status():
        if isinstance(status, dict):
            # This protocol's status will be in the format:
            #    {
            #        "__type": "InferenceResults",
            #
            #        "data_gathered_timestamp": 1680441227.5509229,
            #        "result": [[5.0]],
            #    }
            #
            # Extract the interesting parts and display them.

            data_gathered_epoch_time = status["data_gathered_timestamp"]
            when = datetime.fromtimestamp(data_gathered_epoch_time).strftime("%x %X")

            inferred_value = status["result"][0][0]

            now = datetime.now().timestamp()
            elapsed = now - last  # seconds since last inference

            # Outputs:
            # Delta: 1.2 s (-0.2) 	at 05/06/2023 03:45:35 PM (+0.94)	Value: 1845.8944091796875
            print(
                f"Delta: {elapsed:<4.02}s ({target_interval-elapsed:+2.0}) \tat {when} ({data_gathered_epoch_time-last_data:+0.02})\tValue: {inferred_value}"
            )
            last = now
            last_data = data_gathered_epoch_time

except BaseException as e:
    output_stream.handle_exception(e)
