#!/usr/bin/env python
# Copyright (c) TripleBlind Holdings, Inc. Confidential and Proprietary. All rights reserved.

import sys
import json
import os
import logging
from pathlib import Path

import numpy as np
import pandas as pd
import tripleblind as tb
from tripleblind.util.timer import Timer

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def load_asset_info():
    """Load asset information from previous script"""
    info_path = tb.config.data_dir / "asset_info.json"
    if not info_path.exists():
        print(f"Asset info not found: {info_path}")
        print("Please run 2_position_database_asset.py first")
        sys.exit(1)

    with open(info_path, "r") as f:
        return json.load(f)


def ensure_cloud_storage_info():
    """
    Check if cloud_storage_info.json exists, and create it from environment
    variables if it doesn't. This allows the script to run independently.
    """
    from datetime import datetime

    info_path = tb.config.data_dir / "cloud_storage_info.json"

    # If the file already exists, just return its contents
    if info_path.exists():
        try:
            with open(info_path, "r") as f:
                cloud_info = json.load(f)
            logger.info(f"Using existing cloud storage configuration from {info_path}")
            return cloud_info
        except Exception as e:
            logger.warning(f"Error reading existing cloud storage info: {e}")
            # Fall through to recreate the file

    # File doesn't exist or is invalid - create it from environment variables
    logger.info("Creating cloud storage configuration from environment variables")

    # Check for Azure environment variables
    if os.environ.get("AZURE_STORAGE_ACCOUNT") and os.environ.get("AZURE_STORAGE_KEY"):
        cloud_info = {
            "provider": "azure",
            "storage_account": os.environ.get("AZURE_STORAGE_ACCOUNT"),
            "storage_key": os.environ.get("AZURE_STORAGE_KEY"),
            "container_or_bucket": os.environ.get(
                "AZURE_CONTAINER_NAME", "covid-xrays"
            ),
            "upload_success": True,  # Mark as successful since we're creating it now
            "upload_timestamp": datetime.now().isoformat(),
        }
        logger.info(
            f"Using Azure Blob Storage configuration: account={cloud_info['storage_account']}, container={cloud_info['container_or_bucket']}"
        )

    # This example only supports Azure Blob Storage

    else:
        logger.warning("No Azure storage credentials found in environment variables")
        logger.warning(
            "To use cloud storage with images, set the following environment variables:"
        )
        logger.warning(
            "For Azure: AZURE_STORAGE_ACCOUNT, AZURE_STORAGE_KEY, AZURE_CONTAINER_NAME"
        )

        # Create a minimal placeholder configuration
        cloud_info = {
            "provider": "azure",  # Default provider
            "storage_account": "demo-account",
            "storage_key": "demo-key",
            "container_or_bucket": "covid-xrays",
            "upload_success": True,  # Allow training to proceed
            "upload_timestamp": datetime.now().isoformat(),
        }
        logger.warning(
            "Using placeholder Azure configuration - this example will still work"
        )

    # Save the configuration
    try:
        with open(info_path, "w") as f:
            json.dump(cloud_info, f, indent=2)
        logger.info(f"Cloud storage configuration saved to {info_path}")
    except Exception as e:
        logger.error(f"Error saving cloud storage configuration: {e}")

    return cloud_info


