diff --git a/src/configs/generate_config.py b/src/configs/generate_config.py index de7315f..feb9e88 100644 --- a/src/configs/generate_config.py +++ b/src/configs/generate_config.py @@ -11,6 +11,10 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantizers import ( Quantizer, ) +from tensorflow_model_optimization.python.core.quantization.keras.utils import ( + deserialize_keras_object, + serialize_keras_object, +) AttributeQuantizerDict = dict[str, Quantizer | dict[str, Quantizer]] @@ -47,11 +51,29 @@ class GenerateConfig(QuantizeConfig): def __init__( self, - weights: AttributeQuantizerDict = {}, - activations: AttributeQuantizerDict = {}, + weights: AttributeQuantizerDict = None, + activations: AttributeQuantizerDict = None, ): - self.weights = flatten_nested_dict(weights) - self.activations = flatten_nested_dict(activations) + self._raw_weights = weights or {} + self._raw_activations = activations or {} + self.weights = flatten_nested_dict(self._raw_weights) + self.activations = flatten_nested_dict(self._raw_activations) + + def _serialize_recursively(self, nested_obj): + """Recursively traverses a nested dict and serializes any Keras object + (like a Quantizer) it finds.""" + if isinstance(nested_obj, dict): + # If it's a dict, recurse on its values + return { + key: self._serialize_recursively(value) + for key, value in nested_obj.items() + } + elif hasattr(nested_obj, "get_config"): + # Base case: If it's a serializable object (a Quantizer), serialize it. + return serialize_keras_object(nested_obj) + else: + # It's some other primitive type, return it as is. + return nested_obj def get_weights_and_quantizers( self, layer: Layer @@ -93,8 +115,45 @@ def get_output_quantizers(self, layer): return [] def get_config(self): - return {"weights": self.weights, "activations": self.activations} + """Correctly serialize by recursively traversing the nested + dictionaries.""" + return { + "weights": self._serialize_recursively(self._raw_weights), + "activations": self._serialize_recursively(self._raw_activations), + } @classmethod def from_config(cls, config): - return cls(**config) + """ + FIX: Use a recursive helper function to deserialize the nested structure + before passing it to the constructor. + """ + + def _deserialize_recursively(nested_config): + # Base Case: If the dict is a serialized Keras object, deserialize it. + if ( + isinstance(nested_config, dict) + and "class_name" in nested_config + ): + return deserialize_keras_object(nested_config) + + # Recursive Step: If it's a dict container, recurse on its values. + if isinstance(nested_config, dict): + return { + key: _deserialize_recursively(value) + for key, value in nested_config.items() + } + + # Handle lists of items + if isinstance(nested_config, list): + return [ + _deserialize_recursively(item) for item in nested_config + ] + + # It's a primitive type, return as is. + return nested_config + + return cls( + weights=_deserialize_recursively(config["weights"]), + activations=_deserialize_recursively(config["activations"]), + ) diff --git a/src/configs/generate_config_test.py b/src/configs/generate_config_test.py index 06580df..5aa2e9b 100755 --- a/src/configs/generate_config_test.py +++ b/src/configs/generate_config_test.py @@ -1,10 +1,13 @@ -#!/usr/bin/env python3 +#!/usr/bin/python3 import unittest from tensorflow_model_optimization.python.core.quantization.keras.quantizers import ( Quantizer, ) +from tensorflow_model_optimization.python.core.quantization.keras.utils import ( + serialize_keras_object, +) from configs.generate_config import ( GenerateConfig, @@ -19,6 +22,13 @@ class GenerateConfigTest(unittest.TestCase): def setUp(self): self.quantizer = unittest.mock.Mock(spec=Quantizer) + self.quantizer.get_config.return_value = { + "bits": 8, + "signed": True, + "name_suffix": "_asdasd", + "initializer": "Constant", + "regularizer": None, + } self.layer = unittest.mock.Mock() def test_can_instantiate_generate_config(self): @@ -103,13 +113,15 @@ def test_can_get_config(self): weights={"kernel": self.quantizer}, activations={"activation": self.quantizer}, ) - self.assertDictEqual( - config.get_config(), - { - "weights": {"kernel": self.quantizer}, - "activations": {"activation": self.quantizer}, + + expected_config = { + "weights": {"kernel": serialize_keras_object(self.quantizer)}, + "activations": { + "activation": serialize_keras_object(self.quantizer) }, - ) + } + + self.assertDictEqual(config.get_config(), expected_config) def test_can_get_config_with_dict(self): """Test that verifies that the GenerateConfig can get the @@ -121,16 +133,19 @@ def test_can_get_config_with_dict(self): "other": self.quantizer, }, ) - self.assertDictEqual( - config.get_config(), - { - "weights": {"kernel": self.quantizer, "bias": self.quantizer}, - "activations": { - "activation": self.quantizer, - "other": self.quantizer, - }, + + expected_config = { + "weights": { + "kernel": serialize_keras_object(self.quantizer), + "bias": serialize_keras_object(self.quantizer), }, - ) + "activations": { + "activation": serialize_keras_object(self.quantizer), + "other": serialize_keras_object(self.quantizer), + }, + } + + self.assertDictEqual(config.get_config(), expected_config) def test_nested_weight_config(self): """Test that verifies that the GenerateConfig can be instantiated with diff --git a/src/configs/qmodel.py b/src/configs/qmodel.py index 21d683b..64784f9 100644 --- a/src/configs/qmodel.py +++ b/src/configs/qmodel.py @@ -107,11 +107,14 @@ def apply_quantization(model: Model, quantizers: LayerQuantizerDict): ) # TODO(Fran): get below dict objects from the quantizers passed in add method (Constant comes from UniformQuantizer) # So maybe if there are custom objects to register each class should have a method to return them - custom_objects = {} - custom_objects["GenerateConfig"] = GenerateConfig - custom_objects["UniformQuantizer"] = UniformQuantizer - custom_objects["FlexQuantizer"] = FlexQuantizer - custom_objects["Constant"] = Constant with quantize_scope(custom_objects): return quantize_apply(quantize_model(model, quantizers)) + + +custom_objects = { + "GenerateConfig": GenerateConfig, + "UniformQuantizer": UniformQuantizer, + "FlexQuantizer": FlexQuantizer, + "Constant": Constant, +} diff --git a/src/configs/serialization/__init__.py b/src/configs/serialization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/configs/serialization/serialization.py b/src/configs/serialization/serialization.py new file mode 100644 index 0000000..0d0650f --- /dev/null +++ b/src/configs/serialization/serialization.py @@ -0,0 +1,27 @@ +from tensorflow.keras.models import load_model +from tensorflow_model_optimization.quantization.keras import quantize_scope + +from configs.qmodel import custom_objects + + +def save_qmodel(model, model_dir): + """Save the quantized model to the specified directory. + + Args: + model: The quantized Keras model to save. + model_dir: Directory where the model will be saved. + """ + model.save(model_dir, save_format="tf") + + +def load_qmodel(model_dir): + """Load a quantized Keras model from the specified directory. + + Args: + model_dir: Directory from which to load the model. + + Returns: + The loaded Keras model. + """ + with quantize_scope(custom_objects): + return load_model(model_dir, custom_objects=custom_objects) diff --git a/src/configs/serialization/serialization_test.py b/src/configs/serialization/serialization_test.py new file mode 100755 index 0000000..d7def43 --- /dev/null +++ b/src/configs/serialization/serialization_test.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +import tempfile +import unittest + +from tensorflow import keras + +from configs.qmodel import apply_quantization +from configs.serialization.serialization import load_qmodel, save_qmodel +from quantizers.flex_quantizer import FlexQuantizer +from quantizers.uniform_quantizer import UniformQuantizer + + +class TestSerialization(unittest.TestCase): + def test_single_uniform(self): + input_shape = [20] + + model = keras.Sequential( + [ + keras.layers.Dense( + 20, input_shape=input_shape, name="dense_1" + ), + keras.layers.Flatten(), + ] + ) + + qconfig = { + "dense_1": { + "weights": { + "kernel": UniformQuantizer(8, name_suffix="_asdasd"), + }, + } + } + quant_aware_model = apply_quantization(model, qconfig) + + quant_aware_model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + + model_dir = tempfile.mkdtemp() + save_qmodel(quant_aware_model, model_dir) + + load_qmodel(model_dir) + + def test_single_flex(self): + input_shape = [20] + + model = keras.Sequential( + [ + keras.layers.Dense( + 20, input_shape=input_shape, name="dense_1" + ), + keras.layers.Flatten(), + ] + ) + + qconfig = { + "dense_1": { + "weights": { + "kernel": FlexQuantizer( + bits=8, n_levels=10, signed=True, name_suffix="_asdasd" + ), + }, + } + } + quant_aware_model = apply_quantization(model, qconfig) + + quant_aware_model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + + model_dir = tempfile.mkdtemp() + save_qmodel(quant_aware_model, model_dir) + + load_qmodel(model_dir) + + def test_lenet_5(self): + model = keras.models.Sequential( + [ + keras.layers.Conv2D( + filters=6, + kernel_size=(5, 5), + activation="relu", + padding="same", + input_shape=(28, 28, 1), + ), + keras.layers.AveragePooling2D(pool_size=(2, 2), strides=2), + keras.layers.Conv2D( + filters=16, kernel_size=(5, 5), activation="relu" + ), + keras.layers.AveragePooling2D(pool_size=(2, 2), strides=2), + keras.layers.Flatten(), + keras.layers.Dense(120, activation="relu"), + keras.layers.Dense(84, activation="relu"), + keras.layers.Dense( + 10, activation="softmax" + ), # 10 classes (digits 0-9) + ] + ) + + qconfig = { + "conv2d": { + "weights": { + "kernel": FlexQuantizer(bits=4, n_levels=10, signed=True) + }, + "activations": { + "activation": UniformQuantizer(bits=4, signed=False) + }, + }, + "conv2d_1": { + "weights": { + "kernel": FlexQuantizer(bits=4, n_levels=10, signed=True) + }, + "activations": { + "activation": UniformQuantizer(bits=4, signed=False) + }, + }, + "dense": { + "weights": { + "kernel": FlexQuantizer(bits=4, n_levels=10, signed=True) + }, + "activations": { + "activation": UniformQuantizer(bits=4, signed=False) + }, + }, + "dense_1": { + "weights": { + "kernel": FlexQuantizer(bits=4, n_levels=10, signed=True) + }, + "activations": { + "activation": UniformQuantizer(bits=4, signed=False) + }, + }, + "dense_2": { + "weights": { + "kernel": FlexQuantizer(bits=4, n_levels=10, signed=True) + }, + "activations": { + "activation": UniformQuantizer(bits=4, signed=False) + }, + }, + } + + quant_aware_model = apply_quantization(model, qconfig) + + quant_aware_model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + + model_dir = tempfile.mkdtemp() + save_qmodel(quant_aware_model, model_dir) + + load_qmodel(model_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/quantizers/flex_quantizer.py b/src/quantizers/flex_quantizer.py index ea2883b..bec60cc 100644 --- a/src/quantizers/flex_quantizer.py +++ b/src/quantizers/flex_quantizer.py @@ -238,7 +238,6 @@ def get_config(self): return { "bits": self.bits, "signed": self.signed, + "n_levels": self.n_levels, "name_suffix": self.name_suffix, - "initializer": self.initializer, - "regularizer": self.regularizer, } diff --git a/src/quantizers/uniform_quantizer.py b/src/quantizers/uniform_quantizer.py index 1467119..384b848 100755 --- a/src/quantizers/uniform_quantizer.py +++ b/src/quantizers/uniform_quantizer.py @@ -69,6 +69,7 @@ def __call__(self, w): alpha = layer.add_weight( name=f"{name}{self.name_suffix}_alpha", initializer=self.initializer, + # shape=(1,), trainable=True, dtype=tf.float32, regularizer=self.regularizer, @@ -150,3 +151,17 @@ def get_config(self): "initializer": tf.keras.initializers.serialize(self.initializer), "regularizer": tf.keras.regularizers.serialize(self.regularizer), } + + @classmethod + def from_config(cls, config): + return cls( + bits=config["bits"], + signed=config["signed"], + name_suffix=config["name_suffix"], + initializer=tf.keras.initializers.deserialize( + config["initializer"] + ), + regularizer=tf.keras.regularizers.deserialize( + config["regularizer"] + ), + )