#!/usr/bin/env python3
"""
Generic script to unroll LSTM layers in any Keras H5 model

This script:
- Finds all LSTM and Bidirectional LSTM layers
- Sets unroll=True on each LSTM layer
- Preserves all weights exactly
- Preserves all custom H5 attributes (Imagimob, etc.)
- Validates the conversion

Note: Only LSTM and Bidirectional LSTM are supported in TFLite Micro
      GRU and SimpleRNN layers are detected but NOT unrolled

Keras 3 Compatibility:
      When using --keras-version 3, the script automatically handles the
      time_major attribute which is not supported in Keras 3

Requirements:
    pip install -r unroll_lstm_requirements.txt

Usage:
    python unroll_lstm_layers.py input_model.h5 output_model.h5
    python unroll_lstm_layers.py input_model.h5  # Adds -unrolled suffix
    python unroll_lstm_layers.py model.h5 --keras-version 3  # Use Keras 3

Options:
    --max-timesteps N    Only unroll if timesteps <= N (default: 20)
    --force              Unroll regardless of timestep count
    --dry-run            Show what would be changed without saving
    --keras-version 2|3  Force Keras version (default: 2, use 3 for Keras 3 models)
"""

import os
import sys
import argparse
import json
import h5py
import numpy as np
from pathlib import Path

# Keras version handling
def setup_keras(keras_version=None):
    """Setup Keras environment"""
    if keras_version == 2 or keras_version is None:
        os.environ['TF_USE_LEGACY_KERAS'] = '1'
    
    import tensorflow as tf
    from tensorflow import keras
    
    return tf, keras

def load_model_with_keras3_compat(model_path, keras_version):
    """Load model with Keras 3 compatibility wrapper"""
    from tensorflow import keras
    import h5py
    import json
    import tempfile
    import shutil
    
    if keras_version == 3:
        # For Keras 3, we need to strip time_major from the H5 file before loading
        # Create a temporary copy with cleaned model_config
        with tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmp_file:
            temp_path = tmp_file.name
        
        try:
            # Copy the H5 file
            shutil.copy2(model_path, temp_path)
            
            # Clean the model_config
            with h5py.File(temp_path, 'r+') as f:
                if 'model_config' in f.attrs:
                    model_config_str = f.attrs['model_config']
                    if isinstance(model_config_str, bytes):
                        model_config_str = model_config_str.decode('utf-8')
                    
                    model_config = json.loads(model_config_str)
                    
                    # Remove time_major from all layers
                    def clean_layer_config(layer_dict):
                        if 'config' in layer_dict:
                            layer_dict['config'].pop('time_major', None)
                            # Handle Bidirectional layers
                            if 'layer' in layer_dict['config'] and isinstance(layer_dict['config']['layer'], dict):
                                if 'config' in layer_dict['config']['layer']:
                                    layer_dict['config']['layer']['config'].pop('time_major', None)
                        return layer_dict
                    
                    if 'config' in model_config and 'layers' in model_config['config']:
                        model_config['config']['layers'] = [
                            clean_layer_config(layer) for layer in model_config['config']['layers']
                        ]
                    
                    # Save cleaned config
                    cleaned_config_str = json.dumps(model_config)
                    del f.attrs['model_config']
                    f.attrs['model_config'] = cleaned_config_str
            
            # Load the cleaned model
            model = keras.models.load_model(temp_path, compile=False)
            
            return model
            
        finally:
            # Clean up temp file
            try:
                import os
                os.unlink(temp_path)
            except:
                pass
    else:
        # Keras 2 - load normally
        return keras.models.load_model(model_path, compile=False)

