Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions src/configs/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import (
Quantizer,
)
from tensorflow_model_optimization.python.core.quantization.keras.utils import (
deserialize_keras_object,
serialize_keras_object,
)

AttributeQuantizerDict = dict[str, Quantizer | dict[str, Quantizer]]

Expand Down Expand Up @@ -47,11 +51,29 @@ class GenerateConfig(QuantizeConfig):

def __init__(
self,
weights: AttributeQuantizerDict = {},
activations: AttributeQuantizerDict = {},
weights: AttributeQuantizerDict = None,
activations: AttributeQuantizerDict = None,
):
self.weights = flatten_nested_dict(weights)
self.activations = flatten_nested_dict(activations)
self._raw_weights = weights or {}
self._raw_activations = activations or {}
self.weights = flatten_nested_dict(self._raw_weights)
self.activations = flatten_nested_dict(self._raw_activations)

def _serialize_recursively(self, nested_obj):
"""Recursively traverses a nested dict and serializes any Keras object
(like a Quantizer) it finds."""
if isinstance(nested_obj, dict):
# If it's a dict, recurse on its values
return {
key: self._serialize_recursively(value)
for key, value in nested_obj.items()
}
elif hasattr(nested_obj, "get_config"):
# Base case: If it's a serializable object (a Quantizer), serialize it.
return serialize_keras_object(nested_obj)
else:
# It's some other primitive type, return it as is.
return nested_obj

def get_weights_and_quantizers(
self, layer: Layer
Expand Down Expand Up @@ -93,8 +115,45 @@ def get_output_quantizers(self, layer):
return []

def get_config(self):
return {"weights": self.weights, "activations": self.activations}
"""Correctly serialize by recursively traversing the nested
dictionaries."""
return {
"weights": self._serialize_recursively(self._raw_weights),
"activations": self._serialize_recursively(self._raw_activations),
}

@classmethod
def from_config(cls, config):
return cls(**config)
"""
FIX: Use a recursive helper function to deserialize the nested structure
before passing it to the constructor.
"""

def _deserialize_recursively(nested_config):
# Base Case: If the dict is a serialized Keras object, deserialize it.
if (
isinstance(nested_config, dict)
and "class_name" in nested_config
):
return deserialize_keras_object(nested_config)

# Recursive Step: If it's a dict container, recurse on its values.
if isinstance(nested_config, dict):
return {
key: _deserialize_recursively(value)
for key, value in nested_config.items()
}

# Handle lists of items
if isinstance(nested_config, list):
return [
_deserialize_recursively(item) for item in nested_config
]

# It's a primitive type, return as is.
return nested_config

return cls(
weights=_deserialize_recursively(config["weights"]),
activations=_deserialize_recursively(config["activations"]),
)
47 changes: 31 additions & 16 deletions src/configs/generate_config_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#!/usr/bin/env python3
#!/usr/bin/python3

import unittest

from tensorflow_model_optimization.python.core.quantization.keras.quantizers import (
Quantizer,
)
from tensorflow_model_optimization.python.core.quantization.keras.utils import (
serialize_keras_object,
)

from configs.generate_config import (
GenerateConfig,
Expand All @@ -19,6 +22,13 @@ class GenerateConfigTest(unittest.TestCase):

def setUp(self):
self.quantizer = unittest.mock.Mock(spec=Quantizer)
self.quantizer.get_config.return_value = {
"bits": 8,
"signed": True,
"name_suffix": "_asdasd",
"initializer": "Constant",
"regularizer": None,
}
self.layer = unittest.mock.Mock()

def test_can_instantiate_generate_config(self):
Expand Down Expand Up @@ -103,13 +113,15 @@ def test_can_get_config(self):
weights={"kernel": self.quantizer},
activations={"activation": self.quantizer},
)
self.assertDictEqual(
config.get_config(),
{
"weights": {"kernel": self.quantizer},
"activations": {"activation": self.quantizer},

expected_config = {
"weights": {"kernel": serialize_keras_object(self.quantizer)},
"activations": {
"activation": serialize_keras_object(self.quantizer)
},
)
}

self.assertDictEqual(config.get_config(), expected_config)

def test_can_get_config_with_dict(self):
"""Test that verifies that the GenerateConfig can get the
Expand All @@ -121,16 +133,19 @@ def test_can_get_config_with_dict(self):
"other": self.quantizer,
},
)
self.assertDictEqual(
config.get_config(),
{
"weights": {"kernel": self.quantizer, "bias": self.quantizer},
"activations": {
"activation": self.quantizer,
"other": self.quantizer,
},

expected_config = {
"weights": {
"kernel": serialize_keras_object(self.quantizer),
"bias": serialize_keras_object(self.quantizer),
},
)
"activations": {
"activation": serialize_keras_object(self.quantizer),
"other": serialize_keras_object(self.quantizer),
},
}

self.assertDictEqual(config.get_config(), expected_config)

def test_nested_weight_config(self):
"""Test that verifies that the GenerateConfig can be instantiated with
Expand Down
13 changes: 8 additions & 5 deletions src/configs/qmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def apply_quantization(model: Model, quantizers: LayerQuantizerDict):
)
# TODO(Fran): get below dict objects from the quantizers passed in add method (Constant comes from UniformQuantizer)
# So maybe if there are custom objects to register each class should have a method to return them
custom_objects = {}
custom_objects["GenerateConfig"] = GenerateConfig
custom_objects["UniformQuantizer"] = UniformQuantizer
custom_objects["FlexQuantizer"] = FlexQuantizer
custom_objects["Constant"] = Constant

with quantize_scope(custom_objects):
return quantize_apply(quantize_model(model, quantizers))


custom_objects = {
"GenerateConfig": GenerateConfig,
"UniformQuantizer": UniformQuantizer,
"FlexQuantizer": FlexQuantizer,
"Constant": Constant,
}
Empty file.
27 changes: 27 additions & 0 deletions src/configs/serialization/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from tensorflow.keras.models import load_model
from tensorflow_model_optimization.quantization.keras import quantize_scope

from configs.qmodel import custom_objects


def save_qmodel(model, model_dir):
"""Save the quantized model to the specified directory.

Args:
model: The quantized Keras model to save.
model_dir: Directory where the model will be saved.
"""
model.save(model_dir, save_format="tf")


def load_qmodel(model_dir):
"""Load a quantized Keras model from the specified directory.

Args:
model_dir: Directory from which to load the model.

Returns:
The loaded Keras model.
"""
with quantize_scope(custom_objects):
return load_model(model_dir, custom_objects=custom_objects)
162 changes: 162 additions & 0 deletions src/configs/serialization/serialization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
import tempfile
import unittest

from tensorflow import keras

from configs.qmodel import apply_quantization
from configs.serialization.serialization import load_qmodel, save_qmodel
from quantizers.flex_quantizer import FlexQuantizer
from quantizers.uniform_quantizer import UniformQuantizer


class TestSerialization(unittest.TestCase):
def test_single_uniform(self):
input_shape = [20]

model = keras.Sequential(
[
keras.layers.Dense(
20, input_shape=input_shape, name="dense_1"
),
keras.layers.Flatten(),
]
)

qconfig = {
"dense_1": {
"weights": {
"kernel": UniformQuantizer(8, name_suffix="_asdasd"),
},
}
}
quant_aware_model = apply_quantization(model, qconfig)

quant_aware_model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)

