#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import json
import argparse
from pathlib import Path
import logging
import sys

# Check for required packages
required_packages = {
    "psycopg2": "psycopg2-binary",
    "sqlalchemy": "sqlalchemy"
}

missing_packages = []
for package, install_name in required_packages.items():
    try:
        __import__(package)
    except ImportError:
        missing_packages.append(f"pip install {install_name}")

if missing_packages:
    print("ERROR: Missing required packages. Please install them with:")
    for cmd in missing_packages:
        print(f"  {cmd}")
    sys.exit(1)

# Import after checking
import psycopg2
from sqlalchemy import create_engine, text

import tripleblind as tb
from tripleblind.util.timer import Timer

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def ensure_cloud_storage_info():
    """
    Check if cloud_storage_info.json exists, and create it from environment
    variables if it doesn't. This allows script 2 to run independently.
    """
    from datetime import datetime

    info_path = tb.config.data_dir / "cloud_storage_info.json"

    # If the file already exists, just return its contents
    if info_path.exists():
        try:
            with open(info_path, 'r') as f:
                cloud_info = json.load(f)
            logger.info(f"Using existing cloud storage configuration from {info_path}")
            return cloud_info
        except Exception as e:
            logger.warning(f"Error reading existing cloud storage info: {e}")
            # Fall through to recreate the file

    # File doesn't exist or is invalid - create it from environment variables
    logger.info("Creating cloud storage configuration from environment variables")

    # Check for Azure environment variables (only Azure is supported)
    if os.environ.get("AZURE_STORAGE_ACCOUNT") and os.environ.get("AZURE_STORAGE_KEY"):
        cloud_info = {
            "provider": "azure",
            "storage_account": os.environ.get("AZURE_STORAGE_ACCOUNT"),
            "storage_key": os.environ.get("AZURE_STORAGE_KEY"),
            "container_or_bucket": os.environ.get("AZURE_CONTAINER_NAME", "covid-xrays"),
            "upload_success": True,  # Mark as successful since we're creating it now
            "upload_timestamp": datetime.now().isoformat()
        }
        logger.info(f"Using Azure Blob Storage configuration: account={cloud_info['storage_account']}, container={cloud_info['container_or_bucket']}")

    else:
        logger.error("No Azure storage credentials found in environment variables")
        logger.error("For Azure: AZURE_STORAGE_ACCOUNT, AZURE_STORAGE_KEY, AZURE_CONTAINER_NAME")
        raise ValueError("Missing Azure storage credentials")

    # Save the configuration
    try:
        with open(info_path, 'w') as f:
            json.dump(cloud_info, f, indent=2)
        logger.info(f"Cloud storage configuration saved to {info_path}")
    except Exception as e:
        logger.error(f"Error saving cloud storage configuration: {e}")

    return cloud_info

def create_linked_storage_columns(cloud_info):
    """Create linked_storage_columns configuration for Azure Blob Storage"""
    # This example only supports Azure
    return {
        "image_path": {
            "storage_type": "azure",
            "content_type": "image/png",
            "storage_account_name": cloud_info.get("storage_account"),
            "storage_key": cloud_info.get("storage_key"),
            "file_system": cloud_info.get("container_or_bucket")
        }
    }

def get_postgres_connection_info():
    """Get PostgreSQL connection information from environment variables"""
    try:
        # Use environment variables for PostgreSQL connection
        host = os.environ.get("PG_HOST", "localhost")
        port = os.environ.get("PG_PORT", "5432")
        user = os.environ.get("PG_USER", "postgres")
        password = os.environ.get("PG_PASSWORD")
        org2_database = os.environ.get("PG_DATABASE_ORG2", "covid_multimodal_org2")
        org3_database = os.environ.get("PG_DATABASE_ORG3", "covid_multimodal_org3")

        if not password:
            logger.error("Missing PostgreSQL password. Set PG_PASSWORD environment variable.")
            raise ValueError("Missing PostgreSQL password")

        # Create connection string for both databases
        org2_connection_string = f"postgresql://{user}:{password}@{host}:{port}/{org2_database}"
        org3_connection_string = f"postgresql://{user}:{password}@{host}:{port}/{org3_database}"

        return {
            "org2": {
                "connection_string": org2_connection_string,
                "host": host,
                "port": port,
                "user": user,
                "password": password,
                "database": org2_database
            },
            "org3": {
                "connection_string": org3_connection_string,
                "host": host,
                "port": port,
                "user": user,
                "password": password,
                "database": org3_database
            }
        }
    except Exception as e:
        logger.error(f"Error getting PostgreSQL connection information: {str(e)}")
        raise

