From 087b9b24d17a36f4528f2cfcdef5acbbac88ca07 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 1 Sep 2025 19:11:56 +0530 Subject: [PATCH 01/19] Update backbone.py --- keras_hub/src/models/backbone.py | 99 ++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 55aaec239d..51b36518f2 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -293,3 +293,102 @@ def export_to_transformers(self, path): ) export_backbone(self, path) + + def _get_save_spec(self, dynamic_batch=True): + """Compatibility shim for Keras/TensorFlow saving utilities. + + TensorFlow's SavedModel / TFLite export paths expect a + `_get_save_spec` method on subclassed models. In some runtime + combinations this method may not be present on the MRO for + our `Backbone` subclass; add a small shim that first delegates to + the superclass, and falls back to constructing simple + `tf.TensorSpec` objects from the functional `inputs` if needed. + + Args: + dynamic_batch: whether to set the batch dimension to `None`. + + Returns: + A TensorSpec, list or dict mirroring the model inputs, or + `None` when specs cannot be inferred. + """ + # Prefer the base implementation if available. + try: + return super()._get_save_spec(dynamic_batch) + except AttributeError: + # Fall back to building specs from `self.inputs`. + try: + from tensorflow.python.framework import tensor_spec + except (ImportError, ModuleNotFoundError): + return None + + inputs = getattr(self, "inputs", None) + if inputs is None: + return None + + def _make_spec(t): + # t is a tf.Tensor-like object + shape = list(t.shape) + if dynamic_batch and len(shape) > 0: + shape[0] = None + # Convert to tuple for TensorSpec + try: + name = getattr(t, "name", None) + return tensor_spec.TensorSpec( + shape=tuple(shape), dtype=t.dtype, name=name + ) + except (ImportError, ModuleNotFoundError): + return None + + # Handle dict/list/single tensor inputs + if isinstance(inputs, dict): + return {k: _make_spec(v) for k, v in inputs.items()} + if isinstance(inputs, (list, tuple)): + return [_make_spec(t) for t in inputs] + return _make_spec(inputs) + + def _trackable_children(self, save_type=None, **kwargs): + """Override to prevent _DictWrapper issues during TensorFlow export. + + This method filters out problematic _DictWrapper objects that cause + TypeError during SavedModel introspection, while preserving all + essential trackable components. + """ + children = super()._trackable_children(save_type, **kwargs) + + # Import _DictWrapper safely + try: + from tensorflow.python.trackable.data_structures import _DictWrapper + except ImportError: + return children + + clean_children = {} + for name, child in children.items(): + # Handle _DictWrapper objects + if isinstance(child, _DictWrapper): + try: + # For list-like _DictWrapper (e.g., transformer_layers) + if hasattr(child, '_data') and isinstance(child._data, list): + # Create a clean list of the trackable items + clean_list = [] + for item in child._data: + if hasattr(item, '_trackable_children'): + clean_list.append(item) + if clean_list: + clean_children[name] = clean_list + # For dict-like _DictWrapper + elif hasattr(child, '_data') and isinstance(child._data, dict): + clean_dict = {} + for k, v in child._data.items(): + if hasattr(v, '_trackable_children'): + clean_dict[k] = v + if clean_dict: + clean_children[name] = clean_dict + # Skip if we can't unwrap safely + except (AttributeError, TypeError): + # Skip problematic _DictWrapper objects + continue + else: + # Keep non-_DictWrapper children as-is + clean_children[name] = child + + return clean_children From de830b1f038c35608885cbb706c2136c7341e9e5 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 1 Sep 2025 19:19:17 +0530 Subject: [PATCH 02/19] Update backbone.py --- keras_hub/src/models/backbone.py | 57 ++++++++++++++++---------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 51b36518f2..4a0c61ae20 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -66,6 +66,18 @@ def __setattr__(self, name, value): is_property = isinstance(getattr(type(self), name, None), property) is_unitialized = not hasattr(self, "_initialized") simple_setattr = keras.config.backend() == "torch" + + # Prevent _DictWrapper creation for transformer_layers + if name == "transformer_layers" and isinstance(value, list): + # Use a trackable list wrapper instead of regular list + try: + # Create a proper trackable list + from tensorflow.python.trackable.data_structures import ListWrapper + value = ListWrapper(value) + except ImportError: + # Fallback: keep as regular list + pass + if simple_setattr and (is_property or is_unitialized): return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -349,11 +361,14 @@ def _make_spec(t): def _trackable_children(self, save_type=None, **kwargs): """Override to prevent _DictWrapper issues during TensorFlow export. - This method filters out problematic _DictWrapper objects that cause - TypeError during SavedModel introspection, while preserving all - essential trackable components. + This method ensures clean trackable object traversal by avoiding + problematic _DictWrapper objects that cause SavedModel export errors. """ - children = super()._trackable_children(save_type, **kwargs) + try: + children = super()._trackable_children(save_type, **kwargs) + except Exception: + # If parent fails, return minimal trackable children + children = {} # Import _DictWrapper safely try: @@ -363,32 +378,16 @@ def _trackable_children(self, save_type=None, **kwargs): clean_children = {} for name, child in children.items(): - # Handle _DictWrapper objects - if isinstance(child, _DictWrapper): - try: - # For list-like _DictWrapper (e.g., transformer_layers) - if hasattr(child, '_data') and isinstance(child._data, list): - # Create a clean list of the trackable items - clean_list = [] - for item in child._data: - if hasattr(item, '_trackable_children'): - clean_list.append(item) - if clean_list: - clean_children[name] = clean_list - # For dict-like _DictWrapper - elif hasattr(child, '_data') and isinstance(child._data, dict): - clean_dict = {} - for k, v in child._data.items(): - if hasattr(v, '_trackable_children'): - clean_dict[k] = v - if clean_dict: - clean_children[name] = clean_dict - # Skip if we can't unwrap safely - except (AttributeError, TypeError): - # Skip problematic _DictWrapper objects + try: + # Skip _DictWrapper objects entirely to avoid introspection issues + if hasattr(child, '__class__') and '_DictWrapper' in child.__class__.__name__: continue - else: - # Keep non-_DictWrapper children as-is + + # Test if child supports introspection safely + _ = getattr(child, '__dict__', None) clean_children[name] = child + except (TypeError, AttributeError): + # Skip objects that cause introspection errors + continue return clean_children From 62d2484f0b76b359ae5754aedadc6204a5846c13 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 1 Sep 2025 19:25:13 +0530 Subject: [PATCH 03/19] Update task.py --- keras_hub/src/models/task.py | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index d273759b46..1a01cbdb24 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -76,6 +76,17 @@ def __setattr__(self, name, value): is_property = isinstance(getattr(type(self), name, None), property) is_unitialized = not hasattr(self, "_initialized") is_torch = keras.config.backend() == "torch" + + # Prevent _DictWrapper creation for list attributes + if isinstance(value, list) and hasattr(self, "_initialized"): + # Use a trackable list wrapper instead of regular list + try: + from tensorflow.python.trackable.data_structures import ListWrapper + value = ListWrapper(value) + except ImportError: + # Fallback: keep as regular list + pass + if is_torch and (is_property or is_unitialized): return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -369,3 +380,37 @@ def add_layer(layer, info): print_fn=print_fn, **kwargs, ) + + def _trackable_children(self, save_type=None, **kwargs): + """Override to prevent _DictWrapper issues during TensorFlow export. + + This method ensures clean trackable object traversal by avoiding + problematic _DictWrapper objects that cause SavedModel export errors. + """ + try: + children = super()._trackable_children(save_type, **kwargs) + except Exception: + # If parent fails, return minimal trackable children + children = {} + + # Import _DictWrapper safely + try: + from tensorflow.python.trackable.data_structures import _DictWrapper + except ImportError: + return children + + clean_children = {} + for name, child in children.items(): + try: + # Skip _DictWrapper objects entirely to avoid introspection issues + if hasattr(child, '__class__') and '_DictWrapper' in child.__class__.__name__: + continue + + # Test if child supports introspection safely + _ = getattr(child, '__dict__', None) + clean_children[name] = child + except (TypeError, AttributeError): + # Skip objects that cause introspection errors + continue + + return clean_children From 3b71125f58616a66b1549c1b088e3b665576139e Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 2 Sep 2025 09:55:21 +0530 Subject: [PATCH 04/19] Revert "Update task.py" This reverts commit 62d2484f0b76b359ae5754aedadc6204a5846c13. --- keras_hub/src/models/task.py | 45 ------------------------------------ 1 file changed, 45 deletions(-) diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 1a01cbdb24..d273759b46 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -76,17 +76,6 @@ def __setattr__(self, name, value): is_property = isinstance(getattr(type(self), name, None), property) is_unitialized = not hasattr(self, "_initialized") is_torch = keras.config.backend() == "torch" - - # Prevent _DictWrapper creation for list attributes - if isinstance(value, list) and hasattr(self, "_initialized"): - # Use a trackable list wrapper instead of regular list - try: - from tensorflow.python.trackable.data_structures import ListWrapper - value = ListWrapper(value) - except ImportError: - # Fallback: keep as regular list - pass - if is_torch and (is_property or is_unitialized): return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -380,37 +369,3 @@ def add_layer(layer, info): print_fn=print_fn, **kwargs, ) - - def _trackable_children(self, save_type=None, **kwargs): - """Override to prevent _DictWrapper issues during TensorFlow export. - - This method ensures clean trackable object traversal by avoiding - problematic _DictWrapper objects that cause SavedModel export errors. - """ - try: - children = super()._trackable_children(save_type, **kwargs) - except Exception: - # If parent fails, return minimal trackable children - children = {} - - # Import _DictWrapper safely - try: - from tensorflow.python.trackable.data_structures import _DictWrapper - except ImportError: - return children - - clean_children = {} - for name, child in children.items(): - try: - # Skip _DictWrapper objects entirely to avoid introspection issues - if hasattr(child, '__class__') and '_DictWrapper' in child.__class__.__name__: - continue - - # Test if child supports introspection safely - _ = getattr(child, '__dict__', None) - clean_children[name] = child - except (TypeError, AttributeError): - # Skip objects that cause introspection errors - continue - - return clean_children From 3d453ff9ead0cb7a7aa1f502917d0db75899215a Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 2 Sep 2025 09:55:27 +0530 Subject: [PATCH 05/19] Revert "Update backbone.py" This reverts commit de830b1f038c35608885cbb706c2136c7341e9e5. --- keras_hub/src/models/backbone.py | 57 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 4a0c61ae20..51b36518f2 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -66,18 +66,6 @@ def __setattr__(self, name, value): is_property = isinstance(getattr(type(self), name, None), property) is_unitialized = not hasattr(self, "_initialized") simple_setattr = keras.config.backend() == "torch" - - # Prevent _DictWrapper creation for transformer_layers - if name == "transformer_layers" and isinstance(value, list): - # Use a trackable list wrapper instead of regular list - try: - # Create a proper trackable list - from tensorflow.python.trackable.data_structures import ListWrapper - value = ListWrapper(value) - except ImportError: - # Fallback: keep as regular list - pass - if simple_setattr and (is_property or is_unitialized): return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -361,14 +349,11 @@ def _make_spec(t): def _trackable_children(self, save_type=None, **kwargs): """Override to prevent _DictWrapper issues during TensorFlow export. - This method ensures clean trackable object traversal by avoiding - problematic _DictWrapper objects that cause SavedModel export errors. + This method filters out problematic _DictWrapper objects that cause + TypeError during SavedModel introspection, while preserving all + essential trackable components. """ - try: - children = super()._trackable_children(save_type, **kwargs) - except Exception: - # If parent fails, return minimal trackable children - children = {} + children = super()._trackable_children(save_type, **kwargs) # Import _DictWrapper safely try: @@ -378,16 +363,32 @@ def _trackable_children(self, save_type=None, **kwargs): clean_children = {} for name, child in children.items(): - try: - # Skip _DictWrapper objects entirely to avoid introspection issues - if hasattr(child, '__class__') and '_DictWrapper' in child.__class__.__name__: + # Handle _DictWrapper objects + if isinstance(child, _DictWrapper): + try: + # For list-like _DictWrapper (e.g., transformer_layers) + if hasattr(child, '_data') and isinstance(child._data, list): + # Create a clean list of the trackable items + clean_list = [] + for item in child._data: + if hasattr(item, '_trackable_children'): + clean_list.append(item) + if clean_list: + clean_children[name] = clean_list + # For dict-like _DictWrapper + elif hasattr(child, '_data') and isinstance(child._data, dict): + clean_dict = {} + for k, v in child._data.items(): + if hasattr(v, '_trackable_children'): + clean_dict[k] = v + if clean_dict: + clean_children[name] = clean_dict + # Skip if we can't unwrap safely + except (AttributeError, TypeError): + # Skip problematic _DictWrapper objects continue - - # Test if child supports introspection safely - _ = getattr(child, '__dict__', None) + else: + # Keep non-_DictWrapper children as-is clean_children[name] = child - except (TypeError, AttributeError): - # Skip objects that cause introspection errors - continue return clean_children From 92b1254aef54fb05d19d4d5014fe5646e16387c7 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 9 Sep 2025 15:03:28 +0530 Subject: [PATCH 06/19] export export working 1st commit --- keras_hub/src/exporters/__init__.py | 85 +++++++++ keras_hub/src/exporters/base.py | 262 ++++++++++++++++++++++++++++ keras_hub/src/exporters/configs.py | 169 ++++++++++++++++++ keras_hub/src/exporters/lite_rt.py | 225 ++++++++++++++++++++++++ keras_hub/src/models/__init__.py | 12 ++ 5 files changed, 753 insertions(+) create mode 100644 keras_hub/src/exporters/__init__.py create mode 100644 keras_hub/src/exporters/base.py create mode 100644 keras_hub/src/exporters/configs.py create mode 100644 keras_hub/src/exporters/lite_rt.py diff --git a/keras_hub/src/exporters/__init__.py b/keras_hub/src/exporters/__init__.py new file mode 100644 index 0000000000..d8ed946e7f --- /dev/null +++ b/keras_hub/src/exporters/__init__.py @@ -0,0 +1,85 @@ +"""Keras-Hub exporters module. + +This module provides export functionality for Keras-Hub models to various formats. +It follows a clean OOP design with proper separation of concerns. +""" + +from keras_hub.src.exporters.base import ( + KerasHubExporterConfig, + KerasHubExporter, + ExporterRegistry +) +from keras_hub.src.exporters.configs import ( + CausalLMExporterConfig, + TextClassifierExporterConfig, + Seq2SeqLMExporterConfig, + TextModelExporterConfig +) +from keras_hub.src.exporters.lite_rt import ( + LiteRTExporter, + export_lite_rt +) + +# Register configurations for different model types +ExporterRegistry.register_config("causal_lm", CausalLMExporterConfig) +ExporterRegistry.register_config("text_classifier", TextClassifierExporterConfig) +ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) +ExporterRegistry.register_config("text_model", TextModelExporterConfig) + +# Register exporters for different formats +ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) + + +def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): + """Export a Keras-Hub model to the specified format. + + This is the main export function that automatically detects the model type + and uses the appropriate exporter configuration. + + Args: + model: The Keras-Hub model to export + filepath: Path where to save the exported model (without extension) + format: Export format (currently supports "lite_rt") + **kwargs: Additional arguments passed to the exporter + """ + # Get the appropriate configuration for this model + config = ExporterRegistry.get_config_for_model(model) + + # Get the exporter for the specified format + exporter = ExporterRegistry.get_exporter(format, config, **kwargs) + + # Export the model + exporter.export(filepath) + + +# Add export method to Task base class +def _add_export_method_to_task(): + """Add the export method to the Task base class.""" + try: + from keras_hub.src.models.task import Task + + def export(self, filepath: str, format: str = "lite_rt", **kwargs) -> None: + """Export the model to the specified format. + + Args: + filepath: str. Path where to save the exported model (without extension) + format: str. Export format. Currently supports "lite_rt" + **kwargs: Additional arguments passed to the exporter + """ + export_model(self, filepath, format=format, **kwargs) + + # Add the method to the Task class if it doesn't exist + if not hasattr(Task, 'export'): + Task.export = export + print("โœ… Added export method to Task base class") + else: + # Override the existing method to use our Keras-Hub specific implementation + Task.export = export + print("โœ… Overrode export method in Task base class with Keras-Hub implementation") + + except Exception as e: + print(f"โš ๏ธ Failed to add export method to Task class: {e}") + + +# Auto-initialize when this module is imported +_add_export_method_to_task() diff --git a/keras_hub/src/exporters/base.py b/keras_hub/src/exporters/base.py new file mode 100644 index 0000000000..e8087e0b80 --- /dev/null +++ b/keras_hub/src/exporters/base.py @@ -0,0 +1,262 @@ +"""Base classes for Keras-Hub model exporters. + +This module provides the foundation for exporting Keras-Hub models to various formats. +It follows the Optimum pattern of having different exporters for different model types and formats. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Union, List +import sys + +# Add the keras path to import from local repo +sys.path.insert(0, '/Users/hellorahul/Projects/keras') + +try: + import keras + from keras.src.export.export_utils import get_input_signature + KERAS_AVAILABLE = True +except ImportError: + KERAS_AVAILABLE = False + keras = None + + +class KerasHubExporterConfig(ABC): + """Base configuration class for Keras-Hub model exporters. + + This class defines the interface for exporter configurations that specify + how different types of Keras-Hub models should be exported. + """ + + # Model type this exporter handles (e.g., "causal_lm", "text_classifier") + MODEL_TYPE: str = None + + # Expected input structure for this model type + EXPECTED_INPUTS: List[str] = [] + + # Default sequence length if not specified + DEFAULT_SEQUENCE_LENGTH: int = 128 + + def __init__(self, model, **kwargs): + """Initialize the exporter configuration. + + Args: + model: The Keras-Hub model to export + **kwargs: Additional configuration parameters + """ + self.model = model + self.config = kwargs + self._validate_model() + + def _validate_model(self): + """Validate that the model is compatible with this exporter.""" + if not self._is_model_compatible(): + raise ValueError( + f"Model {type(self.model)} is not compatible with " + f"{self.__class__.__name__} (expected {self.MODEL_TYPE})" + ) + + @abstractmethod + def _is_model_compatible(self) -> bool: + """Check if the model is compatible with this exporter.""" + pass + + @abstractmethod + def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Get the input signature for the model. + + Args: + sequence_length: Optional sequence length override + + Returns: + Dictionary mapping input names to their specifications + """ + pass + + def get_dummy_inputs(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Generate dummy inputs for model tracing. + + Args: + sequence_length: Optional sequence length override + + Returns: + Dictionary of dummy inputs for the model + """ + if sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + dummy_inputs = {} + + if "token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + if "padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') + if "encoder_token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["encoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + if "encoder_padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["encoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') + if "decoder_token_ids" in self.EXPECTED_INPUTS: + dummy_inputs["decoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + if "decoder_padding_mask" in self.EXPECTED_INPUTS: + dummy_inputs["decoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') + + return dummy_inputs + + +class KerasHubExporter(ABC): + """Base class for Keras-Hub model exporters. + + This class provides the common interface for exporting Keras-Hub models + to different formats (LiteRT, ONNX, etc.). + """ + + def __init__(self, config: KerasHubExporterConfig, **kwargs): + """Initialize the exporter. + + Args: + config: Exporter configuration specifying model type and parameters + **kwargs: Additional exporter-specific parameters + """ + self.config = config + self.model = config.model + self.export_kwargs = kwargs + + @abstractmethod + def export(self, filepath: str) -> None: + """Export the model to the specified filepath. + + Args: + filepath: Path where to save the exported model + """ + pass + + def _ensure_model_built(self, sequence_length: Optional[int] = None): + """Ensure the model is properly built with correct input structure. + + Args: + sequence_length: Optional sequence length for dummy inputs + """ + if not self.model.built: + print("๐Ÿ”ง Building model with sample inputs...") + + dummy_inputs = self.config.get_dummy_inputs(sequence_length) + + try: + # Build the model with the correct input structure + _ = self.model(dummy_inputs, training=False) + print("โœ… Model built successfully") + except Exception as e: + print(f"โš ๏ธ Model building failed: {e}") + # Try alternative approach + try: + input_shapes = {key: tensor.shape for key, tensor in dummy_inputs.items()} + self.model.build(input_shape=input_shapes) + print("โœ… Model built using .build() method") + except Exception as e2: + print(f"โŒ Alternative building method also failed: {e2}") + raise + + +class ExporterRegistry: + """Registry for mapping model types to their appropriate exporters.""" + + _configs = {} + _exporters = {} + + @classmethod + def register_config(cls, model_type: str, config_class: type): + """Register an exporter configuration for a model type. + + Args: + model_type: The model type identifier (e.g., "causal_lm") + config_class: The configuration class for this model type + """ + cls._configs[model_type] = config_class + + @classmethod + def register_exporter(cls, format_name: str, exporter_class: type): + """Register an exporter for a specific format. + + Args: + format_name: The export format identifier (e.g., "lite_rt") + exporter_class: The exporter class for this format + """ + cls._exporters[format_name] = exporter_class + + @classmethod + def get_config_for_model(cls, model) -> KerasHubExporterConfig: + """Get the appropriate configuration for a model. + + Args: + model: The Keras-Hub model + + Returns: + An appropriate exporter configuration + + Raises: + ValueError: If no suitable configuration is found + """ + # Try to detect model type + model_type = cls._detect_model_type(model) + + if model_type not in cls._configs: + raise ValueError(f"No exporter configuration found for model type: {model_type}") + + config_class = cls._configs[model_type] + return config_class(model) + + @classmethod + def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs): + """Get an exporter for the specified format. + + Args: + format_name: The export format + config: The exporter configuration + **kwargs: Additional parameters for the exporter + + Returns: + An appropriate exporter instance + """ + if format_name not in cls._exporters: + raise ValueError(f"No exporter found for format: {format_name}") + + exporter_class = cls._exporters[format_name] + return exporter_class(config, **kwargs) + + @classmethod + def _detect_model_type(cls, model) -> str: + """Detect the model type from the model instance. + + Args: + model: The Keras-Hub model + + Returns: + The detected model type + """ + # Import here to avoid circular imports + try: + from keras_hub.src.models.causal_lm import CausalLM + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + except ImportError: + CausalLM = None + Seq2SeqLM = None + + model_class_name = model.__class__.__name__ + + if CausalLM and isinstance(model, CausalLM): + return "causal_lm" + elif 'TextClassifier' in model_class_name: + return "text_classifier" + elif Seq2SeqLM and isinstance(model, Seq2SeqLM): + return "seq2seq_lm" + elif 'ImageClassifier' in model_class_name: + return "image_classifier" + else: + # Fallback to text model if it has a preprocessor with tokenizer + if hasattr(model, 'preprocessor') and model.preprocessor: + if hasattr(model.preprocessor, 'tokenizer'): + return "text_model" + + return "unknown" diff --git a/keras_hub/src/exporters/configs.py b/keras_hub/src/exporters/configs.py new file mode 100644 index 0000000000..e6461df89a --- /dev/null +++ b/keras_hub/src/exporters/configs.py @@ -0,0 +1,169 @@ +"""Configuration classes for different Keras-Hub model types. + +This module provides specific configurations for exporting different types +of Keras-Hub models, following the Optimum pattern. +""" + +from typing import Dict, Any, Optional +from keras_hub.src.exporters.base import KerasHubExporterConfig +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.exporters.CausalLMExporterConfig") +class CausalLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" + + MODEL_TYPE = "causal_lm" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self) -> bool: + """Check if model is a causal language model.""" + try: + from keras_hub.src.models.causal_lm import CausalLM + return isinstance(self.model, CausalLM) + except ImportError: + # Fallback to class name checking + return 'CausalLM' in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Get input signature for causal LM models.""" + if sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + import keras + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='int32', + name='token_ids' + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='bool', + name='padding_mask' + ) + } + + +@keras_hub_export("keras_hub.exporters.TextClassifierExporterConfig") +class TextClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Text Classification models.""" + + MODEL_TYPE = "text_classifier" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self) -> bool: + """Check if model is a text classifier.""" + return 'TextClassifier' in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Get input signature for text classifier models.""" + if sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + import keras + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='int32', + name='token_ids' + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='bool', + name='padding_mask' + ) + } + + +@keras_hub_export("keras_hub.exporters.Seq2SeqLMExporterConfig") +class Seq2SeqLMExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Sequence-to-Sequence Language Models.""" + + MODEL_TYPE = "seq2seq_lm" + EXPECTED_INPUTS = ["encoder_token_ids", "encoder_padding_mask", "decoder_token_ids", "decoder_padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self) -> bool: + """Check if model is a seq2seq language model.""" + try: + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + return isinstance(self.model, Seq2SeqLM) + except ImportError: + return 'Seq2SeqLM' in self.model.__class__.__name__ + + def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Get input signature for seq2seq models.""" + if sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + import keras + return { + "encoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='int32', + name='encoder_token_ids' + ), + "encoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='bool', + name='encoder_padding_mask' + ), + "decoder_token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='int32', + name='decoder_token_ids' + ), + "decoder_padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='bool', + name='decoder_padding_mask' + ) + } + + +@keras_hub_export("keras_hub.exporters.TextModelExporterConfig") +class TextModelExporterConfig(KerasHubExporterConfig): + """Generic exporter configuration for text models.""" + + MODEL_TYPE = "text_model" + EXPECTED_INPUTS = ["token_ids", "padding_mask"] + DEFAULT_SEQUENCE_LENGTH = 128 + + def _is_model_compatible(self) -> bool: + """Check if model is a text model (fallback).""" + # This is a fallback config for text models that don't fit other categories + return hasattr(self.model, 'preprocessor') and self.model.preprocessor and hasattr(self.model.preprocessor, 'tokenizer') + + def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + """Get input signature for generic text models.""" + if sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + else: + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + + import keras + return { + "token_ids": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='int32', + name='token_ids' + ), + "padding_mask": keras.layers.InputSpec( + shape=(None, sequence_length), + dtype='bool', + name='padding_mask' + ) + } diff --git a/keras_hub/src/exporters/lite_rt.py b/keras_hub/src/exporters/lite_rt.py new file mode 100644 index 0000000000..b2c5d35939 --- /dev/null +++ b/keras_hub/src/exporters/lite_rt.py @@ -0,0 +1,225 @@ +"""LiteRT exporter for Keras-Hub models. + +This module provides LiteRT export functionality specifically designed for Keras-Hub models, +handling their unique input structures and requirements. +""" + +import sys +from typing import Optional + +# Add the keras path to import from local repo +sys.path.insert(0, '/Users/hellorahul/Projects/keras') + +from keras_hub.src.exporters.base import KerasHubExporter, KerasHubExporterConfig +from keras_hub.src.api_export import keras_hub_export + +try: + from keras.src.export.lite_rt_exporter import LiteRTExporter as KerasLiteRTExporter + KERAS_LITE_RT_AVAILABLE = True +except ImportError: + KERAS_LITE_RT_AVAILABLE = False + KerasLiteRTExporter = None + + +@keras_hub_export("keras_hub.exporters.LiteRTExporter") +class LiteRTExporter(KerasHubExporter): + """LiteRT exporter for Keras-Hub models. + + This exporter handles the conversion of Keras-Hub models to TensorFlow Lite format, + properly managing the dictionary input structures that Keras-Hub models expect. + """ + + def __init__(self, config: KerasHubExporterConfig, + max_sequence_length: Optional[int] = None, + aot_compile_targets: Optional[list] = None, + verbose: Optional[int] = None, + **kwargs): + """Initialize the LiteRT exporter. + + Args: + config: Exporter configuration for the model + max_sequence_length: Maximum sequence length for conversion + aot_compile_targets: List of AOT compilation targets + verbose: Verbosity level + **kwargs: Additional arguments passed to the underlying exporter + """ + super().__init__(config, **kwargs) + + if not KERAS_LITE_RT_AVAILABLE: + raise ImportError( + "Keras LiteRT exporter is not available. " + "Make sure you have Keras with LiteRT support installed." + ) + + self.max_sequence_length = max_sequence_length + self.aot_compile_targets = aot_compile_targets + self.verbose = verbose or 0 + + # Get sequence length from model if not provided + if self.max_sequence_length is None: + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + self.max_sequence_length = getattr( + self.model.preprocessor, + 'sequence_length', + self.config.DEFAULT_SEQUENCE_LENGTH + ) + else: + self.max_sequence_length = self.config.DEFAULT_SEQUENCE_LENGTH + + def export(self, filepath: str) -> None: + """Export the Keras-Hub model to LiteRT format. + + Args: + filepath: Path where to save the exported model (without extension) + """ + if self.verbose: + print(f"๐Ÿš€ Starting LiteRT export for {self.config.MODEL_TYPE} model...") + print(f" Model: {type(self.model).__name__}") + print(f" Expected inputs: {self.config.EXPECTED_INPUTS}") + print(f" Sequence length: {self.max_sequence_length}") + + # Ensure model is built with correct input structure + self._ensure_model_built(self.max_sequence_length) + + # Get the proper input signature for this model type + input_signature = self.config.get_input_signature(self.max_sequence_length) + + if self.verbose: + print(f" Input signature: {list(input_signature.keys())}") + + # Create a wrapper that adapts the Keras-Hub model to work with Keras LiteRT exporter + wrapped_model = self._create_export_wrapper() + + # Create the Keras LiteRT exporter with the wrapped model + keras_exporter = KerasLiteRTExporter( + wrapped_model, + input_signature=input_signature, + max_sequence_length=self.max_sequence_length, + aot_compile_targets=self.aot_compile_targets, + verbose=self.verbose, + **self.export_kwargs + ) + + try: + # Export using the Keras exporter + keras_exporter.export(filepath) + + if self.verbose: + print(f"โœ… Export completed successfully!") + print(f"๐Ÿ“ Model saved to: {filepath}.tflite") + + except Exception as e: + if self.verbose: + print(f"โŒ Export failed: {e}") + raise + + def _create_export_wrapper(self): + """Create a wrapper model that handles the input structure conversion. + + This wrapper converts between the list-based inputs that Keras LiteRT exporter + provides and the dictionary-based inputs that Keras-Hub models expect. + """ + import keras + + class KerasHubModelWrapper(keras.Model): + """Wrapper that adapts Keras-Hub models for export.""" + + def __init__(self, keras_hub_model, expected_inputs, input_signature, verbose=False): + super().__init__() + self.keras_hub_model = keras_hub_model + self.expected_inputs = expected_inputs + self.input_signature = input_signature + self.verbose = verbose + + # Create Input layers based on the input signature + self._input_layers = [] + for input_name in expected_inputs: + if input_name in input_signature: + spec = input_signature[input_name] + # Ensure we preserve the correct dtype + input_layer = keras.layers.Input( + shape=spec.shape[1:], # Remove batch dimension + dtype=spec.dtype, + name=input_name + ) + self._input_layers.append(input_layer) + + if self.verbose: + print(f"Created input layer: {input_name} - shape={spec.shape} dtype={spec.dtype}") + + # Store references to the original model's variables + self._variables = keras_hub_model.variables + self._trainable_variables = keras_hub_model.trainable_variables + self._non_trainable_variables = keras_hub_model.non_trainable_variables + + @property + def variables(self): + return self._variables + + @property + def trainable_variables(self): + return self._trainable_variables + + @property + def non_trainable_variables(self): + return self._non_trainable_variables + + @property + def inputs(self): + """Return the input layers for the Keras exporter to use.""" + return self._input_layers + + def call(self, inputs, training=None, mask=None): + """Convert list inputs to dictionary format and call the original model.""" + if isinstance(inputs, dict): + # Already in dictionary format + return self.keras_hub_model(inputs, training=training, mask=mask) + + # Convert list inputs to dictionary format + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + # Map inputs to expected dictionary structure + input_dict = {} + for i, input_name in enumerate(self.expected_inputs): + if i < len(inputs): + input_dict[input_name] = inputs[i] + else: + # Handle missing inputs - this shouldn't happen but let's be safe + print(f"โš ๏ธ Missing input for {input_name}") + + return self.keras_hub_model(input_dict, training=training, mask=mask) + + def get_config(self): + """Return the configuration of the wrapped model.""" + return self.keras_hub_model.get_config() + + return KerasHubModelWrapper( + self.model, + self.config.EXPECTED_INPUTS, + self.config.get_input_signature(self.max_sequence_length), + verbose=self.verbose + ) + + +# Convenience function for direct export +@keras_hub_export("keras_hub.exporters.export_lite_rt") +def export_lite_rt(model, filepath: str, **kwargs) -> None: + """Export a Keras-Hub model to LiteRT format. + + This is a convenience function that automatically detects the model type + and exports it using the appropriate configuration. + + Args: + model: The Keras-Hub model to export + filepath: Path where to save the exported model (without extension) + **kwargs: Additional arguments passed to the exporter + """ + from keras_hub.src.exporters.base import ExporterRegistry + + # Get the appropriate configuration for this model + config = ExporterRegistry.get_config_for_model(model) + + # Create and use the LiteRT exporter + exporter = LiteRTExporter(config, **kwargs) + exporter.export(filepath) diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index e69de29bb2..649c8413b4 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -0,0 +1,12 @@ +"""Import and initialize Keras-Hub export functionality. + +This module automatically extends Keras-Hub models with export capabilities +when imported. +""" + +# Import the exporters functionality +try: + from keras_hub.src.exporters import * + # The __init__.py file automatically adds the export method to Task base class +except ImportError as e: + print(f"โš ๏ธ Failed to import Keras-Hub export functionality: {e}") From e46241d48a5c0a05f50028e841f24effd0b41fbb Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 10 Sep 2025 10:15:04 +0530 Subject: [PATCH 07/19] refactoring --- keras_hub/src/export/__init__.py | 9 +++ keras_hub/src/{exporters => export}/base.py | 62 ++++++++----------- .../src/{exporters => export}/configs.py | 10 +-- .../src/{exporters => export}/lite_rt.py | 12 ++-- .../__init__.py => export/registry.py} | 51 +++++++-------- keras_hub/src/models/__init__.py | 9 ++- 6 files changed, 71 insertions(+), 82 deletions(-) create mode 100644 keras_hub/src/export/__init__.py rename keras_hub/src/{exporters => export}/base.py (80%) rename keras_hub/src/{exporters => export}/configs.py (95%) rename keras_hub/src/{exporters => export}/lite_rt.py (96%) rename keras_hub/src/{exporters/__init__.py => export/registry.py} (58%) diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py new file mode 100644 index 0000000000..f2063779ef --- /dev/null +++ b/keras_hub/src/export/__init__.py @@ -0,0 +1,9 @@ +from keras_hub.src.export.base import ExporterRegistry +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.export.base import KerasHubExporterConfig +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import TextModelExporterConfig +from keras_hub.src.export.lite_rt import export_lite_rt +from keras_hub.src.export.lite_rt import LiteRTExporter diff --git a/keras_hub/src/exporters/base.py b/keras_hub/src/export/base.py similarity index 80% rename from keras_hub/src/exporters/base.py rename to keras_hub/src/export/base.py index e8087e0b80..63a32952e2 100644 --- a/keras_hub/src/exporters/base.py +++ b/keras_hub/src/export/base.py @@ -6,10 +6,6 @@ from abc import ABC, abstractmethod from typing import Dict, Any, Optional, Union, List -import sys - -# Add the keras path to import from local repo -sys.path.insert(0, '/Users/hellorahul/Projects/keras') try: import keras @@ -44,15 +40,15 @@ def __init__(self, model, **kwargs): **kwargs: Additional configuration parameters """ self.model = model - self.config = kwargs + self.config_kwargs = kwargs self._validate_model() def _validate_model(self): """Validate that the model is compatible with this exporter.""" if not self._is_model_compatible(): raise ValueError( - f"Model {type(self.model)} is not compatible with " - f"{self.__class__.__name__} (expected {self.MODEL_TYPE})" + f"Model {self.model.__class__.__name__} is not compatible " + f"with {self.__class__.__name__}" ) @abstractmethod @@ -62,37 +58,37 @@ def _is_model_compatible(self) -> bool: @abstractmethod def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Get the input signature for the model. + """Get the input signature for this model type. Args: - sequence_length: Optional sequence length override + sequence_length: Optional sequence length for input tensors Returns: - Dictionary mapping input names to their specifications + Dictionary mapping input names to their signatures """ pass def get_dummy_inputs(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Generate dummy inputs for model tracing. + """Generate dummy inputs for model building and testing. Args: - sequence_length: Optional sequence length override + sequence_length: Optional sequence length for dummy inputs Returns: - Dictionary of dummy inputs for the model + Dictionary of dummy inputs """ if sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) - else: - sequence_length = self.DEFAULT_SEQUENCE_LENGTH - + sequence_length = self.DEFAULT_SEQUENCE_LENGTH + dummy_inputs = {} + # Common inputs for most Keras-Hub models if "token_ids" in self.EXPECTED_INPUTS: dummy_inputs["token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') if "padding_mask" in self.EXPECTED_INPUTS: dummy_inputs["padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') + + # Encoder-decoder specific inputs if "encoder_token_ids" in self.EXPECTED_INPUTS: dummy_inputs["encoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') if "encoder_padding_mask" in self.EXPECTED_INPUTS: @@ -167,21 +163,21 @@ class ExporterRegistry: @classmethod def register_config(cls, model_type: str, config_class: type): - """Register an exporter configuration for a model type. + """Register a configuration class for a model type. Args: - model_type: The model type identifier (e.g., "causal_lm") - config_class: The configuration class for this model type + model_type: The model type (e.g., "causal_lm") + config_class: The configuration class """ cls._configs[model_type] = config_class - + @classmethod def register_exporter(cls, format_name: str, exporter_class: type): - """Register an exporter for a specific format. + """Register an exporter class for a format. Args: - format_name: The export format identifier (e.g., "lite_rt") - exporter_class: The exporter class for this format + format_name: The export format (e.g., "lite_rt") + exporter_class: The exporter class """ cls._exporters[format_name] = exporter_class @@ -193,16 +189,12 @@ def get_config_for_model(cls, model) -> KerasHubExporterConfig: model: The Keras-Hub model Returns: - An appropriate exporter configuration - - Raises: - ValueError: If no suitable configuration is found + An appropriate exporter configuration instance """ - # Try to detect model type model_type = cls._detect_model_type(model) if model_type not in cls._configs: - raise ValueError(f"No exporter configuration found for model type: {model_type}") + raise ValueError(f"No configuration found for model type: {model_type}") config_class = cls._configs[model_type] return config_class(model) @@ -254,9 +246,5 @@ def _detect_model_type(cls, model) -> str: elif 'ImageClassifier' in model_class_name: return "image_classifier" else: - # Fallback to text model if it has a preprocessor with tokenizer - if hasattr(model, 'preprocessor') and model.preprocessor: - if hasattr(model.preprocessor, 'tokenizer'): - return "text_model" - - return "unknown" + # Default to text model for generic Keras-Hub models + return "text_model" diff --git a/keras_hub/src/exporters/configs.py b/keras_hub/src/export/configs.py similarity index 95% rename from keras_hub/src/exporters/configs.py rename to keras_hub/src/export/configs.py index e6461df89a..803c7c3a38 100644 --- a/keras_hub/src/exporters/configs.py +++ b/keras_hub/src/export/configs.py @@ -5,11 +5,11 @@ """ from typing import Dict, Any, Optional -from keras_hub.src.exporters.base import KerasHubExporterConfig +from keras_hub.src.export.base import KerasHubExporterConfig from keras_hub.src.api_export import keras_hub_export -@keras_hub_export("keras_hub.exporters.CausalLMExporterConfig") +@keras_hub_export("keras_hub.export.CausalLMExporterConfig") class CausalLMExporterConfig(KerasHubExporterConfig): """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" @@ -49,7 +49,7 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str } -@keras_hub_export("keras_hub.exporters.TextClassifierExporterConfig") +@keras_hub_export("keras_hub.export.TextClassifierExporterConfig") class TextClassifierExporterConfig(KerasHubExporterConfig): """Exporter configuration for Text Classification models.""" @@ -84,7 +84,7 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str } -@keras_hub_export("keras_hub.exporters.Seq2SeqLMExporterConfig") +@keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") class Seq2SeqLMExporterConfig(KerasHubExporterConfig): """Exporter configuration for Sequence-to-Sequence Language Models.""" @@ -133,7 +133,7 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str } -@keras_hub_export("keras_hub.exporters.TextModelExporterConfig") +@keras_hub_export("keras_hub.export.TextModelExporterConfig") class TextModelExporterConfig(KerasHubExporterConfig): """Generic exporter configuration for text models.""" diff --git a/keras_hub/src/exporters/lite_rt.py b/keras_hub/src/export/lite_rt.py similarity index 96% rename from keras_hub/src/exporters/lite_rt.py rename to keras_hub/src/export/lite_rt.py index b2c5d35939..3bf5a89a04 100644 --- a/keras_hub/src/exporters/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -4,13 +4,9 @@ handling their unique input structures and requirements. """ -import sys from typing import Optional -# Add the keras path to import from local repo -sys.path.insert(0, '/Users/hellorahul/Projects/keras') - -from keras_hub.src.exporters.base import KerasHubExporter, KerasHubExporterConfig +from keras_hub.src.export.base import KerasHubExporter, KerasHubExporterConfig from keras_hub.src.api_export import keras_hub_export try: @@ -21,7 +17,7 @@ KerasLiteRTExporter = None -@keras_hub_export("keras_hub.exporters.LiteRTExporter") +@keras_hub_export("keras_hub.export.LiteRTExporter") class LiteRTExporter(KerasHubExporter): """LiteRT exporter for Keras-Hub models. @@ -203,7 +199,7 @@ def get_config(self): # Convenience function for direct export -@keras_hub_export("keras_hub.exporters.export_lite_rt") +@keras_hub_export("keras_hub.export.export_lite_rt") def export_lite_rt(model, filepath: str, **kwargs) -> None: """Export a Keras-Hub model to LiteRT format. @@ -215,7 +211,7 @@ def export_lite_rt(model, filepath: str, **kwargs) -> None: filepath: Path where to save the exported model (without extension) **kwargs: Additional arguments passed to the exporter """ - from keras_hub.src.exporters.base import ExporterRegistry + from keras_hub.src.export.base import ExporterRegistry # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) diff --git a/keras_hub/src/exporters/__init__.py b/keras_hub/src/export/registry.py similarity index 58% rename from keras_hub/src/exporters/__init__.py rename to keras_hub/src/export/registry.py index d8ed946e7f..56df0e31b6 100644 --- a/keras_hub/src/exporters/__init__.py +++ b/keras_hub/src/export/registry.py @@ -1,33 +1,28 @@ -"""Keras-Hub exporters module. +"""Registry initialization for Keras-Hub export functionality. -This module provides export functionality for Keras-Hub models to various formats. -It follows a clean OOP design with proper separation of concerns. +This module initializes the export registry with available configurations and exporters. """ -from keras_hub.src.exporters.base import ( - KerasHubExporterConfig, - KerasHubExporter, - ExporterRegistry -) -from keras_hub.src.exporters.configs import ( +from keras_hub.src.export.base import ExporterRegistry +from keras_hub.src.export.configs import ( CausalLMExporterConfig, TextClassifierExporterConfig, Seq2SeqLMExporterConfig, TextModelExporterConfig ) -from keras_hub.src.exporters.lite_rt import ( - LiteRTExporter, - export_lite_rt -) +from keras_hub.src.export.lite_rt import LiteRTExporter + -# Register configurations for different model types -ExporterRegistry.register_config("causal_lm", CausalLMExporterConfig) -ExporterRegistry.register_config("text_classifier", TextClassifierExporterConfig) -ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) -ExporterRegistry.register_config("text_model", TextModelExporterConfig) +def initialize_export_registry(): + """Initialize the export registry with available configurations and exporters.""" + # Register configurations for different model types + ExporterRegistry.register_config("causal_lm", CausalLMExporterConfig) + ExporterRegistry.register_config("text_classifier", TextClassifierExporterConfig) + ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) + ExporterRegistry.register_config("text_model", TextModelExporterConfig) -# Register exporters for different formats -ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) + # Register exporters for different formats + ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): @@ -42,6 +37,9 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): format: Export format (currently supports "lite_rt") **kwargs: Additional arguments passed to the exporter """ + # Ensure registry is initialized + initialize_export_registry() + # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) @@ -52,8 +50,7 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): exporter.export(filepath) -# Add export method to Task base class -def _add_export_method_to_task(): +def add_export_method_to_task(): """Add the export method to the Task base class.""" try: from keras_hub.src.models.task import Task @@ -71,15 +68,11 @@ def export(self, filepath: str, format: str = "lite_rt", **kwargs) -> None: # Add the method to the Task class if it doesn't exist if not hasattr(Task, 'export'): Task.export = export - print("โœ… Added export method to Task base class") - else: - # Override the existing method to use our Keras-Hub specific implementation - Task.export = export - print("โœ… Overrode export method in Task base class with Keras-Hub implementation") except Exception as e: print(f"โš ๏ธ Failed to add export method to Task class: {e}") -# Auto-initialize when this module is imported -_add_export_method_to_task() +# Initialize the registry when this module is imported +initialize_export_registry() +add_export_method_to_task() diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index 649c8413b4..43ccfbd194 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -4,9 +4,12 @@ when imported. """ -# Import the exporters functionality +# Import the export functionality try: - from keras_hub.src.exporters import * - # The __init__.py file automatically adds the export method to Task base class + from keras_hub.src.export.registry import add_export_method_to_task + from keras_hub.src.export.registry import initialize_export_registry + # Initialize export functionality + initialize_export_registry() + add_export_method_to_task() except ImportError as e: print(f"โš ๏ธ Failed to import Keras-Hub export functionality: {e}") From 6e970e2a76f29656418d09fbefa09b8dba5f2918 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 10 Sep 2025 10:30:34 +0530 Subject: [PATCH 08/19] refactor --- keras_hub/src/export/registry.py | 96 +++++++++++++++++++++++++++----- keras_hub/src/models/__init__.py | 4 +- 2 files changed, 85 insertions(+), 15 deletions(-) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index 56df0e31b6..0a0dd50706 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -25,7 +25,7 @@ def initialize_export_registry(): ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) -def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): +def export_model(model, filepath: str, format: str = "lite_rt", verbose=None, **kwargs): """Export a Keras-Hub model to the specified format. This is the main export function that automatically detects the model type @@ -35,11 +35,16 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) format: Export format (currently supports "lite_rt") + verbose: Whether to print verbose output during export **kwargs: Additional arguments passed to the exporter """ # Ensure registry is initialized initialize_export_registry() + # Pass verbose parameter to the exporter + if verbose is not None: + kwargs['verbose'] = verbose + # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) @@ -48,31 +53,96 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): # Export the model exporter.export(filepath) + """Export a Keras-Hub model to the specified format. + + This is the main export function that automatically detects the model type + and uses the appropriate exporter configuration. + + Args: + model: The Keras-Hub model to export + filepath: Path where to save the exported model (without extension) + format: Export format (currently supports "lite_rt") + verbose: Whether to print verbose output + **kwargs: Additional arguments passed to the exporter + """ + # Ensure registry is initialized + initialize_export_registry() + + # Get the appropriate configuration for this model + config = ExporterRegistry.get_config_for_model(model) + + # Get the exporter for the specified format + exporter = ExporterRegistry.get_exporter(format, config, verbose=verbose, **kwargs) + + # Export the model + exporter.export(filepath) -def add_export_method_to_task(): - """Add the export method to the Task base class.""" +def extend_export_method_for_keras_hub(): + """Extend the export method for Keras-Hub models to handle dictionary inputs.""" try: from keras_hub.src.models.task import Task + import keras - def export(self, filepath: str, format: str = "lite_rt", **kwargs) -> None: - """Export the model to the specified format. + # Store the original export method + original_export = Task.export if hasattr(Task, 'export') else keras.Model.export + + def keras_hub_export(self, filepath: str, format: str = "lite_rt", verbose=None, **kwargs): + """Extended export method for Keras-Hub models. + + This method extends Keras' export functionality to properly handle + Keras-Hub models that expect dictionary inputs. Args: filepath: str. Path where to save the exported model (without extension) - format: str. Export format. Currently supports "lite_rt" + format: str. Export format. Supports "lite_rt", "tf_saved_model", etc. + verbose: bool. Whether to print verbose output during export **kwargs: Additional arguments passed to the exporter """ - export_model(self, filepath, format=format, **kwargs) + # Check if this is a Keras-Hub model that needs special handling + if format == "lite_rt" and self._is_keras_hub_model(): + # Use our Keras-Hub specific export logic + export_model(self, filepath, format=format, verbose=verbose, **kwargs) + else: + # Fall back to the original Keras export method + original_export(self, filepath, format=format, verbose=verbose, **kwargs) + + def _is_keras_hub_model(self): + """Check if this model is a Keras-Hub model that needs special handling.""" + # Check if it's a Task (most Keras-Hub models inherit from Task) + if hasattr(self, '__class__'): + class_name = self.__class__.__name__ + module_name = self.__class__.__module__ + + # Check if it's from keras_hub package + if 'keras_hub' in module_name: + return True + + # Check if it has keras-hub specific attributes + if hasattr(self, 'preprocessor') and hasattr(self, 'backbone'): + return True + + # Check for common Keras-Hub model names + keras_hub_model_names = ['CausalLM', 'Seq2SeqLM', 'TextClassifier', 'ImageClassifier'] + if any(name in class_name for name in keras_hub_model_names): + return True + + return False + + # Add the helper method to the class + Task._is_keras_hub_model = _is_keras_hub_model + + # Override the export method + Task.export = keras_hub_export + + print("โœ… Extended export method for Keras-Hub models") - # Add the method to the Task class if it doesn't exist - if not hasattr(Task, 'export'): - Task.export = export - except Exception as e: - print(f"โš ๏ธ Failed to add export method to Task class: {e}") + print(f"โš ๏ธ Failed to extend export method for Keras-Hub models: {e}") + import traceback + traceback.print_exc() # Initialize the registry when this module is imported initialize_export_registry() -add_export_method_to_task() +extend_export_method_for_keras_hub() diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index 43ccfbd194..896e87678e 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -6,10 +6,10 @@ # Import the export functionality try: - from keras_hub.src.export.registry import add_export_method_to_task + from keras_hub.src.export.registry import extend_export_method_for_keras_hub from keras_hub.src.export.registry import initialize_export_registry # Initialize export functionality initialize_export_registry() - add_export_method_to_task() + extend_export_method_for_keras_hub() except ImportError as e: print(f"โš ๏ธ Failed to import Keras-Hub export functionality: {e}") From 15ad9f370309e107fd6ab98f070b096c6854cff0 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 10 Sep 2025 11:01:48 +0530 Subject: [PATCH 09/19] Update registry.py --- keras_hub/src/export/registry.py | 37 ++++++-------------------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index 0a0dd50706..8856735592 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -25,7 +25,7 @@ def initialize_export_registry(): ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) -def export_model(model, filepath: str, format: str = "lite_rt", verbose=None, **kwargs): +def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): """Export a Keras-Hub model to the specified format. This is the main export function that automatically detects the model type @@ -35,16 +35,11 @@ def export_model(model, filepath: str, format: str = "lite_rt", verbose=None, ** model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) format: Export format (currently supports "lite_rt") - verbose: Whether to print verbose output during export - **kwargs: Additional arguments passed to the exporter + **kwargs: Additional arguments passed to the exporter (including verbose) """ # Ensure registry is initialized initialize_export_registry() - # Pass verbose parameter to the exporter - if verbose is not None: - kwargs['verbose'] = verbose - # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) @@ -53,29 +48,6 @@ def export_model(model, filepath: str, format: str = "lite_rt", verbose=None, ** # Export the model exporter.export(filepath) - """Export a Keras-Hub model to the specified format. - - This is the main export function that automatically detects the model type - and uses the appropriate exporter configuration. - - Args: - model: The Keras-Hub model to export - filepath: Path where to save the exported model (without extension) - format: Export format (currently supports "lite_rt") - verbose: Whether to print verbose output - **kwargs: Additional arguments passed to the exporter - """ - # Ensure registry is initialized - initialize_export_registry() - - # Get the appropriate configuration for this model - config = ExporterRegistry.get_config_for_model(model) - - # Get the exporter for the specified format - exporter = ExporterRegistry.get_exporter(format, config, verbose=verbose, **kwargs) - - # Export the model - exporter.export(filepath) def extend_export_method_for_keras_hub(): @@ -102,7 +74,10 @@ def keras_hub_export(self, filepath: str, format: str = "lite_rt", verbose=None, # Check if this is a Keras-Hub model that needs special handling if format == "lite_rt" and self._is_keras_hub_model(): # Use our Keras-Hub specific export logic - export_model(self, filepath, format=format, verbose=verbose, **kwargs) + # Make sure we don't duplicate the verbose parameter + if verbose is not None and 'verbose' not in kwargs: + kwargs['verbose'] = verbose + export_model(self, filepath, format=format, **kwargs) else: # Fall back to the original Keras export method original_export(self, filepath, format=format, verbose=verbose, **kwargs) From 02ca0d97485cbd870b2f98fa70d9db9df0a188f6 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 15 Sep 2025 14:18:07 +0530 Subject: [PATCH 10/19] Refactor export logic and improve error handling Refactored exporter and registry logic for better type safety and error handling. Improved input signature methods in config classes by extracting sequence length logic. Enhanced LiteRT exporter with clearer verbose handling and stricter error reporting. Registry now conditionally registers LiteRT exporter and extends export method only if dependencies are available. --- keras_hub/src/export/base.py | 30 ++++++------ keras_hub/src/export/configs.py | 80 ++++++++++++++++++++++++-------- keras_hub/src/export/lite_rt.py | 36 +++++++------- keras_hub/src/export/registry.py | 47 ++++++++++--------- 4 files changed, 117 insertions(+), 76 deletions(-) diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index 63a32952e2..777e0b3727 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -5,7 +5,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, Any, Optional, Union, List +from typing import Dict, Any, Optional, List, Type try: import keras @@ -128,31 +128,27 @@ def export(self, filepath: str) -> None: """ pass - def _ensure_model_built(self, sequence_length: Optional[int] = None): + def _ensure_model_built(self, sequence_length: Optional[int] = None) -> None: """Ensure the model is properly built with correct input structure. Args: sequence_length: Optional sequence length for dummy inputs """ if not self.model.built: - print("๐Ÿ”ง Building model with sample inputs...") - dummy_inputs = self.config.get_dummy_inputs(sequence_length) try: # Build the model with the correct input structure _ = self.model(dummy_inputs, training=False) - print("โœ… Model built successfully") except Exception as e: - print(f"โš ๏ธ Model building failed: {e}") - # Try alternative approach + # Try alternative approach using build() method try: input_shapes = {key: tensor.shape for key, tensor in dummy_inputs.items()} self.model.build(input_shape=input_shapes) - print("โœ… Model built using .build() method") - except Exception as e2: - print(f"โŒ Alternative building method also failed: {e2}") - raise + except Exception: + raise ValueError( + f"Failed to build model: {e}. Please ensure the model is properly constructed." + ) class ExporterRegistry: @@ -162,7 +158,7 @@ class ExporterRegistry: _exporters = {} @classmethod - def register_config(cls, model_type: str, config_class: type): + def register_config(cls, model_type: str, config_class: Type[KerasHubExporterConfig]) -> None: """Register a configuration class for a model type. Args: @@ -172,7 +168,7 @@ def register_config(cls, model_type: str, config_class: type): cls._configs[model_type] = config_class @classmethod - def register_exporter(cls, format_name: str, exporter_class: type): + def register_exporter(cls, format_name: str, exporter_class: Type[KerasHubExporter]) -> None: """Register an exporter class for a format. Args: @@ -190,6 +186,9 @@ def get_config_for_model(cls, model) -> KerasHubExporterConfig: Returns: An appropriate exporter configuration instance + + Raises: + ValueError: If no configuration is found for the model type """ model_type = cls._detect_model_type(model) @@ -200,7 +199,7 @@ def get_config_for_model(cls, model) -> KerasHubExporterConfig: return config_class(model) @classmethod - def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs): + def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs) -> KerasHubExporter: """Get an exporter for the specified format. Args: @@ -210,6 +209,9 @@ def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs Returns: An appropriate exporter instance + + Raises: + ValueError: If no exporter is found for the format """ if format_name not in cls._exporters: raise ValueError(f"No exporter found for format: {format_name}") diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index 803c7c3a38..b6e563415b 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -27,12 +27,16 @@ def _is_model_compatible(self) -> bool: return 'CausalLM' in self.model.__class__.__name__ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Get input signature for causal LM models.""" + """Get input signature for causal LM models. + + Args: + sequence_length: Optional sequence length. If None, will be inferred from model. + + Returns: + Dictionary mapping input names to their specifications + """ if sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) - else: - sequence_length = self.DEFAULT_SEQUENCE_LENGTH + sequence_length = self._get_sequence_length() import keras return { @@ -47,6 +51,12 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str name='padding_mask' ) } + + def _get_sequence_length(self) -> int: + """Get sequence length from model or use default.""" + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.TextClassifierExporterConfig") @@ -62,12 +72,16 @@ def _is_model_compatible(self) -> bool: return 'TextClassifier' in self.model.__class__.__name__ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Get input signature for text classifier models.""" + """Get input signature for text classifier models. + + Args: + sequence_length: Optional sequence length. If None, will be inferred from model. + + Returns: + Dictionary mapping input names to their specifications + """ if sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) - else: - sequence_length = self.DEFAULT_SEQUENCE_LENGTH + sequence_length = self._get_sequence_length() import keras return { @@ -82,6 +96,12 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str name='padding_mask' ) } + + def _get_sequence_length(self) -> int: + """Get sequence length from model or use default.""" + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") @@ -101,12 +121,16 @@ def _is_model_compatible(self) -> bool: return 'Seq2SeqLM' in self.model.__class__.__name__ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Get input signature for seq2seq models.""" + """Get input signature for seq2seq models. + + Args: + sequence_length: Optional sequence length. If None, will be inferred from model. + + Returns: + Dictionary mapping input names to their specifications + """ if sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) - else: - sequence_length = self.DEFAULT_SEQUENCE_LENGTH + sequence_length = self._get_sequence_length() import keras return { @@ -131,6 +155,12 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str name='decoder_padding_mask' ) } + + def _get_sequence_length(self) -> int: + """Get sequence length from model or use default.""" + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.TextModelExporterConfig") @@ -147,12 +177,16 @@ def _is_model_compatible(self) -> bool: return hasattr(self.model, 'preprocessor') and self.model.preprocessor and hasattr(self.model.preprocessor, 'tokenizer') def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: - """Get input signature for generic text models.""" + """Get input signature for generic text models. + + Args: + sequence_length: Optional sequence length. If None, will be inferred from model. + + Returns: + Dictionary mapping input names to their specifications + """ if sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - sequence_length = getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) - else: - sequence_length = self.DEFAULT_SEQUENCE_LENGTH + sequence_length = self._get_sequence_length() import keras return { @@ -167,3 +201,9 @@ def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str name='padding_mask' ) } + + def _get_sequence_length(self) -> int: + """Get sequence length from model or use default.""" + if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + return self.DEFAULT_SEQUENCE_LENGTH diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/lite_rt.py index 3bf5a89a04..269c0fc4c6 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -28,7 +28,7 @@ class LiteRTExporter(KerasHubExporter): def __init__(self, config: KerasHubExporterConfig, max_sequence_length: Optional[int] = None, aot_compile_targets: Optional[list] = None, - verbose: Optional[int] = None, + verbose: bool = False, **kwargs): """Initialize the LiteRT exporter. @@ -36,7 +36,7 @@ def __init__(self, config: KerasHubExporterConfig, config: Exporter configuration for the model max_sequence_length: Maximum sequence length for conversion aot_compile_targets: List of AOT compilation targets - verbose: Verbosity level + verbose: Enable verbose logging **kwargs: Additional arguments passed to the underlying exporter """ super().__init__(config, **kwargs) @@ -49,7 +49,7 @@ def __init__(self, config: KerasHubExporterConfig, self.max_sequence_length = max_sequence_length self.aot_compile_targets = aot_compile_targets - self.verbose = verbose or 0 + self.verbose = verbose # Get sequence length from model if not provided if self.max_sequence_length is None: @@ -69,10 +69,7 @@ def export(self, filepath: str) -> None: filepath: Path where to save the exported model (without extension) """ if self.verbose: - print(f"๐Ÿš€ Starting LiteRT export for {self.config.MODEL_TYPE} model...") - print(f" Model: {type(self.model).__name__}") - print(f" Expected inputs: {self.config.EXPECTED_INPUTS}") - print(f" Sequence length: {self.max_sequence_length}") + print(f"Starting LiteRT export for {self.config.MODEL_TYPE} model") # Ensure model is built with correct input structure self._ensure_model_built(self.max_sequence_length) @@ -80,9 +77,6 @@ def export(self, filepath: str) -> None: # Get the proper input signature for this model type input_signature = self.config.get_input_signature(self.max_sequence_length) - if self.verbose: - print(f" Input signature: {list(input_signature.keys())}") - # Create a wrapper that adapts the Keras-Hub model to work with Keras LiteRT exporter wrapped_model = self._create_export_wrapper() @@ -92,7 +86,7 @@ def export(self, filepath: str) -> None: input_signature=input_signature, max_sequence_length=self.max_sequence_length, aot_compile_targets=self.aot_compile_targets, - verbose=self.verbose, + verbose=1 if self.verbose else 0, **self.export_kwargs ) @@ -100,6 +94,13 @@ def export(self, filepath: str) -> None: # Export using the Keras exporter keras_exporter.export(filepath) + if self.verbose: + print(f"Export completed successfully to: {filepath}.tflite") + + except Exception as e: + raise RuntimeError(f"LiteRT export failed: {e}") from e + keras_exporter.export(filepath) + if self.verbose: print(f"โœ… Export completed successfully!") print(f"๐Ÿ“ Model saved to: {filepath}.tflite") @@ -120,12 +121,11 @@ def _create_export_wrapper(self): class KerasHubModelWrapper(keras.Model): """Wrapper that adapts Keras-Hub models for export.""" - def __init__(self, keras_hub_model, expected_inputs, input_signature, verbose=False): + def __init__(self, keras_hub_model, expected_inputs, input_signature): super().__init__() self.keras_hub_model = keras_hub_model self.expected_inputs = expected_inputs self.input_signature = input_signature - self.verbose = verbose # Create Input layers based on the input signature self._input_layers = [] @@ -139,9 +139,6 @@ def __init__(self, keras_hub_model, expected_inputs, input_signature, verbose=Fa name=input_name ) self._input_layers.append(input_layer) - - if self.verbose: - print(f"Created input layer: {input_name} - shape={spec.shape} dtype={spec.dtype}") # Store references to the original model's variables self._variables = keras_hub_model.variables @@ -181,8 +178,8 @@ def call(self, inputs, training=None, mask=None): if i < len(inputs): input_dict[input_name] = inputs[i] else: - # Handle missing inputs - this shouldn't happen but let's be safe - print(f"โš ๏ธ Missing input for {input_name}") + # Handle missing inputs + raise ValueError(f"Missing input for {input_name}") return self.keras_hub_model(input_dict, training=training, mask=mask) @@ -193,8 +190,7 @@ def get_config(self): return KerasHubModelWrapper( self.model, self.config.EXPECTED_INPUTS, - self.config.get_input_signature(self.max_sequence_length), - verbose=self.verbose + self.config.get_input_signature(self.max_sequence_length) ) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index 8856735592..e125e220d3 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -10,7 +10,6 @@ Seq2SeqLMExporterConfig, TextModelExporterConfig ) -from keras_hub.src.export.lite_rt import LiteRTExporter def initialize_export_registry(): @@ -22,7 +21,12 @@ def initialize_export_registry(): ExporterRegistry.register_config("text_model", TextModelExporterConfig) # Register exporters for different formats - ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) + try: + from keras_hub.src.export.lite_rt import LiteRTExporter + ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) + except ImportError: + # LiteRT not available + pass def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): @@ -35,7 +39,7 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) format: Export format (currently supports "lite_rt") - **kwargs: Additional arguments passed to the exporter (including verbose) + **kwargs: Additional arguments passed to the exporter """ # Ensure registry is initialized initialize_export_registry() @@ -56,35 +60,35 @@ def extend_export_method_for_keras_hub(): from keras_hub.src.models.task import Task import keras - # Store the original export method - original_export = Task.export if hasattr(Task, 'export') else keras.Model.export + # Store the original export method if it exists + original_export = getattr(Task, 'export', None) or getattr(keras.Model, 'export', None) - def keras_hub_export(self, filepath: str, format: str = "lite_rt", verbose=None, **kwargs): + def keras_hub_export(self, filepath: str, format: str = "lite_rt", verbose: bool = False, **kwargs): """Extended export method for Keras-Hub models. This method extends Keras' export functionality to properly handle Keras-Hub models that expect dictionary inputs. Args: - filepath: str. Path where to save the exported model (without extension) - format: str. Export format. Supports "lite_rt", "tf_saved_model", etc. - verbose: bool. Whether to print verbose output during export + filepath: Path where to save the exported model (without extension) + format: Export format. Supports "lite_rt", "tf_saved_model", etc. + verbose: Whether to print verbose output during export **kwargs: Additional arguments passed to the exporter """ # Check if this is a Keras-Hub model that needs special handling if format == "lite_rt" and self._is_keras_hub_model(): # Use our Keras-Hub specific export logic - # Make sure we don't duplicate the verbose parameter - if verbose is not None and 'verbose' not in kwargs: - kwargs['verbose'] = verbose + kwargs['verbose'] = verbose export_model(self, filepath, format=format, **kwargs) else: # Fall back to the original Keras export method - original_export(self, filepath, format=format, verbose=verbose, **kwargs) + if original_export: + original_export(self, filepath, format=format, verbose=verbose, **kwargs) + else: + raise NotImplementedError(f"Export format '{format}' not supported for this model type") def _is_keras_hub_model(self): """Check if this model is a Keras-Hub model that needs special handling.""" - # Check if it's a Task (most Keras-Hub models inherit from Task) if hasattr(self, '__class__'): class_name = self.__class__.__name__ module_name = self.__class__.__module__ @@ -104,18 +108,17 @@ def _is_keras_hub_model(self): return False - # Add the helper method to the class + # Add the helper method and export method to the Task class Task._is_keras_hub_model = _is_keras_hub_model - - # Override the export method Task.export = keras_hub_export - print("โœ… Extended export method for Keras-Hub models") - + except ImportError: + # Task class not available, skip extension + pass except Exception as e: - print(f"โš ๏ธ Failed to extend export method for Keras-Hub models: {e}") - import traceback - traceback.print_exc() + # Log error but don't fail import + import warnings + warnings.warn(f"Failed to extend export method for Keras-Hub models: {e}") # Initialize the registry when this module is imported From 442fdd316dd70f6a3fd54cd7d138c09a292f6a76 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 22 Sep 2025 11:06:05 +0530 Subject: [PATCH 11/19] reformat --- keras_hub/src/export/__init__.py | 2 +- keras_hub/src/export/base.py | 176 +++++++++++++++----------- keras_hub/src/export/configs.py | 207 ++++++++++++++++++------------- keras_hub/src/export/lite_rt.py | 157 +++++++++++++---------- keras_hub/src/export/registry.py | 109 ++++++++++------ keras_hub/src/models/__init__.py | 1 + keras_hub/src/models/backbone.py | 20 +-- 7 files changed, 400 insertions(+), 272 deletions(-) diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py index f2063779ef..224ae3dec9 100644 --- a/keras_hub/src/export/__init__.py +++ b/keras_hub/src/export/__init__.py @@ -5,5 +5,5 @@ from keras_hub.src.export.configs import Seq2SeqLMExporterConfig from keras_hub.src.export.configs import TextClassifierExporterConfig from keras_hub.src.export.configs import TextModelExporterConfig -from keras_hub.src.export.lite_rt import export_lite_rt from keras_hub.src.export.lite_rt import LiteRTExporter +from keras_hub.src.export.lite_rt import export_lite_rt diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index 777e0b3727..5c58511192 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -1,15 +1,23 @@ """Base classes for Keras-Hub model exporters. -This module provides the foundation for exporting Keras-Hub models to various formats. -It follows the Optimum pattern of having different exporters for different model types and formats. +This module provides the foundation for exporting Keras-Hub models to various +formats. It follows the Optimum pattern of having different exporters for +different model types and formats. """ -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional, List, Type +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Type try: import keras - from keras.src.export.export_utils import get_input_signature + # Removed unused import: from keras.src.export.export_utils import + # get_input_signature + KERAS_AVAILABLE = True except ImportError: KERAS_AVAILABLE = False @@ -18,23 +26,23 @@ class KerasHubExporterConfig(ABC): """Base configuration class for Keras-Hub model exporters. - + This class defines the interface for exporter configurations that specify how different types of Keras-Hub models should be exported. """ - + # Model type this exporter handles (e.g., "causal_lm", "text_classifier") MODEL_TYPE: str = None - + # Expected input structure for this model type EXPECTED_INPUTS: List[str] = [] - + # Default sequence length if not specified DEFAULT_SEQUENCE_LENGTH: int = 128 - + def __init__(self, model, **kwargs): """Initialize the exporter configuration. - + Args: model: The Keras-Hub model to export **kwargs: Additional configuration parameters @@ -42,7 +50,7 @@ def __init__(self, model, **kwargs): self.model = model self.config_kwargs = kwargs self._validate_model() - + def _validate_model(self): """Validate that the model is compatible with this exporter.""" if not self._is_model_compatible(): @@ -50,67 +58,83 @@ def _validate_model(self): f"Model {self.model.__class__.__name__} is not compatible " f"with {self.__class__.__name__}" ) - + @abstractmethod def _is_model_compatible(self) -> bool: """Check if the model is compatible with this exporter.""" pass - + @abstractmethod - def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + def get_input_signature( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Get the input signature for this model type. - + Args: sequence_length: Optional sequence length for input tensors - + Returns: Dictionary mapping input names to their signatures """ pass - - def get_dummy_inputs(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + + def get_dummy_inputs( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Generate dummy inputs for model building and testing. - + Args: sequence_length: Optional sequence length for dummy inputs - + Returns: Dictionary of dummy inputs """ if sequence_length is None: sequence_length = self.DEFAULT_SEQUENCE_LENGTH - + dummy_inputs = {} - + # Common inputs for most Keras-Hub models if "token_ids" in self.EXPECTED_INPUTS: - dummy_inputs["token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + dummy_inputs["token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) if "padding_mask" in self.EXPECTED_INPUTS: - dummy_inputs["padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') - + dummy_inputs["padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) + # Encoder-decoder specific inputs if "encoder_token_ids" in self.EXPECTED_INPUTS: - dummy_inputs["encoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + dummy_inputs["encoder_token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) if "encoder_padding_mask" in self.EXPECTED_INPUTS: - dummy_inputs["encoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') + dummy_inputs["encoder_padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) if "decoder_token_ids" in self.EXPECTED_INPUTS: - dummy_inputs["decoder_token_ids"] = keras.ops.ones((1, sequence_length), dtype='int32') + dummy_inputs["decoder_token_ids"] = keras.ops.ones( + (1, sequence_length), dtype="int32" + ) if "decoder_padding_mask" in self.EXPECTED_INPUTS: - dummy_inputs["decoder_padding_mask"] = keras.ops.ones((1, sequence_length), dtype='bool') - + dummy_inputs["decoder_padding_mask"] = keras.ops.ones( + (1, sequence_length), dtype="bool" + ) + return dummy_inputs class KerasHubExporter(ABC): """Base class for Keras-Hub model exporters. - + This class provides the common interface for exporting Keras-Hub models to different formats (LiteRT, ONNX, etc.). """ - + def __init__(self, config: KerasHubExporterConfig, **kwargs): """Initialize the exporter. - + Args: config: Exporter configuration specifying model type and parameters **kwargs: Additional exporter-specific parameters @@ -118,114 +142,128 @@ def __init__(self, config: KerasHubExporterConfig, **kwargs): self.config = config self.model = config.model self.export_kwargs = kwargs - + @abstractmethod def export(self, filepath: str) -> None: """Export the model to the specified filepath. - + Args: filepath: Path where to save the exported model """ pass - - def _ensure_model_built(self, sequence_length: Optional[int] = None) -> None: + + def _ensure_model_built( + self, sequence_length: Optional[int] = None + ) -> None: """Ensure the model is properly built with correct input structure. - + Args: sequence_length: Optional sequence length for dummy inputs """ if not self.model.built: dummy_inputs = self.config.get_dummy_inputs(sequence_length) - + try: # Build the model with the correct input structure _ = self.model(dummy_inputs, training=False) except Exception as e: # Try alternative approach using build() method try: - input_shapes = {key: tensor.shape for key, tensor in dummy_inputs.items()} + input_shapes = { + key: tensor.shape + for key, tensor in dummy_inputs.items() + } self.model.build(input_shape=input_shapes) except Exception: raise ValueError( - f"Failed to build model: {e}. Please ensure the model is properly constructed." + f"Failed to build model: {e}. Please ensure the model " + "is properly constructed." ) class ExporterRegistry: """Registry for mapping model types to their appropriate exporters.""" - + _configs = {} _exporters = {} - + @classmethod - def register_config(cls, model_type: str, config_class: Type[KerasHubExporterConfig]) -> None: + def register_config( + cls, model_type: str, config_class: Type[KerasHubExporterConfig] + ) -> None: """Register a configuration class for a model type. - + Args: model_type: The model type (e.g., "causal_lm") config_class: The configuration class """ cls._configs[model_type] = config_class - + @classmethod - def register_exporter(cls, format_name: str, exporter_class: Type[KerasHubExporter]) -> None: + def register_exporter( + cls, format_name: str, exporter_class: Type[KerasHubExporter] + ) -> None: """Register an exporter class for a format. - + Args: format_name: The export format (e.g., "lite_rt") exporter_class: The exporter class """ cls._exporters[format_name] = exporter_class - + @classmethod def get_config_for_model(cls, model) -> KerasHubExporterConfig: """Get the appropriate configuration for a model. - + Args: model: The Keras-Hub model - + Returns: An appropriate exporter configuration instance - + Raises: ValueError: If no configuration is found for the model type """ model_type = cls._detect_model_type(model) - + if model_type not in cls._configs: - raise ValueError(f"No configuration found for model type: {model_type}") - + raise ValueError( + f"No configuration found for model type: {model_type}" + ) + config_class = cls._configs[model_type] return config_class(model) - + @classmethod - def get_exporter(cls, format_name: str, config: KerasHubExporterConfig, **kwargs) -> KerasHubExporter: + def get_exporter( + cls, format_name: str, config: KerasHubExporterConfig, **kwargs + ) -> KerasHubExporter: """Get an exporter for the specified format. - + Args: format_name: The export format config: The exporter configuration **kwargs: Additional parameters for the exporter - + Returns: An appropriate exporter instance - + Raises: ValueError: If no exporter is found for the format """ if format_name not in cls._exporters: raise ValueError(f"No exporter found for format: {format_name}") - + exporter_class = cls._exporters[format_name] return exporter_class(config, **kwargs) - + @classmethod def _detect_model_type(cls, model) -> str: """Detect the model type from the model instance. - + Args: model: The Keras-Hub model - + Returns: The detected model type """ @@ -236,16 +274,16 @@ def _detect_model_type(cls, model) -> str: except ImportError: CausalLM = None Seq2SeqLM = None - + model_class_name = model.__class__.__name__ - + if CausalLM and isinstance(model, CausalLM): return "causal_lm" - elif 'TextClassifier' in model_class_name: + elif "TextClassifier" in model_class_name: return "text_classifier" elif Seq2SeqLM and isinstance(model, Seq2SeqLM): return "seq2seq_lm" - elif 'ImageClassifier' in model_class_name: + elif "ImageClassifier" in model_class_name: return "image_classifier" else: # Default to text model for generic Keras-Hub models diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index b6e563415b..f933f0791a 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -4,206 +4,241 @@ of Keras-Hub models, following the Optimum pattern. """ -from typing import Dict, Any, Optional -from keras_hub.src.export.base import KerasHubExporterConfig +from typing import Any +from typing import Dict +from typing import Optional + from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporterConfig @keras_hub_export("keras_hub.export.CausalLMExporterConfig") class CausalLMExporterConfig(KerasHubExporterConfig): """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" - + MODEL_TYPE = "causal_lm" EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - + def _is_model_compatible(self) -> bool: """Check if model is a causal language model.""" try: from keras_hub.src.models.causal_lm import CausalLM + return isinstance(self.model, CausalLM) except ImportError: # Fallback to class name checking - return 'CausalLM' in self.model.__class__.__name__ - - def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + return "CausalLM" in self.model.__class__.__name__ + + def get_input_signature( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Get input signature for causal LM models. - + Args: - sequence_length: Optional sequence length. If None, will be inferred from model. - + sequence_length: Optional sequence length. If None, will be inferred + from model. + Returns: Dictionary mapping input names to their specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() - + import keras + return { "token_ids": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='int32', - name='token_ids' + shape=(None, sequence_length), dtype="int32", name="token_ids" ), "padding_mask": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='bool', - name='padding_mask' - ) + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), } - + def _get_sequence_length(self) -> int: """Get sequence length from model or use default.""" - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + return getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.TextClassifierExporterConfig") class TextClassifierExporterConfig(KerasHubExporterConfig): """Exporter configuration for Text Classification models.""" - + MODEL_TYPE = "text_classifier" EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - + def _is_model_compatible(self) -> bool: """Check if model is a text classifier.""" - return 'TextClassifier' in self.model.__class__.__name__ - - def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + return "TextClassifier" in self.model.__class__.__name__ + + def get_input_signature( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Get input signature for text classifier models. - + Args: - sequence_length: Optional sequence length. If None, will be inferred from model. - + sequence_length: Optional sequence length. If None, will be inferred + from model. + Returns: Dictionary mapping input names to their specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() - + import keras + return { "token_ids": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='int32', - name='token_ids' + shape=(None, sequence_length), dtype="int32", name="token_ids" ), "padding_mask": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='bool', - name='padding_mask' - ) + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), } - + def _get_sequence_length(self) -> int: """Get sequence length from model or use default.""" - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + return getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.Seq2SeqLMExporterConfig") class Seq2SeqLMExporterConfig(KerasHubExporterConfig): """Exporter configuration for Sequence-to-Sequence Language Models.""" - + MODEL_TYPE = "seq2seq_lm" - EXPECTED_INPUTS = ["encoder_token_ids", "encoder_padding_mask", "decoder_token_ids", "decoder_padding_mask"] + EXPECTED_INPUTS = [ + "encoder_token_ids", + "encoder_padding_mask", + "decoder_token_ids", + "decoder_padding_mask", + ] DEFAULT_SEQUENCE_LENGTH = 128 - + def _is_model_compatible(self) -> bool: """Check if model is a seq2seq language model.""" try: from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + return isinstance(self.model, Seq2SeqLM) except ImportError: - return 'Seq2SeqLM' in self.model.__class__.__name__ - - def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + return "Seq2SeqLM" in self.model.__class__.__name__ + + def get_input_signature( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Get input signature for seq2seq models. - + Args: - sequence_length: Optional sequence length. If None, will be inferred from model. - + sequence_length: Optional sequence length. If None, will be inferred + from model. + Returns: Dictionary mapping input names to their specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() - + import keras + return { "encoder_token_ids": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='int32', - name='encoder_token_ids' + shape=(None, sequence_length), + dtype="int32", + name="encoder_token_ids", ), "encoder_padding_mask": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='bool', - name='encoder_padding_mask' + shape=(None, sequence_length), + dtype="bool", + name="encoder_padding_mask", ), "decoder_token_ids": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='int32', - name='decoder_token_ids' + shape=(None, sequence_length), + dtype="int32", + name="decoder_token_ids", ), "decoder_padding_mask": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='bool', - name='decoder_padding_mask' - ) + shape=(None, sequence_length), + dtype="bool", + name="decoder_padding_mask", + ), } - + def _get_sequence_length(self) -> int: """Get sequence length from model or use default.""" - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + return getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) return self.DEFAULT_SEQUENCE_LENGTH @keras_hub_export("keras_hub.export.TextModelExporterConfig") class TextModelExporterConfig(KerasHubExporterConfig): """Generic exporter configuration for text models.""" - + MODEL_TYPE = "text_model" EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - + def _is_model_compatible(self) -> bool: """Check if model is a text model (fallback).""" - # This is a fallback config for text models that don't fit other categories - return hasattr(self.model, 'preprocessor') and self.model.preprocessor and hasattr(self.model.preprocessor, 'tokenizer') - - def get_input_signature(self, sequence_length: Optional[int] = None) -> Dict[str, Any]: + # This is a fallback config for text models that don't fit other + # categories + return ( + hasattr(self.model, "preprocessor") + and self.model.preprocessor + and hasattr(self.model.preprocessor, "tokenizer") + ) + + def get_input_signature( + self, sequence_length: Optional[int] = None + ) -> Dict[str, Any]: """Get input signature for generic text models. - + Args: - sequence_length: Optional sequence length. If None, will be inferred from model. - + sequence_length: Optional sequence length. If None, will be inferred + from model. + Returns: Dictionary mapping input names to their specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() - + import keras + return { "token_ids": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='int32', - name='token_ids' + shape=(None, sequence_length), dtype="int32", name="token_ids" ), "padding_mask": keras.layers.InputSpec( - shape=(None, sequence_length), - dtype='bool', - name='padding_mask' - ) + shape=(None, sequence_length), dtype="bool", name="padding_mask" + ), } - + def _get_sequence_length(self) -> int: """Get sequence length from model or use default.""" - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: - return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + return getattr( + self.model.preprocessor, + "sequence_length", + self.DEFAULT_SEQUENCE_LENGTH, + ) return self.DEFAULT_SEQUENCE_LENGTH diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/lite_rt.py index 269c0fc4c6..359550ac59 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -1,16 +1,20 @@ """LiteRT exporter for Keras-Hub models. -This module provides LiteRT export functionality specifically designed for Keras-Hub models, -handling their unique input structures and requirements. +This module provides LiteRT export functionality specifically designed for +Keras-Hub models, handling their unique input structures and requirements. """ from typing import Optional -from keras_hub.src.export.base import KerasHubExporter, KerasHubExporterConfig from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.export.base import KerasHubExporter +from keras_hub.src.export.base import KerasHubExporterConfig try: - from keras.src.export.lite_rt_exporter import LiteRTExporter as KerasLiteRTExporter + from keras.src.export.lite_rt_exporter import ( + LiteRTExporter as KerasLiteRTExporter, + ) + KERAS_LITE_RT_AVAILABLE = True except ImportError: KERAS_LITE_RT_AVAILABLE = False @@ -20,18 +24,22 @@ @keras_hub_export("keras_hub.export.LiteRTExporter") class LiteRTExporter(KerasHubExporter): """LiteRT exporter for Keras-Hub models. - - This exporter handles the conversion of Keras-Hub models to TensorFlow Lite format, - properly managing the dictionary input structures that Keras-Hub models expect. + + This exporter handles the conversion of Keras-Hub models to TensorFlow Lite + format, properly managing the dictionary input structures that Keras-Hub + models expect. """ - - def __init__(self, config: KerasHubExporterConfig, - max_sequence_length: Optional[int] = None, - aot_compile_targets: Optional[list] = None, - verbose: bool = False, - **kwargs): + + def __init__( + self, + config: KerasHubExporterConfig, + max_sequence_length: Optional[int] = None, + aot_compile_targets: Optional[list] = None, + verbose: bool = False, + **kwargs, + ): """Initialize the LiteRT exporter. - + Args: config: Exporter configuration for the model max_sequence_length: Maximum sequence length for conversion @@ -40,46 +48,49 @@ def __init__(self, config: KerasHubExporterConfig, **kwargs: Additional arguments passed to the underlying exporter """ super().__init__(config, **kwargs) - + if not KERAS_LITE_RT_AVAILABLE: raise ImportError( "Keras LiteRT exporter is not available. " "Make sure you have Keras with LiteRT support installed." ) - + self.max_sequence_length = max_sequence_length self.aot_compile_targets = aot_compile_targets self.verbose = verbose - + # Get sequence length from model if not provided if self.max_sequence_length is None: - if hasattr(self.model, 'preprocessor') and self.model.preprocessor: + if hasattr(self.model, "preprocessor") and self.model.preprocessor: self.max_sequence_length = getattr( - self.model.preprocessor, - 'sequence_length', - self.config.DEFAULT_SEQUENCE_LENGTH + self.model.preprocessor, + "sequence_length", + self.config.DEFAULT_SEQUENCE_LENGTH, ) else: self.max_sequence_length = self.config.DEFAULT_SEQUENCE_LENGTH - + def export(self, filepath: str) -> None: """Export the Keras-Hub model to LiteRT format. - + Args: filepath: Path where to save the exported model (without extension) """ if self.verbose: print(f"Starting LiteRT export for {self.config.MODEL_TYPE} model") - + # Ensure model is built with correct input structure self._ensure_model_built(self.max_sequence_length) - + # Get the proper input signature for this model type - input_signature = self.config.get_input_signature(self.max_sequence_length) - - # Create a wrapper that adapts the Keras-Hub model to work with Keras LiteRT exporter + input_signature = self.config.get_input_signature( + self.max_sequence_length + ) + + # Create a wrapper that adapts the Keras-Hub model to work with Keras + # LiteRT exporter wrapped_model = self._create_export_wrapper() - + # Create the Keras LiteRT exporter with the wrapped model keras_exporter = KerasLiteRTExporter( wrapped_model, @@ -87,46 +98,49 @@ def export(self, filepath: str) -> None: max_sequence_length=self.max_sequence_length, aot_compile_targets=self.aot_compile_targets, verbose=1 if self.verbose else 0, - **self.export_kwargs + **self.export_kwargs, ) - + try: # Export using the Keras exporter keras_exporter.export(filepath) - + if self.verbose: print(f"Export completed successfully to: {filepath}.tflite") - + except Exception as e: raise RuntimeError(f"LiteRT export failed: {e}") from e keras_exporter.export(filepath) - + if self.verbose: - print(f"โœ… Export completed successfully!") + print("โœ… Export completed successfully!") print(f"๐Ÿ“ Model saved to: {filepath}.tflite") - + except Exception as e: if self.verbose: print(f"โŒ Export failed: {e}") raise - + def _create_export_wrapper(self): """Create a wrapper model that handles the input structure conversion. - - This wrapper converts between the list-based inputs that Keras LiteRT exporter - provides and the dictionary-based inputs that Keras-Hub models expect. + + This wrapper converts between the list-based inputs that Keras LiteRT + exporter provides and the dictionary-based inputs that Keras-Hub models + expect. """ import keras - + class KerasHubModelWrapper(keras.Model): """Wrapper that adapts Keras-Hub models for export.""" - - def __init__(self, keras_hub_model, expected_inputs, input_signature): + + def __init__( + self, keras_hub_model, expected_inputs, input_signature + ): super().__init__() self.keras_hub_model = keras_hub_model self.expected_inputs = expected_inputs self.input_signature = input_signature - + # Create Input layers based on the input signature self._input_layers = [] for input_name in expected_inputs: @@ -136,42 +150,47 @@ def __init__(self, keras_hub_model, expected_inputs, input_signature): input_layer = keras.layers.Input( shape=spec.shape[1:], # Remove batch dimension dtype=spec.dtype, - name=input_name + name=input_name, ) self._input_layers.append(input_layer) - + # Store references to the original model's variables self._variables = keras_hub_model.variables self._trainable_variables = keras_hub_model.trainable_variables - self._non_trainable_variables = keras_hub_model.non_trainable_variables - - @property + self._non_trainable_variables = ( + keras_hub_model.non_trainable_variables + ) + + @property def variables(self): return self._variables - + @property def trainable_variables(self): return self._trainable_variables - + @property def non_trainable_variables(self): return self._non_trainable_variables - + @property def inputs(self): """Return the input layers for the Keras exporter to use.""" return self._input_layers - + def call(self, inputs, training=None, mask=None): - """Convert list inputs to dictionary format and call the original model.""" + """Convert list inputs to dictionary format and call the + original model.""" if isinstance(inputs, dict): # Already in dictionary format - return self.keras_hub_model(inputs, training=training, mask=mask) - + return self.keras_hub_model( + inputs, training=training, mask=mask + ) + # Convert list inputs to dictionary format if not isinstance(inputs, (list, tuple)): inputs = [inputs] - + # Map inputs to expected dictionary structure input_dict = {} for i, input_name in enumerate(self.expected_inputs): @@ -180,17 +199,19 @@ def call(self, inputs, training=None, mask=None): else: # Handle missing inputs raise ValueError(f"Missing input for {input_name}") - - return self.keras_hub_model(input_dict, training=training, mask=mask) - + + return self.keras_hub_model( + input_dict, training=training, mask=mask + ) + def get_config(self): """Return the configuration of the wrapped model.""" return self.keras_hub_model.get_config() - + return KerasHubModelWrapper( - self.model, - self.config.EXPECTED_INPUTS, - self.config.get_input_signature(self.max_sequence_length) + self.model, + self.config.EXPECTED_INPUTS, + self.config.get_input_signature(self.max_sequence_length), ) @@ -198,20 +219,20 @@ def get_config(self): @keras_hub_export("keras_hub.export.export_lite_rt") def export_lite_rt(model, filepath: str, **kwargs) -> None: """Export a Keras-Hub model to LiteRT format. - + This is a convenience function that automatically detects the model type and exports it using the appropriate configuration. - + Args: model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) **kwargs: Additional arguments passed to the exporter """ from keras_hub.src.export.base import ExporterRegistry - + # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) - + # Create and use the LiteRT exporter exporter = LiteRTExporter(config, **kwargs) exporter.export(filepath) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index e125e220d3..45c2081250 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -1,28 +1,31 @@ """Registry initialization for Keras-Hub export functionality. -This module initializes the export registry with available configurations and exporters. +This module initializes the export registry with available configurations and +exporters. """ from keras_hub.src.export.base import ExporterRegistry -from keras_hub.src.export.configs import ( - CausalLMExporterConfig, - TextClassifierExporterConfig, - Seq2SeqLMExporterConfig, - TextModelExporterConfig -) +from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import Seq2SeqLMExporterConfig +from keras_hub.src.export.configs import TextClassifierExporterConfig +from keras_hub.src.export.configs import TextModelExporterConfig def initialize_export_registry(): - """Initialize the export registry with available configurations and exporters.""" + """Initialize the export registry with available configurations and + exporters.""" # Register configurations for different model types ExporterRegistry.register_config("causal_lm", CausalLMExporterConfig) - ExporterRegistry.register_config("text_classifier", TextClassifierExporterConfig) + ExporterRegistry.register_config( + "text_classifier", TextClassifierExporterConfig + ) ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) ExporterRegistry.register_config("text_model", TextModelExporterConfig) # Register exporters for different formats try: from keras_hub.src.export.lite_rt import LiteRTExporter + ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) except ImportError: # LiteRT not available @@ -31,10 +34,10 @@ def initialize_export_registry(): def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): """Export a Keras-Hub model to the specified format. - + This is the main export function that automatically detects the model type and uses the appropriate exporter configuration. - + Args: model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) @@ -43,82 +46,108 @@ def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): """ # Ensure registry is initialized initialize_export_registry() - + # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) - + # Get the exporter for the specified format exporter = ExporterRegistry.get_exporter(format, config, **kwargs) - + # Export the model exporter.export(filepath) def extend_export_method_for_keras_hub(): - """Extend the export method for Keras-Hub models to handle dictionary inputs.""" + """Extend the export method for Keras-Hub models to handle dictionary + inputs.""" try: - from keras_hub.src.models.task import Task import keras - + + from keras_hub.src.models.task import Task + # Store the original export method if it exists - original_export = getattr(Task, 'export', None) or getattr(keras.Model, 'export', None) - - def keras_hub_export(self, filepath: str, format: str = "lite_rt", verbose: bool = False, **kwargs): + original_export = getattr(Task, "export", None) or getattr( + keras.Model, "export", None + ) + + def keras_hub_export( + self, + filepath: str, + format: str = "lite_rt", + verbose: bool = False, + **kwargs, + ): """Extended export method for Keras-Hub models. - + This method extends Keras' export functionality to properly handle Keras-Hub models that expect dictionary inputs. - + Args: - filepath: Path where to save the exported model (without extension) - format: Export format. Supports "lite_rt", "tf_saved_model", etc. + filepath: Path where to save the exported model (without + extension) + format: Export format. Supports "lite_rt", "tf_saved_model", + etc. verbose: Whether to print verbose output during export **kwargs: Additional arguments passed to the exporter """ # Check if this is a Keras-Hub model that needs special handling if format == "lite_rt" and self._is_keras_hub_model(): # Use our Keras-Hub specific export logic - kwargs['verbose'] = verbose + kwargs["verbose"] = verbose export_model(self, filepath, format=format, **kwargs) else: # Fall back to the original Keras export method if original_export: - original_export(self, filepath, format=format, verbose=verbose, **kwargs) + original_export( + self, filepath, format=format, verbose=verbose, **kwargs + ) else: - raise NotImplementedError(f"Export format '{format}' not supported for this model type") - + raise NotImplementedError( + f"Export format '{format}' not supported for this " + "model type" + ) + def _is_keras_hub_model(self): - """Check if this model is a Keras-Hub model that needs special handling.""" - if hasattr(self, '__class__'): + """Check if this model is a Keras-Hub model that needs special + handling.""" + if hasattr(self, "__class__"): class_name = self.__class__.__name__ module_name = self.__class__.__module__ - + # Check if it's from keras_hub package - if 'keras_hub' in module_name: + if "keras_hub" in module_name: return True - + # Check if it has keras-hub specific attributes - if hasattr(self, 'preprocessor') and hasattr(self, 'backbone'): + if hasattr(self, "preprocessor") and hasattr(self, "backbone"): return True - + # Check for common Keras-Hub model names - keras_hub_model_names = ['CausalLM', 'Seq2SeqLM', 'TextClassifier', 'ImageClassifier'] + keras_hub_model_names = [ + "CausalLM", + "Seq2SeqLM", + "TextClassifier", + "ImageClassifier", + ] if any(name in class_name for name in keras_hub_model_names): return True - + return False - + # Add the helper method and export method to the Task class Task._is_keras_hub_model = _is_keras_hub_model Task.export = keras_hub_export - + except ImportError: # Task class not available, skip extension pass except Exception as e: # Log error but don't fail import import warnings - warnings.warn(f"Failed to extend export method for Keras-Hub models: {e}") + + warnings.warn( + f"Failed to extend export method for Keras-Hub models: {e}" + ) # Initialize the registry when this module is imported diff --git a/keras_hub/src/models/__init__.py b/keras_hub/src/models/__init__.py index 896e87678e..c0ada3d741 100644 --- a/keras_hub/src/models/__init__.py +++ b/keras_hub/src/models/__init__.py @@ -8,6 +8,7 @@ try: from keras_hub.src.export.registry import extend_export_method_for_keras_hub from keras_hub.src.export.registry import initialize_export_registry + # Initialize export functionality initialize_export_registry() extend_export_method_for_keras_hub() diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index cc098bb1c6..a43f9d2582 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -344,38 +344,42 @@ def _make_spec(t): def _trackable_children(self, save_type=None, **kwargs): """Override to prevent _DictWrapper issues during TensorFlow export. - + This method filters out problematic _DictWrapper objects that cause TypeError during SavedModel introspection, while preserving all essential trackable components. """ children = super()._trackable_children(save_type, **kwargs) - + # Import _DictWrapper safely try: from tensorflow.python.trackable.data_structures import _DictWrapper except ImportError: return children - + clean_children = {} for name, child in children.items(): # Handle _DictWrapper objects if isinstance(child, _DictWrapper): try: # For list-like _DictWrapper (e.g., transformer_layers) - if hasattr(child, '_data') and isinstance(child._data, list): + if hasattr(child, "_data") and isinstance( + child._data, list + ): # Create a clean list of the trackable items clean_list = [] for item in child._data: - if hasattr(item, '_trackable_children'): + if hasattr(item, "_trackable_children"): clean_list.append(item) if clean_list: clean_children[name] = clean_list # For dict-like _DictWrapper - elif hasattr(child, '_data') and isinstance(child._data, dict): + elif hasattr(child, "_data") and isinstance( + child._data, dict + ): clean_dict = {} for k, v in child._data.items(): - if hasattr(v, '_trackable_children'): + if hasattr(v, "_trackable_children"): clean_dict[k] = v if clean_dict: clean_children[name] = clean_dict @@ -386,5 +390,5 @@ def _trackable_children(self, save_type=None, **kwargs): else: # Keep non-_DictWrapper children as-is clean_children[name] = child - + return clean_children From 5446e2a8d2a62fb17a2941521cc2e42987fc58ea Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 22 Sep 2025 11:09:37 +0530 Subject: [PATCH 12/19] Add export submodule to keras_hub API Introduces the keras_hub.api.export submodule and updates the main API to expose it. The new export module imports various exporter configs and functions from the internal export package, making them available through the public API. --- keras_hub/api/__init__.py | 1 + keras_hub/api/export/__init__.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 keras_hub/api/export/__init__.py diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..810f8fa921 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_hub import export as export from keras_hub import layers as layers from keras_hub import metrics as metrics from keras_hub import models as models diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py new file mode 100644 index 0000000000..16e5e1817f --- /dev/null +++ b/keras_hub/api/export/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_hub.src.export.configs import ( + CausalLMExporterConfig as CausalLMExporterConfig, +) +from keras_hub.src.export.configs import ( + Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, +) +from keras_hub.src.export.configs import ( + TextClassifierExporterConfig as TextClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + TextModelExporterConfig as TextModelExporterConfig, +) +from keras_hub.src.export.lite_rt import LiteRTExporter as LiteRTExporter +from keras_hub.src.export.lite_rt import export_lite_rt as export_lite_rt From 5c31d88b020bd2519742fa1ef521aeb36024cf8a Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 22 Sep 2025 11:14:23 +0530 Subject: [PATCH 13/19] reformat --- keras_hub/api/export/__init__.py | 1 - keras_hub/src/export/base.py | 62 +++++++++------------ keras_hub/src/export/configs.py | 96 ++++++++++++++++++++------------ keras_hub/src/export/lite_rt.py | 16 ++---- keras_hub/src/export/registry.py | 8 +-- 5 files changed, 95 insertions(+), 88 deletions(-) diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py index 16e5e1817f..311f8aa323 100644 --- a/keras_hub/api/export/__init__.py +++ b/keras_hub/api/export/__init__.py @@ -17,4 +17,3 @@ TextModelExporterConfig as TextModelExporterConfig, ) from keras_hub.src.export.lite_rt import LiteRTExporter as LiteRTExporter -from keras_hub.src.export.lite_rt import export_lite_rt as export_lite_rt diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index 5c58511192..57b1d06b12 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -7,11 +7,6 @@ from abc import ABC from abc import abstractmethod -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Type try: import keras @@ -32,13 +27,13 @@ class KerasHubExporterConfig(ABC): """ # Model type this exporter handles (e.g., "causal_lm", "text_classifier") - MODEL_TYPE: str = None + MODEL_TYPE = None # Expected input structure for this model type - EXPECTED_INPUTS: List[str] = [] + EXPECTED_INPUTS = [] # Default sequence length if not specified - DEFAULT_SEQUENCE_LENGTH: int = 128 + DEFAULT_SEQUENCE_LENGTH = 128 def __init__(self, model, **kwargs): """Initialize the exporter configuration. @@ -60,34 +55,34 @@ def _validate_model(self): ) @abstractmethod - def _is_model_compatible(self) -> bool: - """Check if the model is compatible with this exporter.""" + def _is_model_compatible(self): + """Check if the model is compatible with this exporter. + + Returns: + bool: True if compatible, False otherwise + """ pass @abstractmethod - def get_input_signature( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_input_signature(self, sequence_length=None): """Get the input signature for this model type. Args: sequence_length: Optional sequence length for input tensors Returns: - Dictionary mapping input names to their signatures + Dict[str, Any]: Dictionary mapping input names to their signatures """ pass - def get_dummy_inputs( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_dummy_inputs(self, sequence_length=None): """Generate dummy inputs for model building and testing. Args: sequence_length: Optional sequence length for dummy inputs Returns: - Dictionary of dummy inputs + Dict[str, Any]: Dictionary of dummy inputs """ if sequence_length is None: sequence_length = self.DEFAULT_SEQUENCE_LENGTH @@ -132,7 +127,7 @@ class KerasHubExporter(ABC): to different formats (LiteRT, ONNX, etc.). """ - def __init__(self, config: KerasHubExporterConfig, **kwargs): + def __init__(self, config, **kwargs): """Initialize the exporter. Args: @@ -144,7 +139,7 @@ def __init__(self, config: KerasHubExporterConfig, **kwargs): self.export_kwargs = kwargs @abstractmethod - def export(self, filepath: str) -> None: + def export(self, filepath): """Export the model to the specified filepath. Args: @@ -152,9 +147,7 @@ def export(self, filepath: str) -> None: """ pass - def _ensure_model_built( - self, sequence_length: Optional[int] = None - ) -> None: + def _ensure_model_built(self, sequence_length=None): """Ensure the model is properly built with correct input structure. Args: @@ -188,9 +181,7 @@ class ExporterRegistry: _exporters = {} @classmethod - def register_config( - cls, model_type: str, config_class: Type[KerasHubExporterConfig] - ) -> None: + def register_config(cls, model_type, config_class): """Register a configuration class for a model type. Args: @@ -200,9 +191,7 @@ def register_config( cls._configs[model_type] = config_class @classmethod - def register_exporter( - cls, format_name: str, exporter_class: Type[KerasHubExporter] - ) -> None: + def register_exporter(cls, format_name, exporter_class): """Register an exporter class for a format. Args: @@ -212,14 +201,15 @@ def register_exporter( cls._exporters[format_name] = exporter_class @classmethod - def get_config_for_model(cls, model) -> KerasHubExporterConfig: + def get_config_for_model(cls, model): """Get the appropriate configuration for a model. Args: model: The Keras-Hub model Returns: - An appropriate exporter configuration instance + KerasHubExporterConfig: An appropriate exporter configuration + instance Raises: ValueError: If no configuration is found for the model type @@ -235,9 +225,7 @@ def get_config_for_model(cls, model) -> KerasHubExporterConfig: return config_class(model) @classmethod - def get_exporter( - cls, format_name: str, config: KerasHubExporterConfig, **kwargs - ) -> KerasHubExporter: + def get_exporter(cls, format_name, config, **kwargs): """Get an exporter for the specified format. Args: @@ -246,7 +234,7 @@ def get_exporter( **kwargs: Additional parameters for the exporter Returns: - An appropriate exporter instance + KerasHubExporter: An appropriate exporter instance Raises: ValueError: If no exporter is found for the format @@ -258,14 +246,14 @@ def get_exporter( return exporter_class(config, **kwargs) @classmethod - def _detect_model_type(cls, model) -> str: + def _detect_model_type(cls, model): """Detect the model type from the model instance. Args: model: The Keras-Hub model Returns: - The detected model type + str: The detected model type """ # Import here to avoid circular imports try: diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index f933f0791a..94255dd2a2 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -4,10 +4,6 @@ of Keras-Hub models, following the Optimum pattern. """ -from typing import Any -from typing import Dict -from typing import Optional - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.export.base import KerasHubExporterConfig @@ -20,8 +16,12 @@ class CausalLMExporterConfig(KerasHubExporterConfig): EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - def _is_model_compatible(self) -> bool: - """Check if model is a causal language model.""" + def _is_model_compatible(self): + """Check if model is a causal language model. + + Returns: + bool: True if compatible, False otherwise + """ try: from keras_hub.src.models.causal_lm import CausalLM @@ -30,9 +30,7 @@ def _is_model_compatible(self) -> bool: # Fallback to class name checking return "CausalLM" in self.model.__class__.__name__ - def get_input_signature( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_input_signature(self, sequence_length=None): """Get input signature for causal LM models. Args: @@ -40,7 +38,8 @@ def get_input_signature( from model. Returns: - Dictionary mapping input names to their specifications + Dict[str, Any]: Dictionary mapping input names to their + specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() @@ -56,8 +55,12 @@ def get_input_signature( ), } - def _get_sequence_length(self) -> int: - """Get sequence length from model or use default.""" + def _get_sequence_length(self): + """Get sequence length from model or use default. + + Returns: + int: The sequence length + """ if hasattr(self.model, "preprocessor") and self.model.preprocessor: return getattr( self.model.preprocessor, @@ -75,13 +78,15 @@ class TextClassifierExporterConfig(KerasHubExporterConfig): EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - def _is_model_compatible(self) -> bool: - """Check if model is a text classifier.""" + def _is_model_compatible(self): + """Check if model is a text classifier. + + Returns: + bool: True if compatible, False otherwise + """ return "TextClassifier" in self.model.__class__.__name__ - def get_input_signature( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_input_signature(self, sequence_length=None): """Get input signature for text classifier models. Args: @@ -89,7 +94,8 @@ def get_input_signature( from model. Returns: - Dictionary mapping input names to their specifications + Dict[str, Any]: Dictionary mapping input names to their + specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() @@ -105,8 +111,12 @@ def get_input_signature( ), } - def _get_sequence_length(self) -> int: - """Get sequence length from model or use default.""" + def _get_sequence_length(self): + """Get sequence length from model or use default. + + Returns: + int: The sequence length + """ if hasattr(self.model, "preprocessor") and self.model.preprocessor: return getattr( self.model.preprocessor, @@ -129,8 +139,12 @@ class Seq2SeqLMExporterConfig(KerasHubExporterConfig): ] DEFAULT_SEQUENCE_LENGTH = 128 - def _is_model_compatible(self) -> bool: - """Check if model is a seq2seq language model.""" + def _is_model_compatible(self): + """Check if model is a seq2seq language model. + + Returns: + bool: True if compatible, False otherwise + """ try: from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM @@ -138,9 +152,7 @@ def _is_model_compatible(self) -> bool: except ImportError: return "Seq2SeqLM" in self.model.__class__.__name__ - def get_input_signature( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_input_signature(self, sequence_length=None): """Get input signature for seq2seq models. Args: @@ -148,7 +160,8 @@ def get_input_signature( from model. Returns: - Dictionary mapping input names to their specifications + Dict[str, Any]: Dictionary mapping input names to their + specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() @@ -178,8 +191,12 @@ def get_input_signature( ), } - def _get_sequence_length(self) -> int: - """Get sequence length from model or use default.""" + def _get_sequence_length(self): + """Get sequence length from model or use default. + + Returns: + int: The sequence length + """ if hasattr(self.model, "preprocessor") and self.model.preprocessor: return getattr( self.model.preprocessor, @@ -197,8 +214,12 @@ class TextModelExporterConfig(KerasHubExporterConfig): EXPECTED_INPUTS = ["token_ids", "padding_mask"] DEFAULT_SEQUENCE_LENGTH = 128 - def _is_model_compatible(self) -> bool: - """Check if model is a text model (fallback).""" + def _is_model_compatible(self): + """Check if model is a text model (fallback). + + Returns: + bool: True if compatible, False otherwise + """ # This is a fallback config for text models that don't fit other # categories return ( @@ -207,9 +228,7 @@ def _is_model_compatible(self) -> bool: and hasattr(self.model.preprocessor, "tokenizer") ) - def get_input_signature( - self, sequence_length: Optional[int] = None - ) -> Dict[str, Any]: + def get_input_signature(self, sequence_length=None): """Get input signature for generic text models. Args: @@ -217,7 +236,8 @@ def get_input_signature( from model. Returns: - Dictionary mapping input names to their specifications + Dict[str, Any]: Dictionary mapping input names to their + specifications """ if sequence_length is None: sequence_length = self._get_sequence_length() @@ -233,8 +253,12 @@ def get_input_signature( ), } - def _get_sequence_length(self) -> int: - """Get sequence length from model or use default.""" + def _get_sequence_length(self): + """Get sequence length from model or use default. + + Returns: + int: The sequence length + """ if hasattr(self.model, "preprocessor") and self.model.preprocessor: return getattr( self.model.preprocessor, diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/lite_rt.py index 359550ac59..f890849745 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -4,11 +4,8 @@ Keras-Hub models, handling their unique input structures and requirements. """ -from typing import Optional - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.export.base import KerasHubExporter -from keras_hub.src.export.base import KerasHubExporterConfig try: from keras.src.export.lite_rt_exporter import ( @@ -32,10 +29,10 @@ class LiteRTExporter(KerasHubExporter): def __init__( self, - config: KerasHubExporterConfig, - max_sequence_length: Optional[int] = None, - aot_compile_targets: Optional[list] = None, - verbose: bool = False, + config, + max_sequence_length=None, + aot_compile_targets=None, + verbose=False, **kwargs, ): """Initialize the LiteRT exporter. @@ -70,7 +67,7 @@ def __init__( else: self.max_sequence_length = self.config.DEFAULT_SEQUENCE_LENGTH - def export(self, filepath: str) -> None: + def export(self, filepath): """Export the Keras-Hub model to LiteRT format. Args: @@ -216,8 +213,7 @@ def get_config(self): # Convenience function for direct export -@keras_hub_export("keras_hub.export.export_lite_rt") -def export_lite_rt(model, filepath: str, **kwargs) -> None: +def export_lite_rt(model, filepath, **kwargs): """Export a Keras-Hub model to LiteRT format. This is a convenience function that automatically detects the model type diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index 45c2081250..c8a2500882 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -32,7 +32,7 @@ def initialize_export_registry(): pass -def export_model(model, filepath: str, format: str = "lite_rt", **kwargs): +def export_model(model, filepath, format="lite_rt", **kwargs): """Export a Keras-Hub model to the specified format. This is the main export function that automatically detects the model type @@ -72,9 +72,9 @@ def extend_export_method_for_keras_hub(): def keras_hub_export( self, - filepath: str, - format: str = "lite_rt", - verbose: bool = False, + filepath, + format="lite_rt", + verbose=False, **kwargs, ): """Extended export method for Keras-Hub models. From 3290d42afe419fe3f2cb6e42edf540cd69b41961 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 23 Sep 2025 10:13:42 +0530 Subject: [PATCH 14/19] now supporting export for objectDetectors --- debug_object_detection.py | 159 +++++++++++++++++++++ keras_hub/src/export/base.py | 6 + keras_hub/src/export/configs.py | 238 +++++++++++++++++++++++++++++++ keras_hub/src/export/lite_rt.py | 102 ++++++++++--- keras_hub/src/export/registry.py | 14 ++ 5 files changed, 502 insertions(+), 17 deletions(-) create mode 100644 debug_object_detection.py diff --git a/debug_object_detection.py b/debug_object_detection.py new file mode 100644 index 0000000000..7edcce0479 --- /dev/null +++ b/debug_object_detection.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Test script to understand object detection model outputs and investigate export issues. +""" + +import keras_hub +import keras +import numpy as np + +def test_object_detection_outputs(): + """Test what object detection models output in different modes.""" + + print("๐Ÿ” Testing object detection model outputs...") + + # Load a simple object detection model + model = keras_hub.models.DFineObjectDetector.from_preset( + "dfine_nano_coco", + # Remove NMS post-processing to see raw outputs + prediction_decoder=None # This should give us raw logits and boxes + ) + + print(f"โœ… Model loaded: {model.__class__.__name__}") + print(f"๐Ÿ“ Model inputs: {[inp.shape for inp in model.inputs] if hasattr(model, 'inputs') and model.inputs else 'Not built yet'}") + + # Create test input + test_input = np.random.random((1, 640, 640, 3)).astype(np.float32) + image_shape = np.array([[640, 640]], dtype=np.int32) + + print(f"๐ŸŽฏ Test input shapes:") + print(f" Images: {test_input.shape}") + print(f" Image shape: {image_shape.shape}") + + # Test raw outputs (without post-processing) + print(f"\n๐Ÿง  Testing raw model outputs...") + try: + # Try dictionary input format + raw_outputs = model({ + "images": test_input, + "image_shape": image_shape + }, training=False) + + print(f"โœ… Raw outputs (dict input):") + if isinstance(raw_outputs, dict): + for key, value in raw_outputs.items(): + print(f" {key}: {value.shape}") + else: + print(f" Output type: {type(raw_outputs)}") + if hasattr(raw_outputs, 'shape'): + print(f" Output shape: {raw_outputs.shape}") + + except Exception as e: + print(f"โŒ Dict input failed: {e}") + + # Try single tensor input + try: + raw_outputs = model(test_input, training=False) + print(f"โœ… Raw outputs (single tensor input):") + if isinstance(raw_outputs, dict): + for key, value in raw_outputs.items(): + print(f" {key}: {value.shape}") + else: + print(f" Output type: {type(raw_outputs)}") + if hasattr(raw_outputs, 'shape'): + print(f" Output shape: {raw_outputs.shape}") + except Exception as e2: + print(f"โŒ Single tensor input also failed: {e2}") + + # Now test with the default post-processing + print(f"\n๐ŸŽฏ Testing with default NMS post-processing...") + model_with_nms = keras_hub.models.DFineObjectDetector.from_preset("dfine_nano_coco") + + try: + # Try dictionary input format + nms_outputs = model_with_nms({ + "images": test_input, + "image_shape": image_shape + }, training=False) + + print(f"โœ… NMS outputs (dict input):") + if isinstance(nms_outputs, dict): + for key, value in nms_outputs.items(): + print(f" {key}: {value.shape} (dtype: {value.dtype})") + else: + print(f" Output type: {type(nms_outputs)}") + + except Exception as e: + print(f"โŒ Dict input failed with NMS: {e}") + + # Try single tensor input + try: + nms_outputs = model_with_nms(test_input, training=False) + print(f"โœ… NMS outputs (single tensor input):") + if isinstance(nms_outputs, dict): + for key, value in nms_outputs.items(): + print(f" {key}: {value.shape} (dtype: {value.dtype})") + else: + print(f" Output type: {type(nms_outputs)}") + except Exception as e2: + print(f"โŒ Single tensor input also failed with NMS: {e2}") + +def test_export_attempt(): + """Test the current export behavior that's failing.""" + print(f"\n๐Ÿš€ Testing current export behavior...") + + try: + model = keras_hub.models.DFineObjectDetector.from_preset("dfine_nano_coco") + + # Check what the export config expects + from keras_hub.src.export.base import ExporterRegistry + config = ExporterRegistry.get_config_for_model(model) + + print(f"๐Ÿ“‹ Export config:") + print(f" Model type: {config.MODEL_TYPE}") + print(f" Expected inputs: {config.EXPECTED_INPUTS}") + + # Try to get input signature + signature = config.get_input_signature() + print(f" Input signature:") + for name, spec in signature.items(): + print(f" {name}: shape={spec.shape}, dtype={spec.dtype}") + + # Try to create the export wrapper to see what fails + from keras_hub.src.export.lite_rt import LiteRTExporter + exporter = LiteRTExporter(config, verbose=True) + + # Try to build the wrapper (this is where it might fail) + print(f"\n๐Ÿ”ง Creating export wrapper...") + wrapper = exporter._create_export_wrapper() + print(f"โœ… Export wrapper created successfully") + print(f" Wrapper inputs: {[inp.shape for inp in wrapper.inputs]}") + + # Try a forward pass through the wrapper + print(f"\n๐Ÿงช Testing wrapper forward pass...") + test_inputs = [ + np.random.random((1, 640, 640, 3)).astype(np.float32), + np.array([[640, 640]], dtype=np.int32) + ] + + wrapper_output = wrapper(test_inputs) + print(f"โœ… Wrapper forward pass successful:") + if isinstance(wrapper_output, dict): + for key, value in wrapper_output.items(): + print(f" {key}: {value.shape}") + else: + print(f" Output shape: {wrapper_output.shape}") + + except Exception as e: + print(f"โŒ Export test failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + try: + test_object_detection_outputs() + test_export_attempt() + except Exception as e: + print(f"โŒ Test failed: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index 57b1d06b12..d31906e5d6 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -259,9 +259,11 @@ def _detect_model_type(cls, model): try: from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM + from keras_hub.src.models.object_detector import ObjectDetector except ImportError: CausalLM = None Seq2SeqLM = None + ObjectDetector = None model_class_name = model.__class__.__name__ @@ -273,6 +275,10 @@ def _detect_model_type(cls, model): return "seq2seq_lm" elif "ImageClassifier" in model_class_name: return "image_classifier" + elif ObjectDetector and isinstance(model, ObjectDetector): + return "object_detector" + elif "ObjectDetector" in model_class_name: + return "object_detector" else: # Default to text model for generic Keras-Hub models return "text_model" diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index 94255dd2a2..d9a70f6508 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -266,3 +266,241 @@ def _get_sequence_length(self): self.DEFAULT_SEQUENCE_LENGTH, ) return self.DEFAULT_SEQUENCE_LENGTH + + +@keras_hub_export("keras_hub.export.ImageClassifierExporterConfig") +class ImageClassifierExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Classification models.""" + + MODEL_TYPE = "image_classifier" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image classifier. + Returns: + bool: True if compatible, False otherwise + """ + return "ImageClassifier" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for image classifier models. + Args: + image_size: Optional image size. If None, will be inferred + from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + image_size = self._get_image_size() + if isinstance(image_size, int): + image_size = (image_size, image_size) + + import keras + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), dtype="float32", name="images" + ), + } + + def _get_image_size(self): + """Get image size from model preprocessor. + Returns: + tuple: The image size (height, width) + """ + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + return self.model.preprocessor.image_size + + # If no preprocessor image_size, try to infer from model inputs + if hasattr(self.model, "inputs") and self.model.inputs: + input_shape = self.model.inputs[0].shape + if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + # Shape is (batch, height, width, channels) + return (input_shape[1], input_shape[2]) + + # Last resort: raise an error instead of using hardcoded values + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size attribute, " + "or model inputs should have concrete shapes." + ) + + def get_dummy_inputs(self, image_size=None): + """Generate dummy inputs for image classifier models. + + Args: + image_size: Optional image size. If None, will be inferred from model. + + Returns: + Dict[str, Any]: Dictionary of dummy inputs + """ + if image_size is None: + image_size = self._get_image_size() + if isinstance(image_size, int): + image_size = (image_size, image_size) + + import keras + + dummy_inputs = {} + if "images" in self.EXPECTED_INPUTS: + dummy_inputs["images"] = keras.ops.ones( + (1, *image_size, 3), dtype="float32" + ) + + return dummy_inputs + + +@keras_hub_export("keras_hub.export.ObjectDetectorExporterConfig") +class ObjectDetectorExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Object Detection models.""" + + MODEL_TYPE = "object_detector" + EXPECTED_INPUTS = ["images", "image_shape"] + + def _is_model_compatible(self): + """Check if model is an object detector. + Returns: + bool: True if compatible, False otherwise + """ + return "ObjectDetector" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for object detector models. + Args: + image_size: Optional image size. If None, will be inferred + from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + image_size = self._get_image_size() + if isinstance(image_size, int): + image_size = (image_size, image_size) + + import keras + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), dtype="float32", name="images" + ), + "image_shape": keras.layers.InputSpec( + shape=(None, 2), dtype="int32", name="image_shape" + ), + } + + def _get_image_size(self): + """Get image size from model preprocessor. + Returns: + tuple: The image size (height, width) + """ + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + return self.model.preprocessor.image_size + + # If no preprocessor image_size, try to infer from model inputs + if hasattr(self.model, "inputs") and self.model.inputs: + input_shape = self.model.inputs[0].shape + if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + # Shape is (batch, height, width, channels) + return (input_shape[1], input_shape[2]) + + # Last resort: raise an error instead of using hardcoded values + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size attribute, " + "or model inputs should have concrete shapes." + ) + + def get_dummy_inputs(self, image_size=None): + """Generate dummy inputs for object detector models. + + Args: + image_size: Optional image size. If None, will be inferred + from model. + + Returns: + Dict[str, Any]: Dictionary of dummy inputs + """ + if image_size is None: + image_size = self._get_image_size() + if isinstance(image_size, int): + image_size = (image_size, image_size) + + import keras + + dummy_inputs = {} + + # Create dummy image input + dummy_inputs["images"] = keras.ops.random_uniform( + (1, *image_size, 3), dtype="float32" + ) + + # Create dummy image shape input + dummy_inputs["image_shape"] = keras.ops.constant( + [[image_size[0], image_size[1]]], dtype="int32" + ) + + return dummy_inputs + + +@keras_hub_export("keras_hub.export.ImageSegmenterExporterConfig") +class ImageSegmenterExporterConfig(KerasHubExporterConfig): + """Exporter configuration for Image Segmentation models.""" + + MODEL_TYPE = "image_segmenter" + EXPECTED_INPUTS = ["images"] + + def _is_model_compatible(self): + """Check if model is an image segmenter. + Returns: + bool: True if compatible, False otherwise + """ + return "ImageSegmenter" in self.model.__class__.__name__ + + def get_input_signature(self, image_size=None): + """Get input signature for image segmenter models. + Args: + image_size: Optional image size. If None, will be inferred + from model. + Returns: + Dict[str, Any]: Dictionary mapping input names to their + specifications + """ + if image_size is None: + image_size = self._get_image_size() + if isinstance(image_size, int): + image_size = (image_size, image_size) + + import keras + + return { + "images": keras.layers.InputSpec( + shape=(None, *image_size, 3), dtype="float32", name="images" + ), + } + + def _get_image_size(self): + """Get image size from model preprocessor. + Returns: + tuple: The image size (height, width) + """ + if hasattr(self.model, "preprocessor") and self.model.preprocessor: + if hasattr(self.model.preprocessor, "image_size"): + return self.model.preprocessor.image_size + + # If no preprocessor image_size, try to infer from model inputs + if hasattr(self.model, "inputs") and self.model.inputs: + input_shape = self.model.inputs[0].shape + if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + # Shape is (batch, height, width, channels) + return (input_shape[1], input_shape[2]) + + # Last resort: raise an error instead of using hardcoded values + raise ValueError( + "Could not determine image size from model. " + "Model should have a preprocessor with image_size attribute, " + "or model inputs should have concrete shapes." + ) diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/lite_rt.py index f890849745..89830fead3 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -77,12 +77,22 @@ def export(self, filepath): print(f"Starting LiteRT export for {self.config.MODEL_TYPE} model") # Ensure model is built with correct input structure - self._ensure_model_built(self.max_sequence_length) + # For text models, use sequence length; for image models, use None to auto-detect + if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + build_param = self.max_sequence_length + else: + build_param = None # Let image models auto-detect from preprocessor + + self._ensure_model_built(build_param) # Get the proper input signature for this model type - input_signature = self.config.get_input_signature( - self.max_sequence_length - ) + # For text models, pass sequence length; for image models, pass None to auto-detect + if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + signature_param = self.max_sequence_length + else: + signature_param = None # Let image models auto-detect from preprocessor + + input_signature = self.config.get_input_signature(signature_param) # Create a wrapper that adapts the Keras-Hub model to work with Keras # LiteRT exporter @@ -92,7 +102,6 @@ def export(self, filepath): keras_exporter = KerasLiteRTExporter( wrapped_model, input_signature=input_signature, - max_sequence_length=self.max_sequence_length, aot_compile_targets=self.aot_compile_targets, verbose=1 if self.verbose else 0, **self.export_kwargs, @@ -188,27 +197,86 @@ def call(self, inputs, training=None, mask=None): if not isinstance(inputs, (list, tuple)): inputs = [inputs] - # Map inputs to expected dictionary structure - input_dict = {} - for i, input_name in enumerate(self.expected_inputs): - if i < len(inputs): - input_dict[input_name] = inputs[i] + # For image classifiers, try the direct tensor approach first + # since most Keras-Hub vision models expect single tensor inputs + if len(self.expected_inputs) == 1 and self.expected_inputs[0] == "images": + try: + return self.keras_hub_model( + inputs[0], training=training, mask=mask + ) + except Exception: + # Fall back to dictionary approach if that fails + pass + + # For LiteRT export, we need to handle the fact that different + # Keras Hub models expect inputs in different formats. Some + # expect dictionaries, others expect single tensors. + try: + # First, try mapping to the expected input names (dictionary format) + input_dict = {} + if len(self.expected_inputs) == 1: + input_dict[self.expected_inputs[0]] = inputs[0] else: - # Handle missing inputs - raise ValueError(f"Missing input for {input_name}") - - return self.keras_hub_model( - input_dict, training=training, mask=mask - ) + for i, input_name in enumerate(self.expected_inputs): + input_dict[input_name] = inputs[i] + + return self.keras_hub_model( + input_dict, training=training, mask=mask + ) + except ValueError as e: + error_msg = str(e) + # If that fails, try direct tensor input (positional format) + if ("doesn't match the expected structure" in error_msg and + "Expected: keras_tensor" in error_msg): + # The model expects a single tensor, not a dictionary + if len(inputs) == 1: + return self.keras_hub_model( + inputs[0], training=training, mask=mask + ) + else: + # Multiple inputs - try as positional arguments + return self.keras_hub_model( + *inputs, training=training, mask=mask + ) + elif "Missing data for input" in error_msg: + # Extract the actual expected input names from the error + if "Expected the following keys:" in error_msg: + # Parse the expected keys from error message + start = error_msg.find("Expected the following keys: [") + if start != -1: + start += len("Expected the following keys: [") + end = error_msg.find("]", start) + if end != -1: + keys_str = error_msg[start:end] + actual_input_names = [k.strip().strip("'\"") for k in keys_str.split(",")] + + # Map inputs to actual expected names + input_dict = {} + for i, actual_name in enumerate(actual_input_names): + if i < len(inputs): + input_dict[actual_name] = inputs[i] + + return self.keras_hub_model( + input_dict, training=training, mask=mask + ) + + # If we still can't figure it out, re-raise the original error + raise def get_config(self): """Return the configuration of the wrapped model.""" return self.keras_hub_model.get_config() + # Pass the correct parameter based on model type + if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + signature_param = self.max_sequence_length + else: + signature_param = None # Let image models auto-detect from preprocessor + return KerasHubModelWrapper( self.model, self.config.EXPECTED_INPUTS, - self.config.get_input_signature(self.max_sequence_length), + self.config.get_input_signature(signature_param), ) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index c8a2500882..df9f5c6524 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -6,6 +6,9 @@ from keras_hub.src.export.base import ExporterRegistry from keras_hub.src.export.configs import CausalLMExporterConfig +from keras_hub.src.export.configs import ImageClassifierExporterConfig +from keras_hub.src.export.configs import ImageSegmenterExporterConfig +from keras_hub.src.export.configs import ObjectDetectorExporterConfig from keras_hub.src.export.configs import Seq2SeqLMExporterConfig from keras_hub.src.export.configs import TextClassifierExporterConfig from keras_hub.src.export.configs import TextModelExporterConfig @@ -22,6 +25,17 @@ def initialize_export_registry(): ExporterRegistry.register_config("seq2seq_lm", Seq2SeqLMExporterConfig) ExporterRegistry.register_config("text_model", TextModelExporterConfig) + # Register vision model configurations + ExporterRegistry.register_config( + "image_classifier", ImageClassifierExporterConfig + ) + ExporterRegistry.register_config( + "object_detector", ObjectDetectorExporterConfig + ) + ExporterRegistry.register_config( + "image_segmenter", ImageSegmenterExporterConfig + ) + # Register exporters for different formats try: from keras_hub.src.export.lite_rt import LiteRTExporter From 8b1024fddec833897d31aa872b2961d635fd62fb Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Tue, 23 Sep 2025 10:22:12 +0530 Subject: [PATCH 15/19] Add and refine image model exporter configs Added ImageClassifierExporterConfig, ImageSegmenterExporterConfig, and ObjectDetectorExporterConfig to the export API. Improved input shape inference and dummy input generation for image-related exporter configs. Refactored LiteRTExporter to better handle model type checks and input signature logic, with improved error handling for input mapping. --- keras_hub/api/export/__init__.py | 9 ++++ keras_hub/src/export/base.py | 2 +- keras_hub/src/export/configs.py | 51 +++++++++++++-------- keras_hub/src/export/lite_rt.py | 76 +++++++++++++++++++++++--------- 4 files changed, 96 insertions(+), 42 deletions(-) diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py index 311f8aa323..6d961e5eba 100644 --- a/keras_hub/api/export/__init__.py +++ b/keras_hub/api/export/__init__.py @@ -7,6 +7,15 @@ from keras_hub.src.export.configs import ( CausalLMExporterConfig as CausalLMExporterConfig, ) +from keras_hub.src.export.configs import ( + ImageClassifierExporterConfig as ImageClassifierExporterConfig, +) +from keras_hub.src.export.configs import ( + ImageSegmenterExporterConfig as ImageSegmenterExporterConfig, +) +from keras_hub.src.export.configs import ( + ObjectDetectorExporterConfig as ObjectDetectorExporterConfig, +) from keras_hub.src.export.configs import ( Seq2SeqLMExporterConfig as Seq2SeqLMExporterConfig, ) diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index d31906e5d6..f214f354bd 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -258,8 +258,8 @@ def _detect_model_type(cls, model): # Import here to avoid circular imports try: from keras_hub.src.models.causal_lm import CausalLM - from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.object_detector import ObjectDetector + from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM except ImportError: CausalLM = None Seq2SeqLM = None diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index d9a70f6508..103d25ccbf 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -312,14 +312,18 @@ def _get_image_size(self): if hasattr(self.model, "preprocessor") and self.model.preprocessor: if hasattr(self.model.preprocessor, "image_size"): return self.model.preprocessor.image_size - + # If no preprocessor image_size, try to infer from model inputs if hasattr(self.model, "inputs") and self.model.inputs: input_shape = self.model.inputs[0].shape - if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): # Shape is (batch, height, width, channels) return (input_shape[1], input_shape[2]) - + # Last resort: raise an error instead of using hardcoded values raise ValueError( "Could not determine image size from model. " @@ -331,7 +335,8 @@ def get_dummy_inputs(self, image_size=None): """Generate dummy inputs for image classifier models. Args: - image_size: Optional image size. If None, will be inferred from model. + image_size: Optional image size. If None, will be inferred from + model. Returns: Dict[str, Any]: Dictionary of dummy inputs @@ -348,7 +353,7 @@ def get_dummy_inputs(self, image_size=None): dummy_inputs["images"] = keras.ops.ones( (1, *image_size, 3), dtype="float32" ) - + return dummy_inputs @@ -399,14 +404,18 @@ def _get_image_size(self): if hasattr(self.model, "preprocessor") and self.model.preprocessor: if hasattr(self.model.preprocessor, "image_size"): return self.model.preprocessor.image_size - + # If no preprocessor image_size, try to infer from model inputs if hasattr(self.model, "inputs") and self.model.inputs: input_shape = self.model.inputs[0].shape - if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): # Shape is (batch, height, width, channels) return (input_shape[1], input_shape[2]) - + # Last resort: raise an error instead of using hardcoded values raise ValueError( "Could not determine image size from model. " @@ -416,11 +425,11 @@ def _get_image_size(self): def get_dummy_inputs(self, image_size=None): """Generate dummy inputs for object detector models. - + Args: image_size: Optional image size. If None, will be inferred from model. - + Returns: Dict[str, Any]: Dictionary of dummy inputs """ @@ -428,21 +437,21 @@ def get_dummy_inputs(self, image_size=None): image_size = self._get_image_size() if isinstance(image_size, int): image_size = (image_size, image_size) - + import keras - + dummy_inputs = {} - + # Create dummy image input dummy_inputs["images"] = keras.ops.random_uniform( (1, *image_size, 3), dtype="float32" ) - - # Create dummy image shape input + + # Create dummy image shape input dummy_inputs["image_shape"] = keras.ops.constant( [[image_size[0], image_size[1]]], dtype="int32" ) - + return dummy_inputs @@ -490,14 +499,18 @@ def _get_image_size(self): if hasattr(self.model, "preprocessor") and self.model.preprocessor: if hasattr(self.model.preprocessor, "image_size"): return self.model.preprocessor.image_size - + # If no preprocessor image_size, try to infer from model inputs if hasattr(self.model, "inputs") and self.model.inputs: input_shape = self.model.inputs[0].shape - if len(input_shape) == 4 and input_shape[1] is not None and input_shape[2] is not None: + if ( + len(input_shape) == 4 + and input_shape[1] is not None + and input_shape[2] is not None + ): # Shape is (batch, height, width, channels) return (input_shape[1], input_shape[2]) - + # Last resort: raise an error instead of using hardcoded values raise ValueError( "Could not determine image size from model. " diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/lite_rt.py index 89830fead3..e047f0929d 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/lite_rt.py @@ -77,21 +77,33 @@ def export(self, filepath): print(f"Starting LiteRT export for {self.config.MODEL_TYPE} model") # Ensure model is built with correct input structure - # For text models, use sequence length; for image models, use None to auto-detect - if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + # For text models, use sequence length; for image models, use None to + # auto-detect + if self.config.MODEL_TYPE in [ + "causal_lm", + "text_classifier", + "seq2seq_lm", + ]: build_param = self.max_sequence_length else: build_param = None # Let image models auto-detect from preprocessor - + self._ensure_model_built(build_param) # Get the proper input signature for this model type - # For text models, pass sequence length; for image models, pass None to auto-detect - if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + # For text models, pass sequence length; for image models, pass None to + # auto-detect + if self.config.MODEL_TYPE in [ + "causal_lm", + "text_classifier", + "seq2seq_lm", + ]: signature_param = self.max_sequence_length else: - signature_param = None # Let image models auto-detect from preprocessor - + signature_param = ( + None # Let image models auto-detect from preprocessor + ) + input_signature = self.config.get_input_signature(signature_param) # Create a wrapper that adapts the Keras-Hub model to work with Keras @@ -199,7 +211,10 @@ def call(self, inputs, training=None, mask=None): # For image classifiers, try the direct tensor approach first # since most Keras-Hub vision models expect single tensor inputs - if len(self.expected_inputs) == 1 and self.expected_inputs[0] == "images": + if ( + len(self.expected_inputs) == 1 + and self.expected_inputs[0] == "images" + ): try: return self.keras_hub_model( inputs[0], training=training, mask=mask @@ -212,22 +227,25 @@ def call(self, inputs, training=None, mask=None): # Keras Hub models expect inputs in different formats. Some # expect dictionaries, others expect single tensors. try: - # First, try mapping to the expected input names (dictionary format) + # First, try mapping to the expected input names (dictionary + # format) input_dict = {} if len(self.expected_inputs) == 1: input_dict[self.expected_inputs[0]] = inputs[0] else: for i, input_name in enumerate(self.expected_inputs): input_dict[input_name] = inputs[i] - + return self.keras_hub_model( input_dict, training=training, mask=mask ) except ValueError as e: error_msg = str(e) # If that fails, try direct tensor input (positional format) - if ("doesn't match the expected structure" in error_msg and - "Expected: keras_tensor" in error_msg): + if ( + "doesn't match the expected structure" in error_msg + and "Expected: keras_tensor" in error_msg + ): # The model expects a single tensor, not a dictionary if len(inputs) == 1: return self.keras_hub_model( @@ -242,25 +260,33 @@ def call(self, inputs, training=None, mask=None): # Extract the actual expected input names from the error if "Expected the following keys:" in error_msg: # Parse the expected keys from error message - start = error_msg.find("Expected the following keys: [") + start = error_msg.find( + "Expected the following keys: [" + ) if start != -1: start += len("Expected the following keys: [") end = error_msg.find("]", start) if end != -1: keys_str = error_msg[start:end] - actual_input_names = [k.strip().strip("'\"") for k in keys_str.split(",")] - + actual_input_names = [ + k.strip().strip("'\"") + for k in keys_str.split(",") + ] + # Map inputs to actual expected names input_dict = {} - for i, actual_name in enumerate(actual_input_names): + for i, actual_name in enumerate( + actual_input_names + ): if i < len(inputs): input_dict[actual_name] = inputs[i] - + return self.keras_hub_model( input_dict, training=training, mask=mask ) - - # If we still can't figure it out, re-raise the original error + + # If we still can't figure it out, re-raise the original + # error raise def get_config(self): @@ -268,11 +294,17 @@ def get_config(self): return self.keras_hub_model.get_config() # Pass the correct parameter based on model type - if self.config.MODEL_TYPE in ["causal_lm", "text_classifier", "seq2seq_lm"]: + if self.config.MODEL_TYPE in [ + "causal_lm", + "text_classifier", + "seq2seq_lm", + ]: signature_param = self.max_sequence_length else: - signature_param = None # Let image models auto-detect from preprocessor - + signature_param = ( + None # Let image models auto-detect from preprocessor + ) + return KerasHubModelWrapper( self.model, self.config.EXPECTED_INPUTS, From 8df5a75843e7a4719f3a9a4be1f196e353345040 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 24 Sep 2025 09:38:37 +0530 Subject: [PATCH 16/19] Refactor: move keras import to module level Moved the 'import keras' statement to the top of the module and removed redundant local imports within class methods. This improves code clarity and avoids repeated imports. --- keras_hub/src/export/configs.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index 103d25ccbf..bc00d9b08f 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -4,6 +4,7 @@ of Keras-Hub models, following the Optimum pattern. """ +import keras from keras_hub.src.api_export import keras_hub_export from keras_hub.src.export.base import KerasHubExporterConfig @@ -44,8 +45,6 @@ def get_input_signature(self, sequence_length=None): if sequence_length is None: sequence_length = self._get_sequence_length() - import keras - return { "token_ids": keras.layers.InputSpec( shape=(None, sequence_length), dtype="int32", name="token_ids" @@ -100,8 +99,6 @@ def get_input_signature(self, sequence_length=None): if sequence_length is None: sequence_length = self._get_sequence_length() - import keras - return { "token_ids": keras.layers.InputSpec( shape=(None, sequence_length), dtype="int32", name="token_ids" @@ -166,8 +163,6 @@ def get_input_signature(self, sequence_length=None): if sequence_length is None: sequence_length = self._get_sequence_length() - import keras - return { "encoder_token_ids": keras.layers.InputSpec( shape=(None, sequence_length), @@ -242,8 +237,6 @@ def get_input_signature(self, sequence_length=None): if sequence_length is None: sequence_length = self._get_sequence_length() - import keras - return { "token_ids": keras.layers.InputSpec( shape=(None, sequence_length), dtype="int32", name="token_ids" @@ -296,8 +289,6 @@ def get_input_signature(self, image_size=None): if isinstance(image_size, int): image_size = (image_size, image_size) - import keras - return { "images": keras.layers.InputSpec( shape=(None, *image_size, 3), dtype="float32", name="images" @@ -346,8 +337,6 @@ def get_dummy_inputs(self, image_size=None): if isinstance(image_size, int): image_size = (image_size, image_size) - import keras - dummy_inputs = {} if "images" in self.EXPECTED_INPUTS: dummy_inputs["images"] = keras.ops.ones( @@ -385,8 +374,6 @@ def get_input_signature(self, image_size=None): if isinstance(image_size, int): image_size = (image_size, image_size) - import keras - return { "images": keras.layers.InputSpec( shape=(None, *image_size, 3), dtype="float32", name="images" @@ -438,8 +425,6 @@ def get_dummy_inputs(self, image_size=None): if isinstance(image_size, int): image_size = (image_size, image_size) - import keras - dummy_inputs = {} # Create dummy image input @@ -483,8 +468,6 @@ def get_input_signature(self, image_size=None): if isinstance(image_size, int): image_size = (image_size, image_size) - import keras - return { "images": keras.layers.InputSpec( shape=(None, *image_size, 3), dtype="float32", name="images" From 759d2232c05c269cc22af7640b621b6f623110c5 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Wed, 24 Sep 2025 09:39:27 +0530 Subject: [PATCH 17/19] Remove debug_object_detection.py script Deleted the debug_object_detection.py script, which was used for testing object detection model outputs and export issues. This cleanup removes unused debugging code from the repository. --- debug_object_detection.py | 159 -------------------------------------- 1 file changed, 159 deletions(-) delete mode 100644 debug_object_detection.py diff --git a/debug_object_detection.py b/debug_object_detection.py deleted file mode 100644 index 7edcce0479..0000000000 --- a/debug_object_detection.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to understand object detection model outputs and investigate export issues. -""" - -import keras_hub -import keras -import numpy as np - -def test_object_detection_outputs(): - """Test what object detection models output in different modes.""" - - print("๐Ÿ” Testing object detection model outputs...") - - # Load a simple object detection model - model = keras_hub.models.DFineObjectDetector.from_preset( - "dfine_nano_coco", - # Remove NMS post-processing to see raw outputs - prediction_decoder=None # This should give us raw logits and boxes - ) - - print(f"โœ… Model loaded: {model.__class__.__name__}") - print(f"๐Ÿ“ Model inputs: {[inp.shape for inp in model.inputs] if hasattr(model, 'inputs') and model.inputs else 'Not built yet'}") - - # Create test input - test_input = np.random.random((1, 640, 640, 3)).astype(np.float32) - image_shape = np.array([[640, 640]], dtype=np.int32) - - print(f"๐ŸŽฏ Test input shapes:") - print(f" Images: {test_input.shape}") - print(f" Image shape: {image_shape.shape}") - - # Test raw outputs (without post-processing) - print(f"\n๐Ÿง  Testing raw model outputs...") - try: - # Try dictionary input format - raw_outputs = model({ - "images": test_input, - "image_shape": image_shape - }, training=False) - - print(f"โœ… Raw outputs (dict input):") - if isinstance(raw_outputs, dict): - for key, value in raw_outputs.items(): - print(f" {key}: {value.shape}") - else: - print(f" Output type: {type(raw_outputs)}") - if hasattr(raw_outputs, 'shape'): - print(f" Output shape: {raw_outputs.shape}") - - except Exception as e: - print(f"โŒ Dict input failed: {e}") - - # Try single tensor input - try: - raw_outputs = model(test_input, training=False) - print(f"โœ… Raw outputs (single tensor input):") - if isinstance(raw_outputs, dict): - for key, value in raw_outputs.items(): - print(f" {key}: {value.shape}") - else: - print(f" Output type: {type(raw_outputs)}") - if hasattr(raw_outputs, 'shape'): - print(f" Output shape: {raw_outputs.shape}") - except Exception as e2: - print(f"โŒ Single tensor input also failed: {e2}") - - # Now test with the default post-processing - print(f"\n๐ŸŽฏ Testing with default NMS post-processing...") - model_with_nms = keras_hub.models.DFineObjectDetector.from_preset("dfine_nano_coco") - - try: - # Try dictionary input format - nms_outputs = model_with_nms({ - "images": test_input, - "image_shape": image_shape - }, training=False) - - print(f"โœ… NMS outputs (dict input):") - if isinstance(nms_outputs, dict): - for key, value in nms_outputs.items(): - print(f" {key}: {value.shape} (dtype: {value.dtype})") - else: - print(f" Output type: {type(nms_outputs)}") - - except Exception as e: - print(f"โŒ Dict input failed with NMS: {e}") - - # Try single tensor input - try: - nms_outputs = model_with_nms(test_input, training=False) - print(f"โœ… NMS outputs (single tensor input):") - if isinstance(nms_outputs, dict): - for key, value in nms_outputs.items(): - print(f" {key}: {value.shape} (dtype: {value.dtype})") - else: - print(f" Output type: {type(nms_outputs)}") - except Exception as e2: - print(f"โŒ Single tensor input also failed with NMS: {e2}") - -def test_export_attempt(): - """Test the current export behavior that's failing.""" - print(f"\n๐Ÿš€ Testing current export behavior...") - - try: - model = keras_hub.models.DFineObjectDetector.from_preset("dfine_nano_coco") - - # Check what the export config expects - from keras_hub.src.export.base import ExporterRegistry - config = ExporterRegistry.get_config_for_model(model) - - print(f"๐Ÿ“‹ Export config:") - print(f" Model type: {config.MODEL_TYPE}") - print(f" Expected inputs: {config.EXPECTED_INPUTS}") - - # Try to get input signature - signature = config.get_input_signature() - print(f" Input signature:") - for name, spec in signature.items(): - print(f" {name}: shape={spec.shape}, dtype={spec.dtype}") - - # Try to create the export wrapper to see what fails - from keras_hub.src.export.lite_rt import LiteRTExporter - exporter = LiteRTExporter(config, verbose=True) - - # Try to build the wrapper (this is where it might fail) - print(f"\n๐Ÿ”ง Creating export wrapper...") - wrapper = exporter._create_export_wrapper() - print(f"โœ… Export wrapper created successfully") - print(f" Wrapper inputs: {[inp.shape for inp in wrapper.inputs]}") - - # Try a forward pass through the wrapper - print(f"\n๐Ÿงช Testing wrapper forward pass...") - test_inputs = [ - np.random.random((1, 640, 640, 3)).astype(np.float32), - np.array([[640, 640]], dtype=np.int32) - ] - - wrapper_output = wrapper(test_inputs) - print(f"โœ… Wrapper forward pass successful:") - if isinstance(wrapper_output, dict): - for key, value in wrapper_output.items(): - print(f" {key}: {value.shape}") - else: - print(f" Output shape: {wrapper_output.shape}") - - except Exception as e: - print(f"โŒ Export test failed: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - try: - test_object_detection_outputs() - test_export_attempt() - except Exception as e: - print(f"โŒ Test failed: {e}") - import traceback - traceback.print_exc() \ No newline at end of file From 0737c93fb1ea529016fb292d0135b1a20cf800ec Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Fri, 3 Oct 2025 14:15:15 +0530 Subject: [PATCH 18/19] Rename LiteRT to Litert and update exporter configs Renames all references of 'LiteRT' to 'Litert' across the codebase, including file names, class names, and function names. Updates exporter registry and API imports to use the new 'litert' naming. Also improves image model exporter configs to dynamically determine input dtype from the model, enhancing flexibility for different input types. Adds support for ImageSegmenter model type detection in the exporter registry. --- keras_hub/api/export/__init__.py | 2 +- keras_hub/src/export/__init__.py | 4 +-- keras_hub/src/export/base.py | 8 ++++- keras_hub/src/export/configs.py | 36 +++++++++++++++++-- .../src/export/{lite_rt.py => litert.py} | 29 +++++++++------ keras_hub/src/export/registry.py | 16 ++++----- 6 files changed, 70 insertions(+), 25 deletions(-) rename keras_hub/src/export/{lite_rt.py => litert.py} (93%) diff --git a/keras_hub/api/export/__init__.py b/keras_hub/api/export/__init__.py index 6d961e5eba..25d1cc446a 100644 --- a/keras_hub/api/export/__init__.py +++ b/keras_hub/api/export/__init__.py @@ -25,4 +25,4 @@ from keras_hub.src.export.configs import ( TextModelExporterConfig as TextModelExporterConfig, ) -from keras_hub.src.export.lite_rt import LiteRTExporter as LiteRTExporter +from keras_hub.src.export.litert import LitertExporter as LitertExporter diff --git a/keras_hub/src/export/__init__.py b/keras_hub/src/export/__init__.py index 224ae3dec9..4c32e4411d 100644 --- a/keras_hub/src/export/__init__.py +++ b/keras_hub/src/export/__init__.py @@ -5,5 +5,5 @@ from keras_hub.src.export.configs import Seq2SeqLMExporterConfig from keras_hub.src.export.configs import TextClassifierExporterConfig from keras_hub.src.export.configs import TextModelExporterConfig -from keras_hub.src.export.lite_rt import LiteRTExporter -from keras_hub.src.export.lite_rt import export_lite_rt +from keras_hub.src.export.litert import LitertExporter +from keras_hub.src.export.litert import export_litert diff --git a/keras_hub/src/export/base.py b/keras_hub/src/export/base.py index f214f354bd..194baff76e 100644 --- a/keras_hub/src/export/base.py +++ b/keras_hub/src/export/base.py @@ -195,7 +195,7 @@ def register_exporter(cls, format_name, exporter_class): """Register an exporter class for a format. Args: - format_name: The export format (e.g., "lite_rt") + format_name: The export format (e.g., "litert") exporter_class: The exporter class """ cls._exporters[format_name] = exporter_class @@ -258,12 +258,14 @@ def _detect_model_type(cls, model): # Import here to avoid circular imports try: from keras_hub.src.models.causal_lm import CausalLM + from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM except ImportError: CausalLM = None Seq2SeqLM = None ObjectDetector = None + ImageSegmenter = None model_class_name = model.__class__.__name__ @@ -279,6 +281,10 @@ def _detect_model_type(cls, model): return "object_detector" elif "ObjectDetector" in model_class_name: return "object_detector" + elif ImageSegmenter and isinstance(model, ImageSegmenter): + return "image_segmenter" + elif "ImageSegmenter" in model_class_name: + return "image_segmenter" else: # Default to text model for generic Keras-Hub models return "text_model" diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index bc00d9b08f..282a8aa23e 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -291,10 +291,20 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype="float32", name="images" + shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" ), } + def _get_input_dtype(self): + """Get input dtype from model. + Returns: + str: The input dtype (e.g., 'float32', 'float16') + """ + if hasattr(self.model, "inputs") and self.model.inputs: + return str(self.model.inputs[0].dtype) + # Default fallback + return "float32" + def _get_image_size(self): """Get image size from model preprocessor. Returns: @@ -376,13 +386,23 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype="float32", name="images" + shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" ), "image_shape": keras.layers.InputSpec( shape=(None, 2), dtype="int32", name="image_shape" ), } + def _get_input_dtype(self): + """Get input dtype from model. + Returns: + str: The input dtype (e.g., 'float32', 'float16') + """ + if hasattr(self.model, "inputs") and self.model.inputs: + return str(self.model.inputs[0].dtype) + # Default fallback + return "float32" + def _get_image_size(self): """Get image size from model preprocessor. Returns: @@ -470,10 +490,20 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype="float32", name="images" + shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" ), } + def _get_input_dtype(self): + """Get input dtype from model. + Returns: + str: The input dtype (e.g., 'float32', 'float16') + """ + if hasattr(self.model, "inputs") and self.model.inputs: + return str(self.model.inputs[0].dtype) + # Default fallback + return "float32" + def _get_image_size(self): """Get image size from model preprocessor. Returns: diff --git a/keras_hub/src/export/lite_rt.py b/keras_hub/src/export/litert.py similarity index 93% rename from keras_hub/src/export/lite_rt.py rename to keras_hub/src/export/litert.py index e047f0929d..422f2b3068 100644 --- a/keras_hub/src/export/lite_rt.py +++ b/keras_hub/src/export/litert.py @@ -8,18 +8,18 @@ from keras_hub.src.export.base import KerasHubExporter try: - from keras.src.export.lite_rt_exporter import ( - LiteRTExporter as KerasLiteRTExporter, + from keras.src.export.litert_exporter import ( + LitertExporter as KerasLitertExporter, ) KERAS_LITE_RT_AVAILABLE = True except ImportError: KERAS_LITE_RT_AVAILABLE = False - KerasLiteRTExporter = None + KerasLitertExporter = None -@keras_hub_export("keras_hub.export.LiteRTExporter") -class LiteRTExporter(KerasHubExporter): +@keras_hub_export("keras_hub.export.LitertExporter") +class LitertExporter(KerasHubExporter): """LiteRT exporter for Keras-Hub models. This exporter handles the conversion of Keras-Hub models to TensorFlow Lite @@ -110,8 +110,17 @@ def export(self, filepath): # LiteRT exporter wrapped_model = self._create_export_wrapper() + # Convert input signature to list format expected by Keras exporter + if isinstance(input_signature, dict): + # Extract specs in the order expected by the model + signature_list = [] + for input_name in self.config.EXPECTED_INPUTS: + if input_name in input_signature: + signature_list.append(input_signature[input_name]) + input_signature = signature_list + # Create the Keras LiteRT exporter with the wrapped model - keras_exporter = KerasLiteRTExporter( + keras_exporter = KerasLitertExporter( wrapped_model, input_signature=input_signature, aot_compile_targets=self.aot_compile_targets, @@ -313,8 +322,8 @@ def get_config(self): # Convenience function for direct export -def export_lite_rt(model, filepath, **kwargs): - """Export a Keras-Hub model to LiteRT format. +def export_litert(model, filepath, **kwargs): + """Export a Keras-Hub model to Litert format. This is a convenience function that automatically detects the model type and exports it using the appropriate configuration. @@ -329,6 +338,6 @@ def export_lite_rt(model, filepath, **kwargs): # Get the appropriate configuration for this model config = ExporterRegistry.get_config_for_model(model) - # Create and use the LiteRT exporter - exporter = LiteRTExporter(config, **kwargs) + # Create and use the Litert exporter + exporter = LitertExporter(config, **kwargs) exporter.export(filepath) diff --git a/keras_hub/src/export/registry.py b/keras_hub/src/export/registry.py index df9f5c6524..652a863897 100644 --- a/keras_hub/src/export/registry.py +++ b/keras_hub/src/export/registry.py @@ -38,15 +38,15 @@ def initialize_export_registry(): # Register exporters for different formats try: - from keras_hub.src.export.lite_rt import LiteRTExporter + from keras_hub.src.export.litert import LitertExporter - ExporterRegistry.register_exporter("lite_rt", LiteRTExporter) + ExporterRegistry.register_exporter("litert", LitertExporter) except ImportError: - # LiteRT not available + # Litert not available pass -def export_model(model, filepath, format="lite_rt", **kwargs): +def export_model(model, filepath, format="litert", **kwargs): """Export a Keras-Hub model to the specified format. This is the main export function that automatically detects the model type @@ -55,7 +55,7 @@ def export_model(model, filepath, format="lite_rt", **kwargs): Args: model: The Keras-Hub model to export filepath: Path where to save the exported model (without extension) - format: Export format (currently supports "lite_rt") + format: Export format (currently supports "litert") **kwargs: Additional arguments passed to the exporter """ # Ensure registry is initialized @@ -87,7 +87,7 @@ def extend_export_method_for_keras_hub(): def keras_hub_export( self, filepath, - format="lite_rt", + format="litert", verbose=False, **kwargs, ): @@ -99,13 +99,13 @@ def keras_hub_export( Args: filepath: Path where to save the exported model (without extension) - format: Export format. Supports "lite_rt", "tf_saved_model", + format: Export format. Supports "litert", "tf_saved_model", etc. verbose: Whether to print verbose output during export **kwargs: Additional arguments passed to the exporter """ # Check if this is a Keras-Hub model that needs special handling - if format == "lite_rt" and self._is_keras_hub_model(): + if format == "litert" and self._is_keras_hub_model(): # Use our Keras-Hub specific export logic kwargs["verbose"] = verbose export_model(self, filepath, format=format, **kwargs) From c733e18806992951129bea02e23489c2ab182cad Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Mon, 6 Oct 2025 10:18:17 +0530 Subject: [PATCH 19/19] Refactor InputSpec formatting and fix import path Refactored InputSpec definitions in exporter configs for improved readability by placing each argument on a separate line. Updated import path in litert.py to import from keras.src.export.litert instead of keras.src.export.litert_exporter. --- keras_hub/src/export/configs.py | 13 ++++++++++--- keras_hub/src/export/litert.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/export/configs.py b/keras_hub/src/export/configs.py index 282a8aa23e..a8b761e251 100644 --- a/keras_hub/src/export/configs.py +++ b/keras_hub/src/export/configs.py @@ -5,6 +5,7 @@ """ import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.export.base import KerasHubExporterConfig @@ -291,7 +292,9 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" + shape=(None, *image_size, 3), + dtype=self._get_input_dtype(), + name="images", ), } @@ -386,7 +389,9 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" + shape=(None, *image_size, 3), + dtype=self._get_input_dtype(), + name="images", ), "image_shape": keras.layers.InputSpec( shape=(None, 2), dtype="int32", name="image_shape" @@ -490,7 +495,9 @@ def get_input_signature(self, image_size=None): return { "images": keras.layers.InputSpec( - shape=(None, *image_size, 3), dtype=self._get_input_dtype(), name="images" + shape=(None, *image_size, 3), + dtype=self._get_input_dtype(), + name="images", ), } diff --git a/keras_hub/src/export/litert.py b/keras_hub/src/export/litert.py index 422f2b3068..951b2be4da 100644 --- a/keras_hub/src/export/litert.py +++ b/keras_hub/src/export/litert.py @@ -8,7 +8,7 @@ from keras_hub.src.export.base import KerasHubExporter try: - from keras.src.export.litert_exporter import ( + from keras.src.export.litert import ( LitertExporter as KerasLitertExporter, )