#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import os
import pandas as pd
import argparse
import time
import sys
import logging
from pathlib import Path
from datetime import datetime

import tripleblind as tb
from tripleblind.util.timer import Timer

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def upload_to_azure(image_mappings, container_name):
    try:
        from azure.storage.blob import BlobServiceClient, ContentSettings
    except ImportError:
        logger.error("Azure SDK not installed. Run: pip install azure-storage-blob")
        return False
    
    # Get Azure credentials from environment variables
    storage_account = os.environ.get("AZURE_STORAGE_ACCOUNT")
    storage_key = os.environ.get("AZURE_STORAGE_KEY")
    
    if not storage_account or not storage_key:
        logger.error("Please set AZURE_STORAGE_ACCOUNT and AZURE_STORAGE_KEY environment variables")
        return False
    
    # Create the BlobServiceClient
    try:
        logger.info(f"Connecting to Azure Blob Storage account: {storage_account}")
        blob_service = BlobServiceClient(
            account_url=f"https://{storage_account}.blob.core.windows.net",
            credential=storage_key
        )
        
        # Create container if it doesn't exist
        try:
            container_client = blob_service.get_container_client(container_name)
            container_client.get_container_properties()
            logger.info(f"Container {container_name} already exists")
        except Exception:
            logger.info(f"Creating container: {container_name}")
            blob_service.create_container(container_name)
            
        # Upload each image
        total = len(image_mappings)
        successful_uploads = 0
        failed_uploads = 0
        
        logger.info(f"Starting upload of {total} images to Azure Blob Storage")
        for idx, (local_path, cloud_path) in enumerate(zip(image_mappings["local_path"], image_mappings["cloud_path"])):
            # Progress indicator
            if idx % 10 == 0:
                logger.info(f"Uploading {idx}/{total}...")
            
            try:
                # Verify file exists
                if not os.path.exists(local_path):
                    logger.warning(f"File not found: {local_path}")
                    failed_uploads += 1
                    continue
                
                # Get content type
                content_type = "image/png" if local_path.endswith(".png") else "application/octet-stream"
                
                # Upload the blob
                blob_client = blob_service.get_blob_client(
                    container=container_name,
                    blob=cloud_path
                )
                
                with open(local_path, "rb") as data:
                    blob_client.upload_blob(
                        data,
                        overwrite=True,
                        content_settings=ContentSettings(content_type=content_type)
                    )
                successful_uploads += 1
            except Exception as upload_error:
                logger.error(f"Error uploading {local_path} to {cloud_path}: {upload_error}")
                failed_uploads += 1
        
        # Report results
        if failed_uploads == 0:
            logger.info(f"Successfully uploaded all {successful_uploads} images to Azure Blob Storage")
            return True
        else:
            logger.warning(f"Completed with {successful_uploads} successful uploads and {failed_uploads} failures")
            return successful_uploads > 0  # Return True if at least one upload succeeded
        
    except Exception as e:
        logger.error(f"Error connecting to Azure Blob Storage: {e}")
        return False


def upload_to_s3(image_mappings, bucket_name):
    try:
        import boto3
        from botocore.exceptions import ClientError
    except ImportError:
        logger.error("AWS SDK not installed. Run: pip install boto3")
        return False
    
    # Get AWS credentials from environment variables
    aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
    aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
    aws_region = os.environ.get("AWS_REGION", "us-east-1")
    
    if not aws_access_key or not aws_secret_key:
        logger.error("Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables")
        return False
    
    # Create S3 client
    try:
        logger.info(f"Connecting to AWS S3 in region {aws_region}")
        s3_client = boto3.client(
            's3',
            aws_access_key_id=aws_access_key,
            aws_secret_access_key=aws_secret_key,
            region_name=aws_region
        )
        
        # Check if bucket exists, create if it doesn't
        try:
            s3_client.head_bucket(Bucket=bucket_name)
            logger.info(f"Bucket {bucket_name} already exists")
        except ClientError as e:
            error_code = e.response.get('Error', {}).get('Code', '')
            if error_code == '404':  # Bucket not found
                logger.info(f"Creating bucket: {bucket_name}")
                if aws_region == 'us-east-1':
                    # For us-east-1, we don't specify a LocationConstraint
                    s3_client.create_bucket(Bucket=bucket_name)
                else:
                    # For other regions, we need to specify the LocationConstraint
                    s3_client.create_bucket(
                        Bucket=bucket_name,
                        CreateBucketConfiguration={
                            'LocationConstraint': aws_region
                        }
                    )
            elif error_code == '403':  # Access denied
                logger.error(f"Access denied to bucket {bucket_name}. Check your permissions.")
                return False
            else:
                logger.error(f"Error checking bucket: {e}")
                return False
        
        # Upload each image
        total = len(image_mappings)
        successful_uploads = 0
        failed_uploads = 0
        
        logger.info(f"Starting upload of {total} images to S3")
        for idx, (local_path, cloud_path) in enumerate(zip(image_mappings["local_path"], image_mappings["cloud_path"])):
            # Progress indicator
            if idx % 10 == 0:
                logger.info(f"Uploading {idx}/{total}...")
            
            try:
                # Verify file exists
                if not os.path.exists(local_path):
                    logger.warning(f"File not found: {local_path}")
                    failed_uploads += 1
                    continue
                
                # Get content type
                content_type = "image/png" if local_path.endswith(".png") else "application/octet-stream"
                
                # Upload to S3
                s3_client.upload_file(
                    Filename=local_path,
                    Bucket=bucket_name,
                    Key=cloud_path,
                    ExtraArgs={
                        'ContentType': content_type
                    }
                )
                successful_uploads += 1
            except Exception as upload_error:
                logger.error(f"Error uploading {local_path} to {cloud_path}: {upload_error}")
                failed_uploads += 1
        
        # Report results
        if failed_uploads == 0:
            logger.info(f"Successfully uploaded all {successful_uploads} images to S3")
            return True
        else:
            logger.warning(f"Completed with {successful_uploads} successful uploads and {failed_uploads} failures")
            return successful_uploads > 0  # Return True if at least one upload succeeded
        
    except Exception as e:
        logger.error(f"Error connecting to AWS S3: {e}")
        return False


