#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

from pathlib import Path

import snowflake.connector
import yaml

import tripleblind as tb


datafile = "prescription_data.csv"
tablename = "pharmacy_prescriptions"

# Load the configuration file
tb.util.set_script_dir_current()
data_dir = Path("example_data")

with open("snowflake.yaml", "r") as f:
    config = yaml.safe_load(f.read())


def main():
    ctx = snowflake.connector.connect(
        user=config["USERNAME"],
        password=config["PASSWORD"],
        account=config["ACCOUNT"],
        database=config["DATABASE"],
        schema=config["SCHEMA"],
    )
    cs = ctx.cursor()

    try:
        drop_prior_table(cs)
        create_table(cs)
        put_csv_file(cs)
        copy_csv_file(cs)
    finally:
        cs.close()

    ctx.close()
    print("Dataset ready for use.")


def drop_prior_table(cs):
    # Remove any existing table
    cs.execute(f"DROP TABLE IF EXISTS {tablename};")
    res = cs.fetchone()
    cs.execute(f"DROP STAGE IF EXISTS public.stage_{tablename};")
    res = cs.fetchone()


def create_table(cs):
    # Build a command to create the table: id_code, target, var_0...var_199
    sql_cmd = f"CREATE TABLE {tablename}(\n"
    sql_cmd += " date       float   not null    primary key,\n"
    sql_cmd += " name       text    not null,\n"
    sql_cmd += " address    text    not null,\n"
    sql_cmd += " drug_name  text    not null,\n"
    sql_cmd += " dosage     text    not null,\n"
    sql_cmd += " frequency  text    not null\n"
    sql_cmd += ")"

    print("Defining table...")
    cs.execute(sql_cmd)
    res = cs.fetchone()


def put_csv_file(cs):
    try:
        # Create a Snowflake staging area
        cs.execute(f"CREATE STAGE public.stage_{tablename}")
        res = cs.fetchone()
    except snowflake.connector.errors.ProgrammingError as e:
        if e.errno == 2002:
            # The stage already exists, continue loading the database
            pass
        else:
            print("Failed to creating staging area:")
            print(e)
            raise

    # Transfer dataset to remote staging area
    print("Staging data...")
    cs.execute(
        f"""
        PUT 'file://{data_dir / datafile}' @public.stage_{tablename}/{datafile}
            AUTO_COMPRESS = FALSE
            SOURCE_COMPRESSION = NONE
            OVERWRITE = TRUE;
    """
    )
    res = cs.fetchone()
    print("  ", res[0])


def copy_csv_file(cs):
    # Populate the database from the staged data
    print("Loading data into table...")
    cs.execute(
        f"""
        COPY INTO {tablename}
            FROM @public.stage_{tablename}/{datafile}
        FILE_FORMAT = (type = csv field_delimiter = ',' FIELD_OPTIONALLY_ENCLOSED_BY = '"' skip_header = 1);
    """
    )
    res = cs.fetchone()
    print("  ", res[0])


if __name__ == "__main__":
    main()
