#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import json
import pandas as pd
import argparse
from pathlib import Path
import logging
import sys
from datetime import datetime

# Check for required packages
required_packages = {
    "psycopg2": "psycopg2-binary",
    "sqlalchemy": "sqlalchemy"
}

missing_packages = []
for package, install_name in required_packages.items():
    try:
        __import__(package)
    except ImportError:
        missing_packages.append(f"pip install {install_name}")

if missing_packages:
    print("ERROR: Missing required packages. Please install them with:")
    for cmd in missing_packages:
        print(f"  {cmd}")
    sys.exit(1)

# Import after checking
import psycopg2
from sqlalchemy import create_engine, text

import tripleblind as tb
from tripleblind.util.timer import Timer

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def ensure_cloud_storage_info():
    """
    Check if cloud_storage_info.json exists, and create it from environment
    variables if it doesn't. This allows the script to run independently.
    """
    from datetime import datetime
    
    info_path = tb.config.data_dir / "cloud_storage_info.json"
    
    # If the file already exists, just return its contents
    if info_path.exists():
        try:
            with open(info_path, 'r') as f:
                cloud_info = json.load(f)
            logger.info(f"Using existing cloud storage configuration from {info_path}")
            return cloud_info
        except Exception as e:
            logger.warning(f"Error reading existing cloud storage info: {e}")
            # Fall through to recreate the file
    
    # File doesn't exist or is invalid - create it from environment variables
    logger.info("Creating cloud storage configuration from environment variables")
    
    # Check for Azure environment variables
    if os.environ.get("AZURE_STORAGE_ACCOUNT") and os.environ.get("AZURE_STORAGE_KEY"):
        cloud_info = {
            "provider": "azure",
            "storage_account": os.environ.get("AZURE_STORAGE_ACCOUNT"),
            "storage_key": os.environ.get("AZURE_STORAGE_KEY"),
            "container_or_bucket": os.environ.get("AZURE_CONTAINER_NAME", "covid-xrays"),
            "upload_success": True,  # Mark as successful since we're creating it now
            "upload_timestamp": datetime.now().isoformat()
        }
        logger.info(f"Using Azure Blob Storage configuration: account={cloud_info['storage_account']}, container={cloud_info['container_or_bucket']}")
    
    # We only support Azure for this example
    
    else:
        logger.warning("No cloud storage credentials found in environment variables")
        logger.warning("To use cloud storage with images, set the following environment variables:")
        logger.warning("For Azure: AZURE_STORAGE_ACCOUNT, AZURE_STORAGE_KEY, AZURE_CONTAINER_NAME")
        
        # Create a minimal placeholder configuration
        cloud_info = {
            "provider": "azure",  # Default provider
            "storage_account": "demo-account",
            "storage_key": "demo-key",
            "container_or_bucket": "covid-xrays",
            "upload_success": False,  # Mark as unsuccessful
            "upload_timestamp": datetime.now().isoformat()
        }
        logger.warning("Using placeholder cloud configuration - images may not be accessible")
    
    # Save the configuration
    try:
        with open(info_path, 'w') as f:
            json.dump(cloud_info, f, indent=2)
        logger.info(f"Cloud storage configuration saved to {info_path}")
    except Exception as e:
        logger.error(f"Error saving cloud storage configuration: {e}")
    
    return cloud_info