model_dir = tempfile.mkdtemp()
save_qmodel(quant_aware_model, model_dir)

load_qmodel(model_dir)

def test_single_flex(self):
input_shape = [20]

model = keras.Sequential(
[
keras.layers.Dense(
20, input_shape=input_shape, name="dense_1"
),
keras.layers.Flatten(),
]
)

qconfig = {
"dense_1": {
"weights": {
"kernel": FlexQuantizer(
bits=8, n_levels=10, signed=True, name_suffix="_asdasd"
),
},
}
}
quant_aware_model = apply_quantization(model, qconfig)

quant_aware_model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)

model_dir = tempfile.mkdtemp()
save_qmodel(quant_aware_model, model_dir)

load_qmodel(model_dir)

def test_lenet_5(self):
model = keras.models.Sequential(
[
keras.layers.Conv2D(
filters=6,
kernel_size=(5, 5),
activation="relu",
padding="same",
input_shape=(28, 28, 1),
),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=2),
keras.layers.Conv2D(
filters=16, kernel_size=(5, 5), activation="relu"
),
keras.layers.AveragePooling2D(pool_size=(2, 2), strides=2),
keras.layers.Flatten(),
keras.layers.Dense(120, activation="relu"),
keras.layers.Dense(84, activation="relu"),
keras.layers.Dense(
10, activation="softmax"
), # 10 classes (digits 0-9)
]
)

qconfig = {
"conv2d": {
"weights": {
"kernel": FlexQuantizer(bits=4, n_levels=10, signed=True)
},
"activations": {
"activation": UniformQuantizer(bits=4, signed=False)
},
},
"conv2d_1": {
"weights": {
"kernel": FlexQuantizer(bits=4, n_levels=10, signed=True)
},
"activations": {
"activation": UniformQuantizer(bits=4, signed=False)
},
},
"dense": {
"weights": {
"kernel": FlexQuantizer(bits=4, n_levels=10, signed=True)
},
"activations": {
"activation": UniformQuantizer(bits=4, signed=False)
},
},
"dense_1": {
"weights": {
"kernel": FlexQuantizer(bits=4, n_levels=10, signed=True)
},
"activations": {
"activation": UniformQuantizer(bits=4, signed=False)
},
},
"dense_2": {
"weights": {
"kernel": FlexQuantizer(bits=4, n_levels=10, signed=True)
},
"activations": {
"activation": UniformQuantizer(bits=4, signed=False)
},
},
}

quant_aware_model = apply_quantization(model, qconfig)

quant_aware_model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)

model_dir = tempfile.mkdtemp()
save_qmodel(quant_aware_model, model_dir)

load_qmodel(model_dir)


if __name__ == "__main__":
unittest.main()
3 changes: 1 addition & 2 deletions src/quantizers/flex_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def get_config(self):
return {
"bits": self.bits,
"signed": self.signed,
"n_levels": self.n_levels,
"name_suffix": self.name_suffix,
"initializer": self.initializer,
"regularizer": self.regularizer,
}
Loading
Loading