#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import tripleblind as tb


# user3 will be the report creator (also owns some data)
tb.initialize(tb.config.example_user3["token"])

group = tb.FederationGroup.find(name="[DEMO] Harmony Data Federation")
assert group.id is not None

# first, create the report template itself
query_template = """
SELECT
    bv.VisitType AS visittype,
    bv.BillingCode AS icd10cm,
    -- 'demographic' will be renamed in postprocessor
    {{demographic}} AS demographic,
    -- use different SQL specifically for mssql azure
    {{#DIALECT_MSSQL}}
    DATEDIFF(Day, bv.AdmitDate, bv.DischargeDate) AS lengthofstay
    {{/DIALECT_MSSQL}}
    {{#DIALECT_ORACLE}}
    (bv.DischargeDate - bv.AdmitDate) AS lengthofstay
    {{/DIALECT_ORACLE}}
    {{#DIALECT_POSTGRESQL}}
    EXTRACT(DAY FROM bv.DischargeDate) - EXTRACT(day FROM bv.AdmitDate) AS lengthofstay
    {{/DIALECT_POSTGRESQL}}
FROM
    Patient pv
JOIN
    BilledVisit bv ON pv.PatientID = bv.PatientID
WHERE
    {{#DIALECT_ORACLE}} bv.AdmitDate > DATE '{{date}}' {{/DIALECT_ORACLE}}
    {{^DIALECT_ORACLE}} bv.AdmitDate > '{{date}}' {{/DIALECT_ORACLE}}
    -- ICD and PCS codes are both optional thanks to the template logic
    {{#ICDcode}} {{#_FIRST}}AND ({{/_FIRST}}
    BillingCode = '{{_VALUE}}' {{^_LAST}}OR{{/_LAST}}
    {{#_LAST}}){{/_LAST}} {{/ICDcode}}
    {{#PCScode}} {{#_FIRST}}AND ({{/_FIRST}}
    ProcedureCode = '{{_VALUE}}' {{^_LAST}}OR{{/_LAST}}
    {{#_LAST}}){{/_LAST}} {{/PCScode}}
-- no GROUP BY or ORDER BY, they appear in the AggregationRules instead
"""

# then define the aggregation step
agg_template = tb.report_asset.AggregationRules.create(
    group_by=["demographic", "visittype"],
    aggregates={"lengthofstay": "mean", "icd10cm": "count"},
    sort_order="asc",  # "asc" or "desc"
)

# This post-processing script will be run after the query is executed. It can be
# used to polish the output before it is returned to the user.  The parameters
# are a pandas dataframe and a context dictionary.  The context contains all of
# the parameters selected by the user.
#
#    "name": package.meta.record.name,
#    "description": package.meta.record.description,     # str
#    "initiator_details": job_params.initiator_details,  # Dict[str, str]
#    "attributes": {
#        "report_values": display_params,                # Dict[str, str]
#        "raw_values": raw_params,                       # Dict[str, str]
#        "federation_members": fed_members,              # List[str] (only for federated reports)
#
post_processing_script = """
def postprocess(input, ctx):
    input.rename(columns={
        'lengthofstay': 'Mean Length Of Stay (Days)',
        'icd10cm': 'Number of Patients',
        # pull the display value of the selected demographic column out of the ctx
        'demographic': ctx["attributes"]["report_values"]["Demographic to group by"]
    }, inplace=True)
    input.reset_index(drop=True, inplace=True)
    return input
"""
# NOTE: This can also be a path to a file containing the script.  If a path
# is provided, the file will be loaded and embedded into the report template.


# report params
demographic_param = tb.report_asset.ReportParameter.create_string(
    name="demographic",
    display="Demographic to group by",
    description="...",
    options=[
        tb.report_asset.ParameterOption("pv.Sex", "Gender"),
        tb.report_asset.ParameterOption("pv.State", "State"),
        tb.report_asset.ParameterOption("pv.PostalCode", "ZIP Code"),
    ],
)

date_param = tb.report_asset.ReportParameter.create_datetime(
    name="date",
    description="Earliest admit date",
    display="Admissions only after this date",
    datetime_format="%Y-%m-%d",
    required=True,
    min_value="2020-01-01",
    max_value="2025-12-31",
    default_value="2020-01-01",
)

icd_code_param = tb.report_asset.ReportParameter.create_code(
    name="ICDcode",
    display="Filter on a diagnosis code (either code is sufficient)",
    description="Choose an ICD code to filter by",
    systems=["icd9_cm", "icd10_cm"],
    required=False,
)

pcs_code_param = tb.report_asset.ReportParameter.create_code(
    name="PCScode",
    display="Filter on a procedure code (either code is sufficient)",
    description="Choose a PCS code to filter by",
    systems=["icd9_pcs", "icd10_pcs", "cpt"],
    required=False,
)

blind_report = tb.report_asset.DatabaseReport.create(
    name="[DEMO] Readmission Rate Report (Federated Report)",
    desc="""Report the count of patients and average length of stay for the
      selected diagnosis and/or procedure codes. Optionally group results by
      sex, state, or ZIP code. This is part of the Hospital Data Federation
      demo.""",
    query_template=query_template,
    is_discoverable=True,
    allow_overwrite=True,
    federation_group=group,
    federation_aggregation=agg_template,
    post_processing=post_processing_script,
    params=[demographic_param, date_param, icd_code_param, pcs_code_param],
    validate_sql=False,
)

print(f"Federated Blind Report Created: {blind_report}")

# Create an agreement to allow any team to execute this report
# NOTE: Each participating team must also create an agreement to allow
# their data to be used before this report is fully functional.
blind_report.add_agreement(with_team="ANY", operation=blind_report.uuid)