def prepare_postgres_tables(cloud_info):
    """Upload data to PostgreSQL and create tables with cloud image references"""
    try:
        # Load training and testing data for both organizations
        org2_train_df = pd.read_csv(tb.config.data_dir / "covid_multimodal_org2_train.csv")
        org2_test_df = pd.read_csv(tb.config.data_dir / "covid_multimodal_org2_test.csv")
        org3_train_df = pd.read_csv(tb.config.data_dir / "covid_multimodal_org3_train.csv")
        org3_test_df = pd.read_csv(tb.config.data_dir / "covid_multimodal_org3_test.csv")

        # Load image mappings to replace local paths with cloud paths
        image_mappings = pd.read_csv(tb.config.data_dir / "covid_image_mappings.csv")
        path_map = dict(zip(image_mappings["local_path"], image_mappings["cloud_path"]))

        # Replace local paths with cloud paths for org3 data (which has the images)
        org3_train_df["cloud_image_path"] = org3_train_df["image_path"].map(path_map)
        org3_test_df["cloud_image_path"] = org3_test_df["image_path"].map(path_map)

        # Get PostgreSQL connection parameters from environment variables
        host = os.environ.get("PG_HOST", "localhost")
        port = os.environ.get("PG_PORT", "5432")
        user = os.environ.get("PG_USER", "postgres")
        password = os.environ.get("PG_PASSWORD")
        org2_database = os.environ.get("PG_DATABASE_ORG2", "covid_multimodal_org2")
        org3_database = os.environ.get("PG_DATABASE_ORG3", "covid_multimodal_org3")

        if not password:
            logger.error("Missing PostgreSQL password. Set PG_PASSWORD environment variable.")
            raise ValueError("Missing PostgreSQL password")

        # Create connection string for both databases
        org2_connection_string = f"postgresql://{user}:{password}@{host}:{port}/{org2_database}"
        org3_connection_string = f"postgresql://{user}:{password}@{host}:{port}/{org3_database}"

        # Create databases if they don't exist
        try:
            # Connect to default 'postgres' database to create our databases
            conn = psycopg2.connect(
                host=host,
                port=port,
                user=user,
                password=password,
                database="postgres"
            )
            conn.autocommit = True
            cursor = conn.cursor()

            # Check if org2 database exists
            cursor.execute(f"SELECT 1 FROM pg_database WHERE datname='{org2_database}'")
            if cursor.fetchone() is None:
                cursor.execute(f"CREATE DATABASE {org2_database}")
                logger.info(f"Created database: {org2_database}")

            # Check if org3 database exists
            cursor.execute(f"SELECT 1 FROM pg_database WHERE datname='{org3_database}'")
            if cursor.fetchone() is None:
                cursor.execute(f"CREATE DATABASE {org3_database}")
                logger.info(f"Created database: {org3_database}")

            cursor.close()
            conn.close()
        except Exception as e:
            logger.error(f"Error creating databases: {str(e)}")
            # Continue anyway, the databases might already exist

        # Create engines and upload data to PostgreSQL
        org2_engine = create_engine(org2_connection_string)
        org3_engine = create_engine(org3_connection_string)

        # Create tables for Organization 2 (first set of features)
        org2_train_df.to_sql("covid_train_data", org2_engine, if_exists="replace", index=False)
        org2_test_df.to_sql("covid_test_data", org2_engine, if_exists="replace", index=False)

        # Create tables for Organization 3 (second set of features + images)
        org3_train_df.to_sql("covid_train_data", org3_engine, if_exists="replace", index=False)
        org3_test_df.to_sql("covid_test_data", org3_engine, if_exists="replace", index=False)

        logger.info(f"Created connections to PostgreSQL at {host}:{port} and uploaded data to both databases")

        # Save database connection information for client scripts
        db_info = {
            "org2": {
                "connection_string": org2_connection_string,
                "host": host,
                "port": port,
                "user": user,
                "password": password,
                "database": org2_database
            },
            "org3": {
                "connection_string": org3_connection_string,
                "host": host,
                "port": port,
                "user": user,
                "password": password,
                "database": org3_database
            }
        }
        
        # Save database connection info for client scripts
        with open(tb.config.data_dir / "postgres_connection_info.json", "w") as f:
            json.dump(db_info, f, indent=2)
            
        logger.info(f"Saved database connection information to {tb.config.data_dir / 'postgres_connection_info.json'}")

        return db_info
    except Exception as e:
        logger.error(f"Error preparing PostgreSQL tables: {str(e)}")
        raise

