#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from tripleblind.table_asset import StatFunc

import tripleblind as tb


tb.initialize(api_token=tb.config.example_user3["token"], example=True)


table1 = tb.TableAsset.find(
    "EXAMPLE - Split Statistics (part 1)", owned_by=tb.config.example_user1["team_id"]
)
table2 = tb.TableAsset.find(
    "EXAMPLE - Split Statistics (part 2)", owned_by=tb.config.example_user2["team_id"]
)
table3 = tb.TableAsset.find(
    "EXAMPLE - Split Statistics (part 3)", owned_by=tb.config.example_user3["team_id"]
)

# Calculate a variety of statistics on the "Normal" and "ContinuousUniform"
# columns on the intersection of two vertically partitioned datasets.
result = table1.get_statistics(
    column=["Normal", "Continuous"],
    function=[
        StatFunc.MAXIMUM,
        StatFunc.MINIMUM,
        StatFunc.MEDIAN,
        StatFunc.QUARTILES,
        StatFunc.MEAN,
        StatFunc.VARIANCE,
        StatFunc.STANDARD_DEVIATION,
        StatFunc.SKEW,
        StatFunc.KURTOSIS,
        StatFunc.COUNT,
        StatFunc.CONFIDENCE_INTERVAL,
        StatFunc.STANDARD_ERROR,
    ],
    combine_with=[table3],
    match_column=["identifier", "id"],  # table1 uses 'identifier', table3 uses 'id'
)
if not result:
    raise SystemError("Statistic calculation failed.")


print("\n=================================")
print("Overall statistics\n")
print(result.dataframe)
print("\n\n")

# Both table1 and table2 have columns named 'ContinuousUniform', so the column
# name is disambiguated by prefixing the positional index of the asset to which
# the column belongs.  So "0.ContinuousUniform" comes from table1 and
# "1.ContinuousUniform" comes form table3 (the first asset in the combine_with
# list).
#
# Note that organization-three does not need to own any of the assets, but the
# operation can be initiated by an organization-three user who has been granted
# access to the data via an agreement.
result = table1.get_statistics(
    "0.Continuous",
    combine_with=[table2],
    group_by="DiscreteA",
    match_column=["identifier", "id"],
)
print("\n=================================")
print("table1.Continuous statistics grouped by 'table2.DiscreteA'\n")
print(result.dataframe)
print("\n\n")

# Prefixing the positional index of the asset also applies to the group_by parameter.
# PSI VP Blind Stats may also be computed on more than 2 assets.
result = table1.get_statistics(
    "Normal",
    function=[
        StatFunc.MEAN,
    ],
    group_by="2.DiscreteB",
    combine_with=[table2, table3],
    match_column=["identifier", "id", "id"],
)
print("\n=================================")
print("Statistics grouped by 'table3.DiscreteA'\n")
print(result.dataframe)
print("\n\n")

# Calculate statistics again, grouping on the "DiscreteB" field.
result = table1.get_statistics(
    "Normal",
    function=[
        StatFunc.MEAN,
    ],
    combine_with=[table2, table3],
    group_by="DiscreteB",
    match_column=["identifier", "id", "id"],
)
if not result:
    print("\nThe job failed due to k-grouping setting -- as expected.")
print("\n\n")