def position_assets(postgres_info, cloud_info, run_id):
    """Position database assets with linked storage columns on access points"""
    # Initialize sessions for all three organizations
    tb.initialize(api_token=tb.config.example_user1["token"])  # Organization 1 - model owner
    session_org_2 = tb.Session(
        api_token=tb.config.example_user2["token"], from_default=True
    )  # Organization 2 - first data owner
    session_org_3 = tb.Session(
        api_token=tb.config.example_user3["token"], from_default=True
    )  # Organization 3 - second data owner with images

    # Create linked storage columns configuration for Organization 3 (has images)
    linked_storage_columns = create_linked_storage_columns(cloud_info)

    # Create and position Organization 2 training asset (first set of features)
    try:
        logger.info(f"Creating Organization 2 training asset with first set of features...")
        org2_train_asset = tb.asset.DatabaseDataset.create(
            connection=postgres_info["org2"]["connection_string"],
            query="SELECT image_id, feature_0, feature_1, feature_2, feature_3, feature_4, label "
                 "FROM covid_train_data "
                 "ORDER BY image_id",
            name=f"covid-multimodal-org2-train-{run_id}",
            desc=f"COVID-19 X-ray multimodal training dataset - First feature set (Org 2)",
            is_discoverable=True,
            allow_overwrite=True,
            session=session_org_2
        )

        logger.info(f"Creating Organization 2 testing asset with first set of features...")
        org2_test_asset = tb.asset.DatabaseDataset.create(
            connection=postgres_info["org2"]["connection_string"],
            query="SELECT image_id, feature_0, feature_1, feature_2, feature_3, feature_4, label "
                 "FROM covid_test_data "
                 "ORDER BY image_id",
            name=f"covid-multimodal-org2-test-{run_id}",
            desc=f"COVID-19 X-ray multimodal testing dataset - First feature set (Org 2)",
            is_discoverable=True,
            allow_overwrite=True,
            session=session_org_2
        )

        # Add agreement with Organization 1 (model owner)
        org2_train_asset.add_agreement(
            with_team=tb.config.example_user1["team_id"],
            operation=tb.Operation.VERTICAL_BLIND_LEARNING,
            session=session_org_2
        )
        org2_test_asset.add_agreement(
            with_team=tb.config.example_user1["team_id"],
            operation=tb.Operation.EXECUTE,
            session=session_org_2
        )
        logger.info("Created agreements for Organization 2 assets with Organization 1")

    except tb.TripleblindAssetAlreadyExists:
        logger.error(f"Organization 2 asset already exists. Try changing the run_id.")
        exit(1)
    except Exception as e:
        logger.error(f"Error creating Organization 2 assets: {str(e)}")
        exit(1)

    # Create and position Organization 3 training asset (second set of features + images)
    try:
        logger.info(f"Creating Organization 3 training asset with second set of features and images...")
        org3_train_asset = tb.asset.DatabaseDataset.create(
            connection=postgres_info["org3"]["connection_string"],
            query="SELECT feature_5, feature_6, feature_7, feature_8, feature_9, "
                 "cloud_image_path as image_path "
                 "FROM covid_train_data "
                 "ORDER BY image_id",
            name=f"covid-multimodal-org3-train-{run_id}",
            desc=f"COVID-19 X-ray multimodal training dataset - Second feature set with images (Org 3)",
            is_discoverable=True,
            linked_storage_columns=linked_storage_columns,
            allow_overwrite=True,
            session=session_org_3
        )

        logger.info(f"Creating Organization 3 testing asset with second set of features and images...")
        org3_test_asset = tb.asset.DatabaseDataset.create(
            connection=postgres_info["org3"]["connection_string"],
            query="SELECT feature_5, feature_6, feature_7, feature_8, feature_9, "
                 "cloud_image_path as image_path "
                 "FROM covid_test_data "
                 "ORDER BY image_id",
            name=f"covid-multimodal-org3-test-{run_id}",
            desc=f"COVID-19 X-ray multimodal testing dataset - Second feature set with images (Org 3)",
            is_discoverable=True,
            linked_storage_columns=linked_storage_columns,
            allow_overwrite=True,
            session=session_org_3
        )

        # Add agreement with Organization 1 (model owner)
        org3_train_asset.add_agreement(
            with_team=tb.config.example_user1["team_id"],
            operation=tb.Operation.VERTICAL_BLIND_LEARNING,
            session=session_org_3
        )
        org3_test_asset.add_agreement(
            with_team=tb.config.example_user1["team_id"],
            operation=tb.Operation.EXECUTE,
            session=session_org_3
        )
        logger.info("Created agreements for Organization 3 assets with Organization 1")

    except tb.TripleblindAssetAlreadyExists:
        logger.error(f"Organization 3 asset already exists. Try changing the run_id.")
        exit(1)
    except Exception as e:
        logger.error(f"Error creating Organization 3 assets: {str(e)}")
        exit(1)

    # Switch back to Organization 1 session
    tb.initialize(api_token=tb.config.example_user1["token"])

    # Save asset IDs for use in subsequent scripts
    with open(tb.config.data_dir / "asset_info.json", "w") as f:
        json.dump({
            "org2_train_asset_id": str(org2_train_asset.uuid),
            "org2_test_asset_id": str(org2_test_asset.uuid),
            "org3_train_asset_id": str(org3_train_asset.uuid),
            "org3_test_asset_id": str(org3_test_asset.uuid),
            "run_id": run_id
        }, f)

    # Save expected outcomes for inference evaluation
    try:
        # Create connection to Organization 2 PostgreSQL
        conn_org2 = psycopg2.connect(
            host=postgres_info["org2"]["host"],
            port=postgres_info["org2"]["port"],
            user=postgres_info["org2"]["user"],
            password=postgres_info["org2"]["password"],
            database=postgres_info["org2"]["database"]
        )
        cursor_org2 = conn_org2.cursor()

        # Get expected labels from test data
        cursor_org2.execute("SELECT label FROM covid_test_data ORDER BY image_id")
        expected_labels = [row[0] for row in cursor_org2.fetchall()]

        cursor_org2.close()
        conn_org2.close()

        # Save expected outcomes for inference script
        with open(tb.config.data_dir / "expected_outcomes.json", "w") as f:
            json.dump({
                "labels": expected_labels
            }, f)

        logger.info(f"Saved expected outcomes for inference evaluation")
    except Exception as e:
        logger.warning(f"Could not save expected outcomes: {str(e)}")
        logger.warning("Inference evaluation may not be accurate")

    logger.info(f"Assets positioned successfully across organizations!")
    logger.info(f"Organization 2 Training asset ID: {org2_train_asset.uuid}")
    logger.info(f"Organization 2 Testing asset ID: {org2_test_asset.uuid}")
    logger.info(f"Organization 3 Training asset ID: {org3_train_asset.uuid}")
    logger.info(f"Organization 3 Testing asset ID: {org3_test_asset.uuid}")