def analyze_rnn_layers(model):
    """Analyze all RNN layers in the model"""
    # Note: Only LSTM and Bidirectional LSTM are supported in TFLite Micro
    # GRU and SimpleRNN are detected but will not be unrolled
    rnn_types = ['LSTM', 'GRU', 'SimpleRNN', 'Bidirectional']
    rnn_layers = []
    
    for i, layer in enumerate(model.layers):
        layer_type = type(layer).__name__
        
        # Handle Keras 3 compatibility wrappers
        if layer_type.endswith('_Keras3Compat'):
            layer_type = layer_type.replace('_Keras3Compat', '')
        
        if any(rnn in layer_type for rnn in rnn_types):
            config = layer.get_config()
            
            # Get timesteps from input shape (handle both Keras 2 and 3)
            try:
                input_shape = layer.input_shape
            except AttributeError:
                # Fallback for Keras 3 compatibility wrappers
                if hasattr(layer, 'input'):
                    input_shape = layer.input.shape
                else:
                    input_shape = (None, None, None)
            timesteps = input_shape[1] if len(input_shape) > 1 else None
            
            # For Bidirectional, check the wrapped layer
            if layer_type == 'Bidirectional':
                # Keras 3 uses different attribute access
                try:
                    wrapped_layer = layer.layer  # Keras 2
                except AttributeError:
                    wrapped_layer = layer._layers[0]  # Keras 3
                
                wrapped_config = wrapped_layer.get_config()
                current_unroll = wrapped_config.get('unroll', False)
                wrapped_type = type(wrapped_layer).__name__
            else:
                current_unroll = config.get('unroll', False)
                wrapped_type = None
            
            rnn_layers.append({
                'index': i,
                'name': layer.name,
                'type': layer_type,
                'wrapped_type': wrapped_type,
                'current_unroll': current_unroll,
                'timesteps': timesteps,
                'input_shape': input_shape,
                'config': config
            })
    
    return rnn_layers

def should_unroll(layer_info, max_timesteps, force):
    """Determine if a layer should be unrolled"""
    # Only LSTM and Bidirectional LSTM are supported in TFLite Micro
    layer_type = layer_info['type']
    wrapped_type = layer_info.get('wrapped_type')
    
    # Check if this is a supported layer type
    is_lstm = layer_type == 'LSTM'
    is_bidirectional_lstm = (layer_type == 'Bidirectional' and wrapped_type == 'LSTM')
    
    if not (is_lstm or is_bidirectional_lstm):
        # GRU and SimpleRNN are not supported in TFLite Micro
        return False, f"{layer_type} not supported in TFLite Micro (only LSTM)"
    
    if force:
        return True, "forced by user"
    
    if layer_info['current_unroll']:
        return False, "already unrolled"
    
    timesteps = layer_info['timesteps']
    
    if timesteps is None:
        return False, "variable timesteps (None) - cannot unroll"
    
    if timesteps <= max_timesteps:
        return True, f"{timesteps} timesteps <= {max_timesteps} limit"
    else:
        return False, f"{timesteps} timesteps > {max_timesteps} limit"

def clean_config_for_keras3(config, keras_version):
    """Remove Keras 2 specific attributes that are not supported in Keras 3"""
    if keras_version == 3:
        # Keras 3 doesn't support time_major parameter
        config.pop('time_major', None)
        
        # If this is a Bidirectional config, clean the wrapped layer config too
        if 'layer' in config and isinstance(config['layer'], dict):
            if 'config' in config['layer']:
                config['layer']['config'].pop('time_major', None)
    
    return config

def recreate_model_with_unrolled_layers(original_model, layers_to_unroll, keras_version=2):
    """Recreate model with specified layers unrolled"""
    new_model = original_model.__class__(name=original_model.name)
    
    # Check if Sequential or Functional API
    is_sequential = isinstance(original_model, original_model.__class__) and \
                   original_model.__class__.__name__ == 'Sequential'
    
    if not is_sequential:
        raise ValueError("Only Sequential models are supported. Functional API models require manual conversion.")
    
    for i, layer in enumerate(original_model.layers):
        layer_config = layer.get_config()
        layer_type = type(layer).__name__
        
        # Check if this layer should be unrolled
        should_modify = any(info['index'] == i for info in layers_to_unroll)
        
        # Remove batch-related config that might cause issues
        layer_config.pop('batch_input_shape', None)
        layer_config.pop('batch_size', None)
        
        # Clean config for Keras 3 compatibility
        layer_config = clean_config_for_keras3(layer_config, keras_version)
        
        if should_modify and layer_type == 'LSTM':
            # Set unroll=True for LSTM
            layer_config['unroll'] = True
            layer_class = type(layer)
            new_layer = layer_class.from_config(layer_config)
            new_model.add(new_layer)
            
        elif should_modify and layer_type == 'Bidirectional':
            # For Bidirectional LSTM, unroll the wrapped LSTM layer
            wrapped_config = layer_config['layer']['config']
            wrapped_config['unroll'] = True
            layer_config['layer']['config'] = wrapped_config
            layer_class = type(layer)
            new_layer = layer_class.from_config(layer_config)
            new_model.add(new_layer)
            
        else:
            # Note: GRU and SimpleRNN are not modified (not supported in TFLite Micro)
            # Keep layer as-is
            try:
                layer_class = type(layer)
                new_layer = layer_class.from_config(layer_config)
                new_model.add(new_layer)
            except Exception as e:
                # If from_config fails, use the original layer
                new_model.add(layer)
    
    # Build the model
    input_shape = original_model.input_shape[1:]
    new_model.build(input_shape=(None,) + input_shape)
    
    return new_model