if __name__ == "__main__":
    logger.info(
        "Starting Multimodal Cloud Vision model training with vertically partitioned PostgreSQL data..."
    )

    # Initialize TripleBlind
    tb.initialize(api_token=tb.config.example_user1["token"])

    # Load asset and cloud information
    asset_info = load_asset_info()
    cloud_info = ensure_cloud_storage_info()

    # Find the Organization 2 training asset (first set of features)
    org2_train_asset = tb.Asset.find(asset_info["org2_train_asset_id"])
    org2_test_asset = tb.Asset.find(asset_info["org2_test_asset_id"])
    if not org2_train_asset:
        logger.error(
            f"Organization 2 training asset not found: {asset_info['org2_train_asset_id']}"
        )
        logger.error("Please run 2_position_database_asset.py again")
        sys.exit(1)

    # Find the Organization 3 training asset (second set of features + images)
    org3_train_asset = tb.Asset.find(asset_info["org3_train_asset_id"])
    org3_test_asset = tb.Asset.find(asset_info["org3_test_asset_id"])
    if not org3_train_asset:
        logger.error(
            f"Organization 3 training asset not found: {asset_info['org3_train_asset_id']}"
        )
        logger.error("Please run 2_position_database_asset.py again")
        sys.exit(1)

    logger.info(f"Found Organization 2 training asset: {org2_train_asset.name}")
    logger.info(f"Found Organization 3 training asset: {org3_train_asset.name}")

    # Define the model architecture
    with Timer("Building model architecture"):
        # Tabular branch 1 (Organization 2 - features_0 through features_4)
        org2_tabular_builder = tb.NetworkBuilder()
        org2_tabular_builder.add_dense_layer(5, 16)  # 5 features to 16 nodes
        org2_tabular_builder.add_relu()
        org2_tabular_builder.add_dropout(0.2)  # Lighter dropout
        org2_tabular_builder.add_dense_layer(16, 8)
        org2_tabular_builder.add_relu()
        org2_tabular_builder.add_dropout(0.2)  # Lighter dropout
        org2_tabular_builder.add_split()  # Required split layer

        # Combined branch for Organization 3 (features_5 through features_9 + images)
        org3_combined_builder = tb.NetworkBuilder()
        # Calculate the correct dimensions:
        # 5 numeric features (features 5-9)
        tabular_features = 5
        # 64×64 grayscale image = 4096 pixels
        image_features = 64 * 64  # Grayscale image with channels_first
        # Total features for Org 3
        total_features = tabular_features + image_features
        logger.info(
            f"Organization 3 input dimensions: {tabular_features} tabular + {image_features} image = {total_features} total features"
        )

        # First layer must match exact input dimension
        # PyTorch will create the layers with float32 weights by default
        org3_combined_builder.add_dense_layer(
            total_features, 128
        )  # Simplified architecture
        org3_combined_builder.add_relu()
        org3_combined_builder.add_dropout(0.3)  # Moderate dropout
        org3_combined_builder.add_dense_layer(128, 32)  # Middle layer
        org3_combined_builder.add_relu()
        org3_combined_builder.add_dropout(0.3)  # Moderate dropout
        org3_combined_builder.add_dense_layer(32, 8)  # Output features
        org3_combined_builder.add_relu()
        org3_combined_builder.add_split()  # Required split layer

        # Server model that combines both branches
        server_builder = tb.NetworkBuilder()
        server_builder.add_dense_layer(16, 32)  # Combined input features
        server_builder.add_relu()
        server_builder.add_dropout(0.3)  # Moderate dropout
        server_builder.add_dense_layer(
            32, 3
        )  # 3 classes: COVID, Normal, Viral_Pneumonia

        logger.info(
            "Model architecture defined with two branches for vertically partitioned data (tabular + image)"
        )

        # Log expected dimensions for model
        logger.info("Expected model dimensions:")
        logger.info(f"  Organization 2: 5 features -> 8 features output")
        org3_input = 5 + 64 * 64  # 5 tabular features + 4096 image pixels
        logger.info(f"  Organization 3: {org3_input} features -> 8 features output")
        logger.info(f"  Server: 16 features input (8 + 8) -> 3 classes output")

        # Create vertical network with two components
        model = tb.create_vertical_network(
            "covid_multimodal_net",
            server=server_builder,
            clients=[org2_tabular_builder, org3_combined_builder],
        )

    # Define preprocessors
    with Timer("Creating preprocessors"):
        # Tabular preprocessor for Organization 2 (features_0 through features_4)
        logger.info("Setting up Organization 2 tabular preprocessor")
        org2_tabular_pre = tb.TabularPreprocessor.builder()

        # Create a list of feature columns for Organization 2
        feature_columns = [f"feature_{i}" for i in range(5)]

        # Add Organization 2 feature columns (only numeric data)
        for col in feature_columns:
            org2_tabular_pre.add_column(col)
        logger.info(f"Added feature columns 0-4 for Organization 2")

        # Only Organization 2 will provide the target for the model
        # to avoid "multi-target not supported" error
        org2_tabular_pre.add_column("label", target=True)
        logger.info("Organization 2 will provide the target label")

        # Set data types - be explicit about using float32 for all numeric features
        org2_tabular_pre.dtype("float32")  # Critical: must match model weights dtype
        org2_tabular_pre.target_dtype(
            "int64"
        )  # Target remains int64 for classification

        # Add image_id as a separate string column for debugging only
        # This won't be used in the actual model since it's not numeric
        logger.info(
            "Note: image_id is not included as it's a string and can't be converted to float"
        )
        logger.info(
            "Set feature dtype to float32 and target dtype to int64 for Organization 2"
        )

        # TabularImage preprocessor for Organization 3 (features_5 through features_9 + image)
        logger.info(
            "Setting up Organization 3 TabularImage preprocessor for both tabular features and images"
        )
        org3_combined_pre = tb.TabularImagePreprocessor.builder()

        # Include the 5 numeric feature columns for Organization 3
        feature_columns_org3 = [f"feature_{i}" for i in range(5, 10)]
        for col in feature_columns_org3:
            org3_combined_pre.add_column(col)

        # Configure image processing - these settings are critical for dimension matching
        org3_combined_pre.image_column("image_path")  # Column containing image paths
        org3_combined_pre.resize(
            64, 64
        )  # Must match the dimensions in model (64x64=4096)
        org3_combined_pre.convert("L")  # Grayscale (1 channel)
        org3_combined_pre.channels_first(True)  # Important for tensor shape calculation

        # For Organization 3, we make sure not to include any target column
        logger.info(
            f"Added feature columns {feature_columns_org3} and image processing for Organization 3"
        )
        logger.info("Organization 3 will NOT provide any target label")

        # Set data types - ensure consistency with model weights
        org3_combined_pre.dtype("float32")  # Critical: must match model weights dtype
        logger.info("Set feature dtype to float32 for Organization 3")

        logger.info(
            "All preprocessors created for vertically partitioned data across organizations"
        )

    # Training parameters
    epochs = 5  # Adjust epochs based on model complexity and dataset size
    batch_size = 16  # Standard batch size
    test_size = 0.2

    # Loss function and optimizer
    loss_name = "CrossEntropyLoss"  # For multiclass classification
    optimizer_name = "Adam"
    optimizer_params = {
        "lr": 0.001,
        "weight_decay": 0.0001,
    }  # Standard learning rate, lighter regularization

    # Create and submit the training job
    with Timer("Training model"):
        logger.info("Creating vertical training job with the following configuration:")
        logger.info(
            f"  - Datasets: org2_train (with target), org3_train (tabular features + images)"
        )
        logger.info(
            f"  - Preprocessors: org2_tabular (provides target), org3_combined (tabular + images)"
        )
        logger.info(
            f"  - Epochs: {epochs}, Batch Size: {batch_size}, Test Size: {test_size}"
        )
        logger.info(f"  - Loss: {loss_name}, Optimizer: {optimizer_name}")
        image_size = 64 * 64  # 64x64 grayscale image
        org3_total = 5 + image_size  # 5 tabular features + flattened image
        logger.info(
            f"  - Data shape: [5, {org3_total}] (split between organizations: 5 for org2, {org3_total} for org3)"
        )
        logger.info(f"  - Model output: multiclass (3 classes)")

        # Using our new TabularImagePreprocessor to handle both tabular and image data
        # Create a job with just two assets and two preprocessors
        # Explicitly passing preprocessors ensures our type settings are preserved
        job = tb.create_job(
            job_name="COVID Multimodal Classification with Vertically Partitioned PostgreSQL",
            operation=model,
            dataset=[
                org2_train_asset,
                org3_train_asset,
            ],  # Two distinct assets from two organizations
            preprocessor=[
                org2_tabular_pre,
                org3_combined_pre,
            ],  # Explicitly specify preprocessor types
            params={
                "epochs": epochs,
                "batchsize": batch_size,
                "test_size": test_size,
                "loss_meta": {"name": loss_name},
                "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
                "lr_scheduler_meta": {
                    "name": "CyclicCosineDecayLR",
                    "params": {"init_decay_epochs": 5, "min_decay_lr": 0.0001},
                },
                "data_type": "multimodal",
                "data_shape": [
                    5,
                    5 + 64 * 64,
                ],  # [org2_features, org3_features (tabular + image)]
                "model_output": "multiclass",  # Multi-class classification
                "dtype": "float32",  # Explicitly set dtype for consistency
            },
        )

        if job.submit():
            logger.info("Job submitted successfully! Waiting for completion...")
            job.wait_for_completion()

            if job.success:
                logger.info(f"Training completed successfully!")
                logger.info("")
                logger.info("Trained Network Asset ID:")
                logger.info("    ===============================================")
                logger.info(f"    ===>  {job.result.asset.uuid} <===")
                logger.info("    ===============================================")
                logger.info("    Algorithm: Deep Learning Model")
                logger.info(f"    Job ID:    {job.job_name}")
                logger.info("")

                # Save the result for use in inference
                with open(tb.config.data_dir / "trained_model_info.json", "w") as f:
                    json.dump(
                        {
                            "model_asset_id": str(job.result.asset.uuid),
                            "training_job_id": str(job.id),
                        },
                        f,
                    )

                session_org_2 = tb.Session(
                    api_token=tb.config.example_user2["token"], from_default=True
                )  # Organization 2 - first data owner
                session_org_3 = tb.Session(
                    api_token=tb.config.example_user3["token"], from_default=True
                )

                agreement_org2 = org2_test_asset.add_agreement(
                    with_team=tb.config.example_user1["team_id"],
                    operation=job.result.asset,
                    session=session_org_2
                )

                agreement_org3 = org3_test_asset.add_agreement(
                    with_team=tb.config.example_user1["team_id"],
                    operation=job.result.asset,
                    session=session_org_3
                )

                if agreement_org2 and agreement_org3:
                    logger.info(
                        "Created agreements for both organizations to use the trained model for inference"
                    )
            else:
                logger.error("Model training failed!")
        else:
            logger.error("Failed to submit training job!")