def verify_postgres_tables(postgres_info):
    """Verify tables in PostgreSQL are correctly set up"""
    all_verified = True

    # Verify Organization 2 tables
    try:
        # Create connection to Organization 2 PostgreSQL
        conn_org2 = psycopg2.connect(
            host=postgres_info["org2"]["host"],
            port=postgres_info["org2"]["port"],
            user=postgres_info["org2"]["user"],
            password=postgres_info["org2"]["password"],
            database=postgres_info["org2"]["database"]
        )
        cursor_org2 = conn_org2.cursor()

        # Check if tables exist
        cursor_org2.execute("SELECT tablename FROM pg_tables WHERE schemaname='public'")
        tables_org2 = [row[0] for row in cursor_org2.fetchall()]

        if 'covid_train_data' in tables_org2 and 'covid_test_data' in tables_org2:
            # Count records in each table
            cursor_org2.execute("SELECT COUNT(*) FROM covid_train_data")
            train_count_org2 = cursor_org2.fetchone()[0]

            cursor_org2.execute("SELECT COUNT(*) FROM covid_test_data")
            test_count_org2 = cursor_org2.fetchone()[0]

            cursor_org2.close()
            conn_org2.close()

            logger.info(f"Verified Organization 2 PostgreSQL tables: covid_train_data ({train_count_org2} records) and covid_test_data ({test_count_org2} records)")
        else:
            missing_tables = []
            if 'covid_train_data' not in tables_org2:
                missing_tables.append('covid_train_data')
            if 'covid_test_data' not in tables_org2:
                missing_tables.append('covid_test_data')

            logger.warning(f"Missing Organization 2 PostgreSQL tables: {', '.join(missing_tables)}")
            cursor_org2.close()
            conn_org2.close()
            all_verified = False

    except Exception as e:
        logger.error(f"Error verifying Organization 2 PostgreSQL tables: {str(e)}")
        all_verified = False

    # Verify Organization 3 tables
    try:
        # Create connection to Organization 3 PostgreSQL
        conn_org3 = psycopg2.connect(
            host=postgres_info["org3"]["host"],
            port=postgres_info["org3"]["port"],
            user=postgres_info["org3"]["user"],
            password=postgres_info["org3"]["password"],
            database=postgres_info["org3"]["database"]
        )
        cursor_org3 = conn_org3.cursor()

        # Check if tables exist
        cursor_org3.execute("SELECT tablename FROM pg_tables WHERE schemaname='public'")
        tables_org3 = [row[0] for row in cursor_org3.fetchall()]

        if 'covid_train_data' in tables_org3 and 'covid_test_data' in tables_org3:
            # Count records in each table
            cursor_org3.execute("SELECT COUNT(*) FROM covid_train_data")
            train_count_org3 = cursor_org3.fetchone()[0]

            cursor_org3.execute("SELECT COUNT(*) FROM covid_test_data")
            test_count_org3 = cursor_org3.fetchone()[0]

            cursor_org3.close()
            conn_org3.close()

            logger.info(f"Verified Organization 3 PostgreSQL tables: covid_train_data ({train_count_org3} records) and covid_test_data ({test_count_org3} records)")
        else:
            missing_tables = []
            if 'covid_train_data' not in tables_org3:
                missing_tables.append('covid_train_data')
            if 'covid_test_data' not in tables_org3:
                missing_tables.append('covid_test_data')

            logger.warning(f"Missing Organization 3 PostgreSQL tables: {', '.join(missing_tables)}")
            cursor_org3.close()
            conn_org3.close()
            all_verified = False

    except Exception as e:
        logger.error(f"Error verifying Organization 3 PostgreSQL tables: {str(e)}")
        all_verified = False

    return all_verified

# Read-only user creation is not needed for this example
# Clients will use the same PostgreSQL credentials from environment variables

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Create PostgreSQL tables for the Multimodal Cloud Vision example")
    args = parser.parse_args()

    logger.info("Creating PostgreSQL tables for the Multimodal Cloud Vision example...")

    # Ensure cloud storage configuration exists
    cloud_info = ensure_cloud_storage_info()
    if not cloud_info.get("upload_success", False):
        logger.warning("Cloud storage configuration may not be valid")
        logger.warning("This could cause issues when accessing images during training/inference")
        # Continue execution but with a warning, instead of exiting

    with Timer("Preparing PostgreSQL databases"):
        postgres_info = prepare_postgres_tables(cloud_info)

    # Verify PostgreSQL tables
    logger.info("Verifying PostgreSQL tables for both organizations...")
    tables_exist = verify_postgres_tables(postgres_info)
    if tables_exist:
        logger.info("Successfully verified PostgreSQL tables for both organizations")
    else:
        logger.warning("Could not verify all PostgreSQL tables. Make sure PostgreSQL is running and accessible.")
    
    # No read-only user is created for this example
    # Client scripts will use the same PostgreSQL credentials from environment variables
    logger.info("Client scripts will use the same PostgreSQL credentials from PG_* environment variables")

    logger.info("\nDatabase tables created successfully. Ready for asset positioning!")