def copy_weights(original_model, new_model):
    """Copy weights from original to new model"""
    weights_copied = 0
    weights_failed = []
    
    for orig_layer, new_layer in zip(original_model.layers, new_model.layers):
        try:
            weights = orig_layer.get_weights()
            if weights:
                new_layer.set_weights(weights)
                weights_copied += 1
        except Exception as e:
            weights_failed.append((orig_layer.name, str(e)))
    
    return weights_copied, weights_failed

def extract_h5_attributes(h5_path):
    """Extract all custom attributes from H5 file"""
    attributes = {}
    
    with h5py.File(h5_path, 'r') as f:
        for attr_name in f.attrs.keys():
            attributes[attr_name] = f.attrs[attr_name]
    
    return attributes

def restore_h5_attributes(h5_path, attributes):
    """Restore custom attributes to H5 file"""
    restored_count = 0
    failed = []
    
    with h5py.File(h5_path, 'r+') as f:
        for attr_name, attr_value in attributes.items():
            try:
                if attr_name in f.attrs:
                    del f.attrs[attr_name]
                f.attrs[attr_name] = attr_value
                restored_count += 1
            except Exception as e:
                failed.append((attr_name, str(e)))
    
    return restored_count, failed

def fix_unroll_in_model_config(h5_path, layer_names):
    """Fix unroll parameter in model_config JSON"""
    with h5py.File(h5_path, 'r+') as f:
        model_config_str = f.attrs['model_config']
        if isinstance(model_config_str, bytes):
            model_config_str = model_config_str.decode('utf-8')
        
        model_config = json.loads(model_config_str)
        
        # Find and update RNN layers
        layers = model_config['config']['layers']
        updated_count = 0
        
        for layer in layers:
            layer_name = layer['config'].get('name', '')
            
            if layer_name in layer_names:
                # Handle LSTM layers (only LSTM is supported in TFLite Micro)
                if layer['class_name'] == 'LSTM':
                    layer['config']['unroll'] = True
                    updated_count += 1
                # Handle Bidirectional LSTM
                elif layer['class_name'] == 'Bidirectional':
                    if 'layer' in layer['config'] and 'config' in layer['config']['layer']:
                        layer['config']['layer']['config']['unroll'] = True
                        updated_count += 1
        
        # Save updated config
        updated_config_str = json.dumps(model_config)
        del f.attrs['model_config']
        f.attrs['model_config'] = updated_config_str
        
        return updated_count

def validate_conversion(original_model, new_model, num_samples=10):
    """Validate that outputs match"""
    input_shape = original_model.input_shape[1:]
    test_data = np.random.randn(num_samples, *input_shape).astype(np.float32)
    
    orig_output = original_model.predict(test_data, verbose=0)
    new_output = new_model.predict(test_data, verbose=0)
    
    max_diff = np.max(np.abs(orig_output - new_output))
    mean_diff = np.mean(np.abs(orig_output - new_output))
    
    return max_diff, mean_diff