def verify_postgres_tables(postgres_info):
    """Verify tables in PostgreSQL are correctly set up"""
    all_verified = True

    # Verify Organization 2 tables
    try:
        # Create connection to Organization 2 PostgreSQL
        conn_org2 = psycopg2.connect(
            host=postgres_info["org2"]["host"],
            port=postgres_info["org2"]["port"],
            user=postgres_info["org2"]["user"],
            password=postgres_info["org2"]["password"],
            database=postgres_info["org2"]["database"]
        )
        cursor_org2 = conn_org2.cursor()

        # Check if tables exist
        cursor_org2.execute("SELECT tablename FROM pg_tables WHERE schemaname='public'")
        tables_org2 = [row[0] for row in cursor_org2.fetchall()]

        if 'covid_train_data' in tables_org2 and 'covid_test_data' in tables_org2:
            # Count records in each table
            cursor_org2.execute("SELECT COUNT(*) FROM covid_train_data")
            train_count_org2 = cursor_org2.fetchone()[0]

            cursor_org2.execute("SELECT COUNT(*) FROM covid_test_data")
            test_count_org2 = cursor_org2.fetchone()[0]

            cursor_org2.close()
            conn_org2.close()

            logger.info(f"Verified Organization 2 PostgreSQL tables: covid_train_data ({train_count_org2} records) and covid_test_data ({test_count_org2} records)")
        else:
            missing_tables = []
            if 'covid_train_data' not in tables_org2:
                missing_tables.append('covid_train_data')
            if 'covid_test_data' not in tables_org2:
                missing_tables.append('covid_test_data')

            logger.error(f"Missing Organization 2 PostgreSQL tables: {', '.join(missing_tables)}")
            logger.error("Please run Assets/Multimodal_Cloud_Vision/2_create_database_tables.py first")
            cursor_org2.close()
            conn_org2.close()
            all_verified = False

    except Exception as e:
        logger.error(f"Error verifying Organization 2 PostgreSQL tables: {str(e)}")
        logger.error("Please ensure PostgreSQL is running and the tables are created")
        all_verified = False

    # Verify Organization 3 tables
    try:
        # Create connection to Organization 3 PostgreSQL
        conn_org3 = psycopg2.connect(
            host=postgres_info["org3"]["host"],
            port=postgres_info["org3"]["port"],
            user=postgres_info["org3"]["user"],
            password=postgres_info["org3"]["password"],
            database=postgres_info["org3"]["database"]
        )
        cursor_org3 = conn_org3.cursor()

        # Check if tables exist
        cursor_org3.execute("SELECT tablename FROM pg_tables WHERE schemaname='public'")
        tables_org3 = [row[0] for row in cursor_org3.fetchall()]

        if 'covid_train_data' in tables_org3 and 'covid_test_data' in tables_org3:
            # Count records in each table
            cursor_org3.execute("SELECT COUNT(*) FROM covid_train_data")
            train_count_org3 = cursor_org3.fetchone()[0]

            cursor_org3.execute("SELECT COUNT(*) FROM covid_test_data")
            test_count_org3 = cursor_org3.fetchone()[0]

            # Check for required cloud_image_path column
            cursor_org3.execute("SELECT column_name FROM information_schema.columns WHERE table_name='covid_train_data'")
            columns_org3 = [row[0] for row in cursor_org3.fetchall()]

            if 'cloud_image_path' not in columns_org3:
                logger.error(f"Missing 'cloud_image_path' column in Organization 3 tables")
                logger.error("Please run Assets/Multimodal_Cloud_Vision/2_create_database_tables.py first")
                all_verified = False

            cursor_org3.close()
            conn_org3.close()

            logger.info(f"Verified Organization 3 PostgreSQL tables: covid_train_data ({train_count_org3} records) and covid_test_data ({test_count_org3} records)")
        else:
            missing_tables = []
            if 'covid_train_data' not in tables_org3:
                missing_tables.append('covid_train_data')
            if 'covid_test_data' not in tables_org3:
                missing_tables.append('covid_test_data')

            logger.error(f"Missing Organization 3 PostgreSQL tables: {', '.join(missing_tables)}")
            logger.error("Please run Assets/Multimodal_Cloud_Vision/2_create_database_tables.py first")
            cursor_org3.close()
            conn_org3.close()
            all_verified = False

    except Exception as e:
        logger.error(f"Error verifying Organization 3 PostgreSQL tables: {str(e)}")
        logger.error("Please ensure PostgreSQL is running and the tables are created")
        all_verified = False

    return all_verified


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Position database assets with linked cloud storage")
    parser.add_argument("--run-id", default="001",
                       help="Run ID to append to asset names for uniqueness")
    args = parser.parse_args()

    logger.info("Positioning database assets with linked cloud storage across multiple organizations...")

    # Ensure cloud storage configuration exists
    cloud_info = ensure_cloud_storage_info()
    if not cloud_info.get("upload_success", False):
        logger.warning("Azure storage configuration may not be valid")
        logger.info("This example will still work with placeholder cloud paths")

    # Get PostgreSQL connection information
    postgres_info = get_postgres_connection_info()

    # Verify PostgreSQL tables
    logger.info("Verifying PostgreSQL tables for both organizations...")
    tables_exist = verify_postgres_tables(postgres_info)
    if not tables_exist:
        logger.error("Required PostgreSQL tables are missing. Please run setup scripts first.")
        sys.exit(1)
    else:
        logger.info("Successfully verified PostgreSQL tables for both organizations")

    with Timer("Positioning assets across organizations"):
        position_assets(postgres_info, cloud_info, args.run_id)

    logger.info("\nAssets positioned successfully across organizations. Ready for vertical partition model training!")