def save_cloud_info(provider, container_or_bucket, success=True):
    """Save information about the cloud storage for use by subsequent scripts"""
    info = {
        "provider": provider,
        "container_or_bucket": container_or_bucket,
        "upload_success": success,
        "upload_timestamp": datetime.now().isoformat()
    }
    
    # For Azure
    if provider == "azure":
        info["storage_account"] = os.environ.get("AZURE_STORAGE_ACCOUNT", "")
        info["storage_key"] = os.environ.get("AZURE_STORAGE_KEY", "")
    # For AWS
    elif provider == "s3":
        info["aws_access_key_id"] = os.environ.get("AWS_ACCESS_KEY_ID", "")
        info["aws_secret_access_key"] = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
        info["aws_region"] = os.environ.get("AWS_REGION", "us-east-1")
    
    # Save to file
    import json
    with open(tb.config.data_dir / "cloud_storage_info.json", "w") as f:
        json.dump(info, f, indent=2)
    
    print(f"Cloud storage information saved to {tb.config.data_dir / 'cloud_storage_info.json'}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Upload COVID X-ray images to cloud storage")
    parser.add_argument("--provider", choices=["azure", "s3"], default="azure",
                       help="Cloud storage provider (azure or s3)")
    parser.add_argument("--container", default=os.environ.get("AZURE_CONTAINER_NAME", "covid-xrays"),
                       help="Azure container or S3 bucket name")
    parser.add_argument("--max-retries", type=int, default=3,
                       help="Maximum number of retries for failed operations")
    args = parser.parse_args()
    
    # Load image mappings
    mappings_path = tb.config.data_dir / "covid_image_mappings.csv"
    if not mappings_path.exists():
        logger.error(f"Image mappings file not found: {mappings_path}")
        logger.error("Please run 0_prepare_data.py first")
        sys.exit(1)
    
    try:
        image_mappings = pd.read_csv(mappings_path)
    except Exception as e:
        logger.error(f"Error reading image mappings file: {e}")
        sys.exit(1)
    
    # Check if we have images to upload
    if len(image_mappings) == 0:
        logger.error("No images found to upload")
        sys.exit(1)
    
    logger.info(f"Found {len(image_mappings)} images to upload to {args.provider}")
    
    # Set up retry logic
    retry_count = 0
    success = False
    
    while not success and retry_count < args.max_retries:
        if retry_count > 0:
            logger.info(f"Retry attempt {retry_count} of {args.max_retries}")
            time.sleep(2)  # Wait before retrying
            
        with Timer(f"Uploading to {args.provider}"):
            if args.provider == "azure":
                success = upload_to_azure(image_mappings, args.container)
            else:  # s3
                success = upload_to_s3(image_mappings, args.container)
                
        retry_count += 1
    
    if success:
        logger.info(f"Images successfully uploaded to {args.provider}")
        # Save cloud storage info for subsequent scripts
        save_cloud_info(args.provider, args.container, success=True)
        sys.exit(0)
    else:
        logger.error(f"Failed to upload images to {args.provider} after {args.max_retries} attempts")
        save_cloud_info(args.provider, args.container, success=False)
        sys.exit(1)