def main():
    parser = argparse.ArgumentParser(
        description='Unroll LSTM layers in Keras H5 models for TFLite Micro',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s model.h5 model_unrolled.h5
  %(prog)s model.h5 --max-timesteps 10
  %(prog)s model.h5 --force
  %(prog)s model.h5 --dry-run
        """
    )
    
    parser.add_argument('input_model', help='Input H5 model file')
    parser.add_argument('output_model', nargs='?', help='Output H5 model file (default: input-unrolled.h5)')
    parser.add_argument('--max-timesteps', type=int, default=20,
                       help='Only unroll if timesteps <= N (default: 20)')
    parser.add_argument('--force', action='store_true',
                       help='Unroll regardless of timestep count')
    parser.add_argument('--dry-run', action='store_true',
                       help='Show what would be changed without saving')
    parser.add_argument('--keras-version', type=int, choices=[2, 3], default=2,
                       help='Keras version to use (default: 2). Use 3 for Keras 3 models (handles time_major)')
    
    args = parser.parse_args()
    
    # Setup paths
    input_path = Path(args.input_model)
    if not input_path.exists():
        print(f"Error: Input file not found: {input_path}")
        return 1
    
    if args.output_model:
        output_path = Path(args.output_model)
    else:
        output_path = input_path.parent / f"{input_path.stem}-unrolled{input_path.suffix}"
    
    print("="*80)
    print("LSTM Layer Unroller for TFLite Micro")
    print("="*80)
    print(f"\nInput:  {input_path}")
    print(f"Output: {output_path}")
    print(f"Max timesteps: {args.max_timesteps}")
    print(f"Force: {args.force}")
    print(f"Dry run: {args.dry_run}")
    
    # Setup Keras
    print("\n" + "-"*80)
    print("Step 1: Loading model")
    print("-"*80)
    
    tf, keras = setup_keras(args.keras_version)
    print(f"TensorFlow: {tf.__version__}")
    
    try:
        model = load_model_with_keras3_compat(str(input_path), args.keras_version)
        print(f"Model loaded: {model.name}")
        if args.keras_version == 3:
            print(f"  (Keras 3 mode: time_major attribute handled)")
        print(f"  Input shape: {model.input_shape}")
        print(f"  Output shape: {model.output_shape}")
        print(f"  Total layers: {len(model.layers)}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return 1
    
    # Analyze RNN layers
    print("\n" + "-"*80)
    print("Step 2: Analyzing RNN layers")
    print("-"*80)
    
    rnn_layers = analyze_rnn_layers(model)
    
    if not rnn_layers:
        print("No RNN layers found in model")
        return 0
    
    print(f"\nFound {len(rnn_layers)} RNN layer(s):")
    
    layers_to_unroll = []
    
    for layer_info in rnn_layers:
        should_unroll_flag, reason = should_unroll(layer_info, args.max_timesteps, args.force)
        
        status = "[UNROLL]" if should_unroll_flag else "[SKIP]"
        print(f"\n  {status} Layer {layer_info['index']}: {layer_info['name']}")
        print(f"    Type: {layer_info['type']}")
        if layer_info['wrapped_type']:
            print(f"    Wrapped: {layer_info['wrapped_type']}")
        print(f"    Current unroll: {layer_info['current_unroll']}")
        print(f"    Timesteps: {layer_info['timesteps']}")
        print(f"    Input shape: {layer_info['input_shape']}")
        print(f"    Decision: {reason}")
        
        if should_unroll_flag:
            layers_to_unroll.append(layer_info)
    
    if not layers_to_unroll:
        print("\nNo layers need to be unrolled")
        return 0
    
    print(f"\nWill unroll {len(layers_to_unroll)} layer(s)")
    
    if args.dry_run:
        print("\n" + "="*80)
        print("DRY RUN - No changes made")
        print("="*80)
        return 0
    
    # Extract custom attributes
    print("\n" + "-"*80)
    print("Step 3: Extracting custom attributes")
    print("-"*80)
    
    custom_attrs = extract_h5_attributes(str(input_path))
    print(f"Extracted {len(custom_attrs)} custom attributes")
    
    # [IMAGIMOB SUPPORT] Detect Imagimob-specific attributes (silently preserved)
    # These include: imedge_config, immodel_config, imunit_config, infineon_config, 
    # syntiant_config, Test_tf_cm, Train_tf_cm, Validation_tf_cm, mapping, imcm_list
    imagimob_attrs = [k for k in custom_attrs.keys() 
                     if any(x in k.lower() for x in ['imedge', 'immodel', 'imunit', 'syntiant', 'infineon'])]
    # End Imagimob detection (attributes will be silently copied)
    
    # Recreate model
    print("\n" + "-"*80)
    print("Step 4: Recreating model with unrolled layers")
    print("-"*80)
    
    try:
        new_model = recreate_model_with_unrolled_layers(model, layers_to_unroll, args.keras_version)
        print("New model created")
        if args.keras_version == 3:
            print("  (Using Keras 3 compatibility mode - time_major removed)")
    except Exception as e:
        print(f"Error creating new model: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    # Copy weights
    print("\n" + "-"*80)
    print("Step 5: Copying weights")
    print("-"*80)
    
    weights_copied, weights_failed = copy_weights(model, new_model)
    print(f"Copied weights for {weights_copied} layers")
    
    if weights_failed:
        print(f"Warning: Failed to copy {len(weights_failed)} layers:")
        for layer_name, error in weights_failed:
            print(f"  • {layer_name}: {error}")
    
    # Save model
    print("\n" + "-"*80)
    print("Step 6: Saving model")
    print("-"*80)
    
    try:
        new_model.save(str(output_path))
        print(f"Model saved to: {output_path}")
    except Exception as e:
        print(f"Error saving model: {e}")
        return 1
    
    # [IMAGIMOB SUPPORT] Restore all custom attributes (including Imagimob configs)
    print("\n" + "-"*80)
    print("Step 7: Restoring custom attributes")
    print("-"*80)
    
    restored_count, restore_failed = restore_h5_attributes(str(output_path), custom_attrs)
    print(f"Restored {restored_count} custom attributes")
    
    if restore_failed:
        print(f"Warning: Failed to restore {len(restore_failed)} attributes:")
        for attr_name, error in restore_failed[:5]:
            print(f"  • {attr_name}: {error}")
    # End Imagimob attribute restoration
    
    # Fix model_config to ensure unroll parameter persists on reload
    print("\n" + "-"*80)
    print("Step 8: Fixing model_config JSON")
    print("-"*80)
    
    layer_names = [info['name'] for info in layers_to_unroll]
    fixed_count = fix_unroll_in_model_config(str(output_path), layer_names)
    print(f"Updated {fixed_count} layer(s) in model_config")
    
    # Validate
    print("\n" + "-"*80)
    print("Step 9: Validation")
    print("-"*80)
    
    try:
        verification_model = load_model_with_keras3_compat(str(output_path), args.keras_version)
        print("Output model loads successfully")
        
        # Check unroll status
        print("\nVerifying unroll status:")
        for layer_info in layers_to_unroll:
            layer = verification_model.layers[layer_info['index']]
            config = layer.get_config()
            
            if type(layer).__name__ == 'Bidirectional':
                unroll = config['layer']['config'].get('unroll', False)
            else:
                unroll = config.get('unroll', False)
            
            status = "OK" if unroll else "FAILED"
            print(f"  [{status}] {layer.name}: unroll={unroll}")
        
        # Validate outputs
        print("\nValidating inference:")
        max_diff, mean_diff = validate_conversion(model, verification_model)
        print(f"  Max difference: {max_diff:.10f}")
        print(f"  Mean difference: {mean_diff:.10f}")
        
        if max_diff < 1e-5:
            print("  Outputs match perfectly!")
        elif max_diff < 1e-3:
            print("  Warning: Small differences detected (acceptable)")
        else:
            print("  Warning: Significant differences detected!")
        
    except Exception as e:
        print(f"Validation error: {e}")
        return 1
    
    # Summary
    print("\n" + "="*80)
    print("SUCCESS")
    print("="*80)
    print(f"\nUnrolled {len(layers_to_unroll)} RNN layer(s)")
    print(f"Preserved {restored_count} custom attributes")
    print(f"Output: {output_path}")
    print("\nThe model is ready to use!")
    
    return 0

if __name__ == '__main__':
    sys.exit(main())
