#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from datetime import datetime

import tripleblind as tb


DATA_OWNER = tb.config.example_user1
CONSUMER = tb.config.example_user2

# Running as the data owner
session = tb.initialize(api_token=DATA_OWNER["token"])

# Use the most recent trained model, owned by the data consumer
model_asset = tb.ModelAsset(tb.util.load_from("trained_model_asset_id.out"))
assert model_asset.is_valid

# Find the live data stream to be used for generating predictions
mongodb_asset = tb.Asset.find(
    "Energy forecast live data", owned_by=DATA_OWNER["team_id"]
)

# Preprocessor
pre = (
    tb.TabularPreprocessor.builder()
    .expand_input_dims([0, 1])
    .add_column("P")
    .dtype("float32")
)

target_interval = 1  # in seconds

# Launch the process which will run on the DATA_OWNER's access point using the
# model owned by CONSUMER team.  With the "allow_to_listen" parameter,
# the output stream is also accessible by the other team.
job_name = "Example Continuous Inference"
output_stream = model_asset.infer(
    mongodb_asset,
    preprocessor=pre,
    job_name=job_name,
    params={
        "security": "fed",
        "data_type": "numpy",
        "final_layer_softmax": False,
        "batch_size": 1,  # Only used for FED inference
        "inference_interval": target_interval,  # Infer repeatedly at this interval
        "allow_to_listen": model_asset.team_id,  # Team ID allowed to also monitor results
    },
    silent=False,
    stream_output=True,
)

last = datetime.now().timestamp()
last_data = last
try:
    print(f"Starting streaming inference session {output_stream.job.id}")
    for status in output_stream.status():
        if isinstance(status, dict):
            # This protocol's status will be in the format:
            #    {
            #        "__type": "InferenceResults",
            #
            #        "data_gathered_timestamp": 1680441227.5509229,
            #        "result": [[5.0]],
            #    }
            #
            # Extract the interesting parts and display them.

            data_gathered_epoch_time = status["data_gathered_timestamp"]
            when = datetime.fromtimestamp(data_gathered_epoch_time).strftime("%x %X")

            inferred_value = status["result"][0][0]

            now = datetime.now().timestamp()
            elapsed = now - last  # seconds since last inference

            # Outputs:
            # Delta: 1.2 s (-0.2) 	at 05/06/2023 03:45:35 PM (+0.94)	Value: 1845.8944091796875
            print(
                f"Delta: {elapsed:<4.02}s ({target_interval-elapsed:+2.0}) \tat {when} ({data_gathered_epoch_time-last_data:+0.02})\tValue: {inferred_value}"
            )
            last = now
            last_data = data_gathered_epoch_time

except BaseException as e:
    output_stream.handle_exception(e)
