From d3c075614d4de07c21c5c386d9f0642adf1ed642 Mon Sep 17 00:00:00 2001 From: nroope Date: Fri, 19 Dec 2025 16:42:06 +0100 Subject: [PATCH 1/7] Update README.md --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c3c6eaf..d5d245b 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,11 @@ ## Prune and Quantize ML models PQuant is a library for training compressed machine learning models, developed at CERN as part of the [Next Generation Triggers](https://nextgentriggers.web.cern.ch/t13/) project. -Installation via pip: ```pip install pquant-ml```. -To run the code, [HGQ2](https://github.com/calad0i/HGQ2) is also needed. +Installation via pip: ```pip install pquant-ml```. + +With TensorFlow ```pip install pquant-ml[tensorflow]```. + +With PyTorch ```pip install pquant-ml[torch]```. PQuant replaces the layers and activations it finds with a Compressed (in the case of layers) or Quantized (in the case of activations) variant. These automatically handle the quantization of the weights, biases and activations, and the pruning of the weights. Both PyTorch and TensorFlow models are supported. @@ -47,6 +50,12 @@ For detailed documentation check this page: [PQuantML documentation](https://pqu ### Authors - Roope Niemi (CERN) - Anastasiia Petrovych (CERN) + - Arghya Das (Purdue University) + - Enrico Lupi (CERN) - Chang Sun (Caltech) + - Dimitrios Danopoulos (CERN) + - Marlon Joshua Helbing + - Mia Liu (Purdue University) - Michael Kagan (SLAC National Accelerator Laboratory) - Vladimir Loncar (CERN) + - Maurizio Pierini (CERN) From b31bf2c99e023c759b47f5712fed57b49bbbfc58 Mon Sep 17 00:00:00 2001 From: Anastasiia Date: Mon, 12 Jan 2026 11:47:48 +0100 Subject: [PATCH 2/7] Add removed property at training model (#22) --- src/pquant/data_models/training_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pquant/data_models/training_model.py b/src/pquant/data_models/training_model.py index 78d0c37..228c1f6 100644 --- a/src/pquant/data_models/training_model.py +++ b/src/pquant/data_models/training_model.py @@ -11,3 +11,5 @@ class BaseTrainingModel(BaseModel): rewind: str = Field(default="never") rounds: int = Field(default=1) save_weights_epoch: int = Field(default=-1) + pruning_first: bool = Field(default=False) + \ No newline at end of file From b37020941b96f58be5089531a5e17cb731242fbe Mon Sep 17 00:00:00 2001 From: Anastasiia Date: Mon, 12 Jan 2026 15:44:08 +0100 Subject: [PATCH 3/7] Modified 'post_round' function condition (#23) --- src/pquant/core/keras/layers.py | 5 +++-- src/pquant/core/torch/layers.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pquant/core/keras/layers.py b/src/pquant/core/keras/layers.py index b615cbe..823ff49 100644 --- a/src/pquant/core/keras/layers.py +++ b/src/pquant/core/keras/layers.py @@ -1347,11 +1347,12 @@ def call(self, x, training=None): def call_post_round_functions(model, rewind, rounds, r): + last_round = (r == rounds - 1) if rewind == "round": rewind_weights_functions(model) - elif rewind == "post-ticket-search" and r == rounds - 1: + elif rewind == "post-ticket-search" and last_round: rewind_weights_functions(model) - else: + elif not last_round: post_round_functions(model) diff --git a/src/pquant/core/torch/layers.py b/src/pquant/core/torch/layers.py index 6e67655..c13228f 100644 --- a/src/pquant/core/torch/layers.py +++ b/src/pquant/core/torch/layers.py @@ -1404,11 +1404,12 @@ def apply_final_compression(module): def call_post_round_functions(model, rewind, rounds, r): + last_round = (r == rounds - 1) if rewind == "round": rewind_weights_functions(model) - elif rewind == "post-ticket-search" and r == rounds - 1: + elif rewind == "post-ticket-search" and last_round: rewind_weights_functions(model) - elif r != rounds - 1: + elif not last_round: post_round_functions(model) From 0daad897c72a7c281cbfafeabfe1e07872cb3186 Mon Sep 17 00:00:00 2001 From: Roope Niemi Date: Thu, 26 Feb 2026 16:57:39 +0100 Subject: [PATCH 4/7] fix pdp unstructured shape bug --- src/pquant/pruning_methods/pdp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pquant/pruning_methods/pdp.py b/src/pquant/pruning_methods/pdp.py index dfe7b92..48799f8 100644 --- a/src/pquant/pruning_methods/pdp.py +++ b/src/pquant/pruning_methods/pdp.py @@ -31,6 +31,8 @@ def build(self, input_shape): shape = (input_shape[0], 1, 1) else: shape = (input_shape[0], 1, 1, 1) + else: + shape = input_shape self.mask = self.add_weight(shape=shape, initializer="ones", name="mask", trainable=False) self.flat_weight_size = ops.cast(ops.size(self.mask), self.mask.dtype) super().build(input_shape) From bea784279b6aa05b24180f6f8f2cbbe00113a633 Mon Sep 17 00:00:00 2001 From: Anastasiia Petrovych Date: Tue, 3 Mar 2026 17:05:33 +0100 Subject: [PATCH 5/7] Add serialization and refactor errors --- src/pquant/core/keras/layers.py | 164 +++++++++++++++++++++----------- 1 file changed, 111 insertions(+), 53 deletions(-) diff --git a/src/pquant/core/keras/layers.py b/src/pquant/core/keras/layers.py index 6a30fc5..d41f3b8 100644 --- a/src/pquant/core/keras/layers.py +++ b/src/pquant/core/keras/layers.py @@ -27,7 +27,7 @@ T = TypeVar("T") -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQWeightBiasBase(keras.layers.Layer): def __init__( self, @@ -73,7 +73,8 @@ def __init__( self.i_output = config.quantization_parameters.default_data_integer_bits self.f_output = config.quantization_parameters.default_data_fractional_bits - self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) + self.layer_type = layer_type + self.pruning_layer = get_pruning_layer(config=config, layer_type=self.layer_type) self.pruning_method = config.pruning_parameters.pruning_method self.quantize_input = quantize_input self.quantize_output = quantize_output @@ -234,44 +235,45 @@ def collect_output(self, x, training): collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) self.pruning_layer.collect_output(collect_x, training) - @classmethod - def from_config(cls, config): - # Deserialize all sublayers first - input_quantizer = keras.saving.deserialize_keras_object(config.pop("input_quantizer")) - weight_quantizer = keras.saving.deserialize_keras_object(config.pop("weight_quantizer")) - bias_quantizer = keras.saving.deserialize_keras_object(config.pop("bias_quantizer")) - output_quantizer = keras.saving.deserialize_keras_object(config.pop("output_quantizer")) - - instance = cls(**config) - instance.input_quantizer = input_quantizer - instance.weight_quantizer = weight_quantizer - instance.bias_quantizer = bias_quantizer - - if True: - instance.output_quantizer = output_quantizer - return instance def get_config(self): config = super().get_config() + config.update( { - "config": self.config, - "input_quantizer": keras.saving.serialize_keras_object(self.input_quantizer), - "weight_quantizer": keras.saving.serialize_keras_object(self.weight_quantizer), - "bias_quantizer": keras.saving.serialize_keras_object(self.bias_quantizer), + "config": self.config.model_dump(), + "layer_type": getattr(self, "layer_type", None), + "quantize_input": self.quantize_input, "quantize_output": self.quantize_output, - "in_quant_bits": self.in_quant_bits, - "weight_quant_bits": self.weight_quant_bits, - "bias_quant_bits": self.bias_quant_bits, - "out_quant_bits": self.out_quant_bits, + "enable_pruning": self.enable_pruning, + + "in_quant_bits": ( + float(self.k_input), + float(self.i_input), + float(self.f_input), + ), + "weight_quant_bits": ( + float(self.k_weight), + float(self.i_weight), + float(self.f_weight), + ), + "bias_quant_bits": ( + float(self.k_bias), + float(self.i_bias), + float(self.f_bias), + ), + "out_quant_bits": ( + float(self.k_output), + float(self.i_output), + float(self.f_output), + ), } ) - config.update({"output_quantizer": keras.saving.serialize_keras_object(self.output_quantizer)}) return config -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQDepthwiseConv2d(PQWeightBiasBase, keras.layers.DepthwiseConv2D): def __init__( self, @@ -313,7 +315,7 @@ def __init__( activation=None, use_bias=use_bias, depthwise_initializer=depthwise_initializer, - bias_initializer=bias_regularizer, + bias_initializer=bias_initializer, depthwise_regularizer=depthwise_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, @@ -464,7 +466,7 @@ def call(self, x, training=None): def apply_final_compression(self): self._kernel.assign(self.kernel) if self._bias is not None: - self._bias.assign = self.bias + self._bias.assign(self.bias) self.final_compression_done = True def extra_repr(self) -> str: @@ -480,7 +482,7 @@ def extra_repr(self) -> str: ) -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQConv2d(PQWeightBiasBase, keras.layers.Conv2D): def __init__( self, @@ -543,7 +545,7 @@ def __init__( self.weight_transpose_back = (2, 3, 1, 0) self.data_transpose = (0, 3, 1, 2) self.do_transpose_data = self.data_format == "channels_last" - self.use_biase = use_bias + self.use_bias = use_bias def build(self, input_shape): super().build(input_shape) @@ -641,7 +643,7 @@ def call(self, x, training=None): return x -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQSeparableConv2d(Layer): def __init__( self, @@ -667,7 +669,7 @@ def __init__( quantize_output=False, **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.weight_transpose = (3, 2, 0, 1) self.weight_transpose_back = (2, 3, 1, 0) self.data_transpose = (0, 3, 1, 2) @@ -719,9 +721,28 @@ def call(self, x, training=None): x = self.depthwise_conv(x, training=training) x = self.pointwise_conv(x, training=training) return x + + def get_config(self): + config = super().get_config() + config.update( + { + "config": self.depthwise_conv.config.model_dump(), + "filters": self.pointwise_conv.filters, + "kernel_size": self.depthwise_conv.kernel_size, + "strides": self.depthwise_conv.strides, + "padding": self.depthwise_conv.padding, + "data_format": self.depthwise_conv.data_format, + "dilation_rate": self.depthwise_conv.dilation_rate, + "depth_multiplier": self.depthwise_conv.depth_multiplier, + "use_bias": self.pointwise_conv.use_bias, + "quantize_input": self.depthwise_conv.quantize_input, + "quantize_output": self.pointwise_conv.quantize_output, + } + ) + return config + - -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQConv1d(PQWeightBiasBase, keras.layers.Conv1D): def __init__( self, @@ -882,7 +903,7 @@ def call(self, x, training=None): return x -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQDense(PQWeightBiasBase): def __init__( self, @@ -996,7 +1017,7 @@ def ebops(self, include_mask=False): def apply_final_compression(self): self._kernel.assign(self.kernel) if self._bias is not None: - self._bias.assign = self.bias + self._bias.assign(self.bias) self.final_compression_done = True def call(self, x, training=None): @@ -1020,6 +1041,7 @@ def get_config(self): return config +@keras.saving.register_keras_serializable(package="PQuantML") class PQBatchNormalization(keras.layers.BatchNormalization): def __init__( self, @@ -1215,8 +1237,18 @@ def get_bias_quantization_bits(self): def post_pre_train_function(self): self.is_pretraining = False + + def get_config(self): + config = super().get_config() + config.update( + { + "config": self.config.model_dump(), + "quantize_input": self.quantize_input, + } + ) + return config - +@keras.saving.register_keras_serializable(package="PQuantML") class PQAvgPoolBase(keras.layers.Layer): def __init__( self, @@ -1333,20 +1365,25 @@ def get_config(self): config = super().get_config() config.update( { - "i_input": self.i_input, - "f_input": self.f_input, - "i_output": self.i_output, - "f_output": self.f_output, - "is_pretraining": self.is_pretraining, - "overflow": self.overflow_mode_data, - "hgq_gamma": self.hgq_gamma, - "hgq_heterogeneous": self.hgq_heterogeneous, - "pooling": self.pooling, - } + "config": self.config.model_dump(), + "quantize_input": self.quantize_input, + "quantize_output": self.quantize_output, + "in_quant_bits": ( + float(self.k_input), + float(self.i_input), + float(self.f_input), + ), + "out_quant_bits": ( + float(self.k_output), + float(self.i_output), + float(self.f_output), + ), + } ) return config + - +@keras.saving.register_keras_serializable(package="PQuantML") class PQAvgPool1d(PQAvgPoolBase, keras.layers.AveragePooling1D): def __init__( self, @@ -1384,7 +1421,18 @@ def call(self, x, training=None): self.add_loss(self.hgq_loss()) return x + def get_config(self): + config = super().get_config() + config.update({ + "pool_size": self.pool_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + }) + return config + +@keras.saving.register_keras_serializable(package="PQuantML") class PQAvgPool2d(PQAvgPoolBase, keras.layers.AveragePooling2D): def __init__( self, @@ -1420,6 +1468,16 @@ def call(self, x, training=None): if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x + + def get_config(self): + config = super().get_config() + config.update({ + "pool_size": self.pool_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + }) + return config def call_post_round_functions(model, rewind, rounds, r): @@ -2152,7 +2210,7 @@ def get_enable_pruning(layer, config): enable_pruning_depthwise = enable_pruning_pointwise = True if layer.name + "_depthwise" in config.pruning_parameters.disable_pruning_for_layers: enable_pruning_depthwise = False - if layer.name + "pointwise" in config.pruning_parameters.disable_pruning_for_layers: + if layer.name + "_pointwise" in config.pruning_parameters.disable_pruning_for_layers: enable_pruning_pointwise = False return enable_pruning_depthwise, enable_pruning_pointwise else: @@ -2223,7 +2281,7 @@ def populate_config_with_all_layers(model, config): elif isinstance( layer, (Activation, ReLU, AveragePooling1D, AveragePooling2D, AveragePooling3D, PQActivation, PQAvgPoolBase) ): - custom_scheme.layer_specific[layer.name] = { + custom_scheme["layer_specific"][layer.name] = { "input": {"quantize": True, "integer_bits": 0.0, "fractional_bits": 7.0}, "output": {"quantize": True, "integer_bits": 0.0, "fractional_bits": 7.0}, } @@ -2248,7 +2306,7 @@ def post_training_prune(model, config, calibration_data): model = add_compression_layers(model, config, inputs.shape) post_pretrain_functions(model, config) model(inputs, training=True) # True so pruning works - return apply_final_compression(model, config) + return apply_final_compression(model) def get_ebops(model, **kwargs): From f8c9fa8f282458f1c05d1659f3935ccc81559d96 Mon Sep 17 00:00:00 2001 From: Anastasiia Petrovych Date: Tue, 3 Mar 2026 17:09:56 +0100 Subject: [PATCH 6/7] Modified layer type in the get_config method --- src/pquant/core/keras/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pquant/core/keras/layers.py b/src/pquant/core/keras/layers.py index d41f3b8..8fb5abd 100644 --- a/src/pquant/core/keras/layers.py +++ b/src/pquant/core/keras/layers.py @@ -242,8 +242,7 @@ def get_config(self): config.update( { "config": self.config.model_dump(), - "layer_type": getattr(self, "layer_type", None), - + "layer_type": self.layer_type, "quantize_input": self.quantize_input, "quantize_output": self.quantize_output, "enable_pruning": self.enable_pruning, From a9fe4ba70a1a57b9a89c8f6a3f3c85edc9a8a369 Mon Sep 17 00:00:00 2001 From: nroope Date: Fri, 20 Mar 2026 15:07:48 +0100 Subject: [PATCH 7/7] Add support for model fit (#33) Enable model.fit support * Refactor dynamic branching in forward pass from if-else to ops.where * Implement Callback that handles training stage switching during fit * Fix issues with serialization and deserialization --- src/pquant/configs/config_mdmm.yaml | 1 + src/pquant/core/keras/activations.py | 6 +- src/pquant/core/keras/layers.py | 601 ++++++++++++------ src/pquant/core/keras/quantizer.py | 49 +- src/pquant/core/keras/train.py | 108 ++++ src/pquant/core/torch/layers.py | 26 +- src/pquant/core/torch/quantizer.py | 7 +- src/pquant/data_models/pruning_model.py | 1 + .../pruning_methods/activation_pruning.py | 150 +++-- src/pquant/pruning_methods/autosparse.py | 84 ++- .../pruning_methods/constraint_functions.py | 7 +- src/pquant/pruning_methods/cs.py | 59 +- src/pquant/pruning_methods/dst.py | 65 +- src/pquant/pruning_methods/mdmm.py | 69 +- src/pquant/pruning_methods/pdp.py | 229 ++++--- src/pquant/pruning_methods/wanda.py | 225 ++++--- tests/test_ap.py | 2 + tests/test_keras_compression_layers.py | 583 ++++++++++++----- tests/test_pdp.py | 44 +- tests/test_torch_compression_layers.py | 85 +-- 20 files changed, 1599 insertions(+), 802 deletions(-) diff --git a/src/pquant/configs/config_mdmm.yaml b/src/pquant/configs/config_mdmm.yaml index 33e3587..abbae3c 100644 --- a/src/pquant/configs/config_mdmm.yaml +++ b/src/pquant/configs/config_mdmm.yaml @@ -14,6 +14,7 @@ pruning_parameters: damping: 1.0 use_grad: false l0_mode: "coarse" # 'coarse' or 'smooth' + constraint_lr: 1.0e-3 quantization_parameters: enable_quantization: true diff --git a/src/pquant/core/keras/activations.py b/src/pquant/core/keras/activations.py index ddc3a27..dcf5f6b 100644 --- a/src/pquant/core/keras/activations.py +++ b/src/pquant/core/keras/activations.py @@ -23,7 +23,7 @@ def hard_tanh(x): activation_registry = {"relu": relu, "tanh": tanh, "hard_tanh": hard_tanh} -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQActivation(keras.layers.Layer): def __init__( self, @@ -121,6 +121,10 @@ def set_output_quantization_bits(self, i, f): def post_pre_train_function(self): self.is_pretraining = False + if self.quantize_input: + self.input_quantizer.post_pre_train_function() + if self.quantize_output: + self.output_quantizer.post_pre_train_function() def ebops(self): if self.quantize_input and self.quantize_output: diff --git a/src/pquant/core/keras/layers.py b/src/pquant/core/keras/layers.py index 8fb5abd..aabf69f 100644 --- a/src/pquant/core/keras/layers.py +++ b/src/pquant/core/keras/layers.py @@ -17,7 +17,10 @@ SeparableConv2D, ) from keras.src.layers.input_spec import InputSpec -from keras.src.ops.operation_utils import compute_pooling_output_shape +from keras.src.ops.operation_utils import ( + compute_conv_output_shape, + compute_pooling_output_shape, +) from pquant.core.hyperparameter_optimization import PQConfig from pquant.core.keras.activations import PQActivation @@ -98,7 +101,8 @@ def __init__( self.parallelization_factor = -1 self.hgq_beta = config.quantization_parameters.hgq_beta self.input_shape = None - self.is_pretraining = True + self._is_pretraining = True + self._is_finetuning = False self.config = config self.weight_quantizer = Quantizer( @@ -168,28 +172,72 @@ def build(self, input_shape): self.input_shape = (1,) + tuple(input_shape[1:]) self.n_parallel = ops.prod(input_shape[1:-1]) self.parallelization_factor = self.parallelization_factor if self.parallelization_factor > 0 else self.n_parallel + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="float32", + ) super().build(input_shape=input_shape) def apply_final_compression(self): pass + def save_own_variables(self, store): + if not self.built: + return + all_vars = self._trainable_variables + self._non_trainable_variables + for i, v in enumerate(all_vars): + store[str(i)] = v + + def load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + f"but received {len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + for i, v in enumerate(all_vars): + v.assign(store[str(i)]) + def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) if self.pruning_layer is not None: self.pruning_layer.post_pre_train_function() + self.input_quantizer.post_pre_train_function() + self.weight_quantizer.post_pre_train_function() + self.bias_quantizer.post_pre_train_function() + self.output_quantizer.post_pre_train_function() + + def pre_finetune_function(self): + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(1.0) def save_weights(self): - self.init_weight = self.weight.value + self.init_weight = ops.copy(self._kernel) def rewind_weights(self): - self.weight.assign(self.init_weight) + self._kernel.assign(self.init_weight) def ebops(self): return 0.0 def hgq_loss(self): - if self.pruning_layer.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) + loss = self.hgq_beta * self.ebops() loss += self.weight_quantizer.hgq_loss() if self._bias is not None: @@ -198,7 +246,7 @@ def hgq_loss(self): loss += self.input_quantizer.hgq_loss() if self.quantize_output: loss += self.output_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def handle_transpose(self, x, transpose, do_transpose=False): if do_transpose: @@ -212,17 +260,17 @@ def prune(self, weight): weight = self.handle_transpose(weight, self.weight_transpose_back, True) return weight - def pre_forward(self, x, training=None): + def pre_forward(self, x, training): if self.quantize_input and self.enable_quantization: x = self.input_quantizer(x, training=training) - if self.pruning_method == "wanda": + if self.pruning_method == "wanda" and self.enable_pruning: self.collect_input(x, self._kernel, training) return x - def post_forward(self, x, training=None): + def post_forward(self, x, training): if self.quantize_output and self.enable_quantization: x = self.output_quantizer(x, training=training) - if self.pruning_method == "activation_pruning": + if self.pruning_method == "activation_pruning" and self.enable_pruning: self.collect_output(x, training) return x @@ -235,38 +283,37 @@ def collect_output(self, x, training): collect_x = self.handle_transpose(x, self.data_transpose, self.do_transpose_data) self.pruning_layer.collect_output(collect_x, training) + @classmethod + def from_config(cls, config): + # Quantizer objects are recreated by __init__ from the parent config; + # their variable values are restored from the h5 weights file by attribute name. + config.pop("input_quantizer", None) + config.pop("weight_quantizer", None) + config.pop("bias_quantizer", None) + config.pop("output_quantizer", None) + final_compression_done = config.pop("final_compression_done", False) + instance = cls(**config) + instance.final_compression_done = final_compression_done + return instance def get_config(self): config = super().get_config() config.update( { - "config": self.config.model_dump(), - "layer_type": self.layer_type, + "config": self.config.get_dict(), + "input_quantizer": keras.saving.serialize_keras_object(self.input_quantizer), + "weight_quantizer": keras.saving.serialize_keras_object(self.weight_quantizer), + "bias_quantizer": keras.saving.serialize_keras_object(self.bias_quantizer), + "output_quantizer": keras.saving.serialize_keras_object(self.output_quantizer), "quantize_input": self.quantize_input, "quantize_output": self.quantize_output, + "in_quant_bits": self.in_quant_bits, + "weight_quant_bits": self.weight_quant_bits, + "bias_quant_bits": self.bias_quant_bits, + "out_quant_bits": self.out_quant_bits, "enable_pruning": self.enable_pruning, - - "in_quant_bits": ( - float(self.k_input), - float(self.i_input), - float(self.f_input), - ), - "weight_quant_bits": ( - float(self.k_weight), - float(self.i_weight), - float(self.f_weight), - ), - "bias_quant_bits": ( - float(self.k_bias), - float(self.i_bias), - float(self.f_bias), - ), - "out_quant_bits": ( - float(self.k_output), - float(self.i_output), - float(self.f_output), - ), + "final_compression_done": self.final_compression_done, } ) return config @@ -321,7 +368,7 @@ def __init__( depthwise_constraint=depthwise_constraint, bias_constraint=bias_constraint, config=config, - layer_type="conv", + layer_type="depthwise_conv", quantize_input=quantize_input, quantize_output=quantize_output, in_quant_bits=in_quant_bits, @@ -382,7 +429,19 @@ def build(self, input_shape): if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) self.input_shape = (1,) + input_shape[1:] + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -481,8 +540,14 @@ def extra_repr(self) -> str: ) -@keras.saving.register_keras_serializable(package="PQuantML") -class PQConv2d(PQWeightBiasBase, keras.layers.Conv2D): +def _normalize_tuple(value, n): + if isinstance(value, int): + return (value,) * n + return tuple(value) + + +@keras.saving.register_keras_serializable(package="PQuant") +class PQConv2d(PQWeightBiasBase): def __init__( self, config, @@ -512,22 +577,6 @@ def __init__( **kwargs, ): super().__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - groups=groups, - activation=None, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, config=config, layer_type="conv", quantize_input=quantize_input, @@ -537,17 +586,38 @@ def __init__( bias_quant_bits=bias_quant_bits, out_quant_bits=out_quant_bits, enable_pruning=enable_pruning, + activity_regularizer=activity_regularizer, **kwargs, ) - + self.filters = filters + self.kernel_size = _normalize_tuple(kernel_size, 2) + self.strides = _normalize_tuple(strides, 2) + self.padding = padding.lower() + self.data_format = keras.backend.image_data_format() if data_format is None else data_format + self.dilation_rate = _normalize_tuple(dilation_rate, 2) + self.groups = groups + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) self.weight_transpose = (3, 2, 0, 1) self.weight_transpose_back = (2, 3, 1, 0) self.data_transpose = (0, 3, 1, 2) self.do_transpose_data = self.data_format == "channels_last" - self.use_bias = use_bias def build(self, input_shape): - super().build(input_shape) + in_channels = input_shape[-1] if self.data_format == "channels_last" else input_shape[1] + kernel_shape = self.kernel_size + (in_channels // self.groups, self.filters) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) if self.use_bias: self._bias = self.add_weight( name="bias", @@ -560,12 +630,25 @@ def build(self, input_shape): ) else: self._bias = None + super().build(input_shape) if self.use_hgq: self.input_quantizer.build(input_shape) self.weight_quantizer.build(self._kernel.shape) if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -633,14 +716,63 @@ def ebops(self, include_mask=False): ebops += ops.mean(bw_bias) * size return ebops + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def apply_final_compression(self): + self._kernel.assign(self.kernel) + if self._bias is not None: + self._bias.assign(self.bias) + self.final_compression_done = True + def call(self, x, training=None): x = self.pre_forward(x, training) - x = super().call(x) + x = ops.conv( + x, + self.kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if self.use_bias: + bias_shape = (1, 1, 1, self.filters) if self.data_format == "channels_last" else (1, self.filters, 1, 1) + x = x + ops.reshape(self.bias, bias_shape) x = self.post_forward(x, training) if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "groups": self.groups, + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize(self.kernel_initializer), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config + @keras.saving.register_keras_serializable(package="PQuantML") class PQSeparableConv2d(Layer): @@ -720,7 +852,7 @@ def call(self, x, training=None): x = self.depthwise_conv(x, training=training) x = self.pointwise_conv(x, training=training) return x - + def get_config(self): config = super().get_config() config.update( @@ -739,10 +871,10 @@ def get_config(self): } ) return config - -@keras.saving.register_keras_serializable(package="PQuantML") -class PQConv1d(PQWeightBiasBase, keras.layers.Conv1D): + +@keras.saving.register_keras_serializable(package="PQuant") +class PQConv1d(PQWeightBiasBase): def __init__( self, config, @@ -771,24 +903,7 @@ def __init__( bias_constraint=None, **kwargs, ): - super().__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - groups=groups, - activation=None, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_regularizer, - bias_constraint=bias_constraint, config=config, layer_type="conv", quantize_input=quantize_input, @@ -798,17 +913,38 @@ def __init__( bias_quant_bits=bias_quant_bits, out_quant_bits=out_quant_bits, enable_pruning=enable_pruning, + activity_regularizer=activity_regularizer, **kwargs, ) - + self.filters = filters + self.kernel_size = _normalize_tuple(kernel_size, 1) + self.strides = _normalize_tuple(strides, 1) + self.padding = padding.lower() + self.data_format = keras.backend.image_data_format() if data_format is None else data_format + self.dilation_rate = _normalize_tuple(dilation_rate, 1) + self.groups = groups + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) self.weight_transpose = (2, 1, 0) self.weight_transpose_back = (2, 1, 0) self.data_transpose = (0, 2, 1) self.do_transpose_data = self.data_format == "channels_last" - self.use_bias = use_bias def build(self, input_shape): - super().build(input_shape) + in_channels = input_shape[-1] if self.data_format == "channels_last" else input_shape[1] + kernel_shape = self.kernel_size + (in_channels // self.groups, self.filters) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) if self.use_bias: self._bias = self.add_weight( name="bias", @@ -821,12 +957,25 @@ def build(self, input_shape): ) else: self._bias = None + super().build(input_shape) if self.use_hgq: self.input_quantizer.build(input_shape) self.weight_quantizer.build(self._kernel.shape) if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -893,14 +1042,63 @@ def ebops(self, include_mask=False): ebops += ops.mean(bw_bias) * size return ebops + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def apply_final_compression(self): + self._kernel.assign(self.kernel) + if self._bias is not None: + self._bias.assign(self.bias) + self.final_compression_done = True + def call(self, x, training=None): x = self.pre_forward(x, training) - x = super().call(x) + x = ops.conv( + x, + self.kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if self.use_bias: + bias_shape = (1, 1, self.filters) if self.data_format == "channels_last" else (1, self.filters, 1) + x = x + ops.reshape(self.bias, bias_shape) x = self.post_forward(x, training) if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "groups": self.groups, + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize(self.kernel_initializer), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config + @keras.saving.register_keras_serializable(package="PQuantML") class PQDense(PQWeightBiasBase): @@ -949,6 +1147,7 @@ def __init__( self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) + self._ebops = self.add_variable(shape=(), initializer="zeros", trainable=False) def build(self, input_shape): input_dim = input_shape[-1] @@ -970,6 +1169,18 @@ def build(self, input_shape): else: self._bias = None super().build(input_shape) + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + output_shape = input_shape[:-1] + (self.units,) + self.output_quantizer.build(output_shape) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -1006,11 +1217,11 @@ def ebops(self, include_mask=False): step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype) bw_ker = bw_ker * step_size_mask ebops = ops.sum(ops.matmul(bw_inp, bw_ker)) - ebops = ebops * self.parallelization_factor / self.n_parallel if self.use_bias: bw_bias = self.bias_quantizer.get_total_bits(ops.shape(self._bias)) - size = ops.cast(ops.prod(self.input_shape), self.dtype) + size = ops.cast(ops.prod(self.input_shape[:-1]) * self.units, self.dtype) ebops += ops.mean(bw_bias) * size + ebops = ebops * self.parallelization_factor / self.n_parallel return ebops def apply_final_compression(self): @@ -1019,28 +1230,30 @@ def apply_final_compression(self): self._bias.assign(self.bias) self.final_compression_done = True + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[-1] = self.units + return tuple(output_shape) + def call(self, x, training=None): + self.training = training x = self.pre_forward(x, training) x = ops.matmul(x, self.kernel) bias = self.bias - if bias is not None: + if self.use_bias: x = ops.add(x, bias) x = self.post_forward(x, training) + if self.use_hgq: + self.add_loss(self.hgq_loss()) return x def get_config(self): config = super().get_config() - config.update( - { - "config": self.config.model_dump(), - "units": self.units, - "use_bias": self.use_bias, - } - ) + config.update({"units": self.units, "use_bias": self.use_bias}) return config -@keras.saving.register_keras_serializable(package="PQuantML") +@keras.saving.register_keras_serializable(package="PQuant") class PQBatchNormalization(keras.layers.BatchNormalization): def __init__( self, @@ -1060,8 +1273,11 @@ def __init__( gamma_constraint=None, synchronized=False, quantize_input=True, + quantize_parameters=True, **kwargs, ): + if isinstance(config, dict): + config = PQConfig.load_from_config(config) super().__init__( axis, momentum, @@ -1089,6 +1305,7 @@ def __init__( self.use_hgq = config.quantization_parameters.use_high_granularity_quantization self.hgq_beta = config.quantization_parameters.hgq_beta self.quantize_input = quantize_input + self.quantize_parameters = quantize_parameters self.granularity = config.quantization_parameters.granularity self.config = config self.f_weight = self.f_bias = ops.convert_to_tensor(config.quantization_parameters.default_weight_fractional_bits) @@ -1096,10 +1313,17 @@ def __init__( self.i_input = ops.convert_to_tensor(config.quantization_parameters.default_data_integer_bits) self.f_input = ops.convert_to_tensor(config.quantization_parameters.default_data_fractional_bits) self.final_compression_done = False - self.is_pretraining = True + self._is_pretraining = True def build(self, input_shape): super().build(input_shape) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) self.input_quantizer = Quantizer( k=1.0, i=self.i_input, @@ -1137,18 +1361,15 @@ def build(self, input_shape): shape = [1] * len(input_shape) shape[self.axis] = input_shape[self.axis] self._shape = tuple(shape) - self.input_shape = (1,) + input_shape[1:] + self.input_shape = (1,) + tuple(input_shape[1:]) def apply_final_compression(self): self.final_compression_done = True - gamma, beta = self.gamma, self.beta - if self.enable_quantization: - if gamma is not None: - gamma = self.weight_quantizer(gamma) - self.gamma.assign(gamma) - if beta is not None: - beta = self.bias_quantizer(beta) - self.beta.assign(beta) + if self.enable_quantization and self.quantize_parameters: + if self.gamma is not None: + self.gamma.assign(self.weight_quantizer(self.gamma)) + if self.beta is not None: + self.beta.assign(self.bias_quantizer(self.beta)) def ebops(self): bw_inp = self.input_quantizer.get_total_bits(self.input_shape) @@ -1159,14 +1380,14 @@ def ebops(self): return ebops def hgq_loss(self): - if self.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) loss = self.hgq_beta * self.ebops() loss += self.weight_quantizer.hgq_loss() loss += self.bias_quantizer.hgq_loss() if self.quantize_input: loss += self.input_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def call(self, inputs, training=None, mask=None): # Check if the mask has one less dimension than the inputs. @@ -1199,7 +1420,7 @@ def call(self, inputs, training=None, mask=None): if self.scale: gamma = self.gamma - if self.enable_quantization and not self.final_compression_done: + if self.enable_quantization and self.quantize_parameters and not self.final_compression_done: gamma = self.weight_quantizer(self.gamma) gamma = ops.cast(gamma, inputs.dtype) else: @@ -1207,7 +1428,7 @@ def call(self, inputs, training=None, mask=None): if self.center: beta = self.beta - if self.enable_quantization and not self.final_compression_done: + if self.enable_quantization and self.quantize_parameters and not self.final_compression_done: beta = self.bias_quantizer(self.beta) beta = ops.cast(beta, inputs.dtype) else: @@ -1235,18 +1456,30 @@ def get_bias_quantization_bits(self): return self.bias_quantizer.get_quantization_bits() def post_pre_train_function(self): - self.is_pretraining = False - + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) + + @classmethod + def from_config(cls, config): + final_compression_done = config.pop("final_compression_done", False) + instance = cls(**config) + instance.final_compression_done = final_compression_done + return instance + def get_config(self): config = super().get_config() config.update( { - "config": self.config.model_dump(), + "config": self.config.get_dict(), "quantize_input": self.quantize_input, + "quantize_parameters": self.quantize_parameters, + "final_compression_done": self.final_compression_done, } ) return config + @keras.saving.register_keras_serializable(package="PQuantML") class PQAvgPoolBase(keras.layers.Layer): def __init__( @@ -1259,8 +1492,13 @@ def __init__( **kwargs, ): + if isinstance(config, dict): + config = PQConfig.load_from_config(config) super().__init__(**kwargs) + self.in_quant_bits = in_quant_bits + self.out_quant_bits = out_quant_bits + if in_quant_bits is not None: self.k_input, self.i_input, self.f_input = in_quant_bits else: @@ -1276,7 +1514,6 @@ def __init__( self.f_output = config.quantization_parameters.default_data_fractional_bits self.overflow_mode_data = config.quantization_parameters.overflow_mode_data self.config = config - self.is_pretraining = True self.round_mode = config.quantization_parameters.round_mode self.data_k = config.quantization_parameters.default_data_keep_negatives self.use_hgq = config.quantization_parameters.use_high_granularity_quantization @@ -1284,14 +1521,26 @@ def __init__( self.hgq_gamma = config.quantization_parameters.hgq_gamma self.hgq_beta = config.quantization_parameters.hgq_beta self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous - self.saved_inputs = [] + self._is_pretraining = True self.quantize_input = quantize_input self.quantize_output = quantize_output + # BasePooling.__init__ sets built=True to skip the standard Keras build + # call, but we need build() to run so quantizers are created. + self.built = False def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) def build(self, input_shape): + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) self.input_quantizer = Quantizer( k=1.0, i=self.i_input, @@ -1314,10 +1563,9 @@ def build(self, input_shape): hgq_gamma=self.hgq_gamma, place="datalane", ) - if self.use_hgq: - self.input_quantizer.build(input_shape) - self.output_quantizer.build(self.compute_output_shape(input_shape)) - self.input_shape = (1,) + input_shape[1:] + self.input_quantizer.build(input_shape) + self.output_quantizer.build(self.compute_output_shape(input_shape)) + self.input_shape = (1,) + tuple(input_shape[1:]) def get_input_quantization_bits(self): return self.input_quantizer.get_quantization_bits() @@ -1335,8 +1583,6 @@ def compute_output_shape(self, input_shape): ) def pre_pooling(self, x, training): - if not hasattr(self, "input_quantizer"): - self.build(x.shape) if self.quantize_input and self.enable_quantization: x = self.input_quantizer(x, training=training) return x @@ -1351,38 +1597,30 @@ def ebops(self): return ops.sum(bw_inp) def hgq_loss(self): - if self.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) loss = self.hgq_beta * self.ebops() if self.quantize_input: loss += self.input_quantizer.hgq_loss() if self.quantize_output: loss += self.output_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def get_config(self): config = super().get_config() config.update( { - "config": self.config.model_dump(), - "quantize_input": self.quantize_input, - "quantize_output": self.quantize_output, - "in_quant_bits": ( - float(self.k_input), - float(self.i_input), - float(self.f_input), - ), - "out_quant_bits": ( - float(self.k_output), - float(self.i_output), - float(self.f_output), - ), - } + "config": self.config.get_dict(), + "quantize_input": self.quantize_input, + "quantize_output": self.quantize_output, + "in_quant_bits": self.in_quant_bits, + "out_quant_bits": self.out_quant_bits, + } ) return config - -@keras.saving.register_keras_serializable(package="PQuantML") + +@keras.saving.register_keras_serializable(package="PQuant") class PQAvgPool1d(PQAvgPoolBase, keras.layers.AveragePooling1D): def __init__( self, @@ -1421,17 +1659,10 @@ def call(self, x, training=None): return x def get_config(self): - config = super().get_config() - config.update({ - "pool_size": self.pool_size, - "strides": self.strides, - "padding": self.padding, - "data_format": self.data_format, - }) - return config + return super().get_config() -@keras.saving.register_keras_serializable(package="PQuantML") +@keras.saving.register_keras_serializable(package="PQuant") class PQAvgPool2d(PQAvgPoolBase, keras.layers.AveragePooling2D): def __init__( self, @@ -1467,16 +1698,9 @@ def call(self, x, training=None): if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x - + def get_config(self): - config = super().get_config() - config.update({ - "pool_size": self.pool_size, - "strides": self.strides, - "padding": self.padding, - "data_format": self.data_format, - }) - return config + return super().get_config() def call_post_round_functions(model, rewind, rounds, r): @@ -1490,15 +1714,20 @@ def call_post_round_functions(model, rewind, rounds, r): def apply_final_compression(model): - x = model.layers[0].output - for layer in model.layers[1:]: + for layer in model.layers: if isinstance(layer, (PQWeightBiasBase, PQSeparableConv2d, PQBatchNormalization, PQDepthwiseConv2d)): layer.apply_final_compression() - x = layer(x) - else: - x = layer(x) - replaced_model = keras.Model(inputs=model.inputs, outputs=x) - return replaced_model + if hasattr(layer, "input_quantizer"): + layer.input_quantizer.apply_final_compression() + if hasattr(layer, "output_quantizer"): + layer.output_quantizer.apply_final_compression() + return model + + +def _update_pruning_mask(layer): + if layer.enable_pruning and hasattr(layer.pruning_layer, "update_mask"): + kernel = layer.handle_transpose(layer._kernel, layer.weight_transpose, True) + layer.pruning_layer.update_mask(kernel) def post_epoch_functions(model, epoch, total_epochs, **kwargs): @@ -1512,10 +1741,15 @@ def post_epoch_functions(model, epoch, total_epochs, **kwargs): PQDense, ), ): - layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + if layer.enable_pruning: + layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer) elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + if layer.enable_pruning: + layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer.depthwise_conv) + layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer.pointwise_conv) def pre_epoch_functions(model, epoch, total_epochs): @@ -1529,10 +1763,12 @@ def pre_epoch_functions(model, epoch, total_epochs): PQDense, ), ): - layer.pruning_layer.pre_epoch_function(epoch, total_epochs) + if layer.enable_pruning: + layer.pruning_layer.pre_epoch_function(epoch, total_epochs) elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) - layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + if layer.enable_pruning: + layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) def post_round_functions(model): @@ -1597,9 +1833,12 @@ def pre_finetune_functions(model): PQDense, ), ): + layer.pre_finetune_function() layer.pruning_layer.pre_finetune_function() elif isinstance(layer, PQSeparableConv2d): + layer.depthwise_conv.pre_finetune_function() layer.depthwise_conv.pruning_layer.pre_finetune_function() + layer.pointwise_conv.pre_finetune_function() layer.pointwise_conv.pruning_layer.pre_finetune_function() @@ -1614,10 +1853,10 @@ def post_pretrain_functions(model, config): PQDense, ), ): - layer.pruning_layer.post_pre_train_function() + layer.post_pre_train_function() elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.post_pre_train_function() - layer.pointwise_conv.pruning_layer.post_pre_train_function() + layer.depthwise_conv.post_pre_train_function() + layer.pointwise_conv.post_pre_train_function() elif isinstance(layer, (PQActivation, PQAvgPoolBase, PQBatchNormalization)): layer.post_pre_train_function() if config.pruning_parameters.pruning_method == "pdp" or ( @@ -1839,7 +2078,6 @@ def add_compression_layers(model, config, input_shape=None): depth_multiplier=layer.depth_multiplier, data_format=layer.data_format, dilation_rate=layer.dilation_rate, - activation=layer.activation, use_bias=layer.use_bias, bias_initializer=layer.bias_initializer, depthwise_initializer=layer.depthwise_initializer, @@ -1873,7 +2111,6 @@ def add_compression_layers(model, config, input_shape=None): data_format=layer.data_format, dilation_rate=layer.dilation_rate, groups=layer.groups, - activation=layer.activation, use_bias=layer.use_bias, kernel_initializer=layer.kernel_initializer, bias_initializer=layer.bias_initializer, @@ -1928,13 +2165,13 @@ def add_compression_layers(model, config, input_shape=None): new_layer.pointwise_conv.set_enable_pruning(enable_pruning_pointwise) pruning_layer_input = layer.depthwise_kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + pruning_layer_input = ops.transpose(pruning_layer_input, new_layer.depthwise_conv.weight_transpose) new_layer.depthwise_conv.pruning_layer.build(pruning_layer_input.shape) pointwise_pruning_layer_input = layer.pointwise_kernel - transpose_shape = new_layer.weight_transpose - pointwise_pruning_layer_input = ops.transpose(pointwise_pruning_layer_input, transpose_shape) + pointwise_pruning_layer_input = ops.transpose( + pointwise_pruning_layer_input, new_layer.pointwise_conv.weight_transpose + ) new_layer.pointwise_conv.pruning_layer.build(pointwise_pruning_layer_input.shape) new_layer.depthwise_conv.build(x.shape) y = new_layer.depthwise_conv(x).shape @@ -1973,7 +2210,6 @@ def add_compression_layers(model, config, input_shape=None): new_layer = PQDense( config=config, units=layer.units, - activation=layer.activation, use_bias=layer.use_bias, kernel_initializer=layer.kernel_initializer, bias_initializer=layer.bias_initializer, @@ -2075,8 +2311,9 @@ def set_quantization_bits_activations(config, layer, new_layer): if isinstance(layer, ReLU): f_input += 1 f_output += 1 # Unsigned, add 1 bit to default value only - if layer.name in config.quantization_parameters.layer_specific: - layer_config = config.quantization_parameters.layer_specific[layer.name] + layer_specific = config.quantization_parameters.layer_specific + if layer.name in layer_specific: + layer_config = layer_specific[layer.name] if hasattr(layer, "activation") and layer.activation.__name__ in layer_config: if "input" in layer_config[layer.activation.__name__]: if "integer_bits" in layer_config[layer.activation.__name__]["input"]: @@ -2201,6 +2438,10 @@ def set_quantization_bits_weight_layers(config, layer, new_layer): new_layer.f_weight = f_bits_w new_layer.i_bias = i_bits_b new_layer.f_bias = f_bits_b + new_layer.weight_quantizer.i_init = float(i_bits_w) + new_layer.weight_quantizer.f_init = float(f_bits_w) + new_layer.bias_quantizer.i_init = float(i_bits_b) + new_layer.bias_quantizer.f_init = float(f_bits_b) def get_enable_pruning(layer, config): diff --git a/src/pquant/core/keras/quantizer.py b/src/pquant/core/keras/quantizer.py index d4398a0..38b96af 100644 --- a/src/pquant/core/keras/quantizer.py +++ b/src/pquant/core/keras/quantizer.py @@ -6,7 +6,7 @@ from pquant.core.quantizer_functions import create_quantizer -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class Quantizer(keras.layers.Layer): # HGQ quantizer wrapper def __init__( @@ -35,7 +35,7 @@ def __init__( self.quantizer = create_quantizer( self.k_init, self.i_init, self.f_init, self.overflow, self.round_mode, self.use_hgq, self.is_data, place ) - self.is_pretraining = False + self.is_pretraining = True self.hgq_gamma = hgq_gamma if isinstance(granularity, Enum): self.granularity = granularity.value @@ -63,15 +63,37 @@ def compute_dynamic_bits(self, x): return int_bits, frac_bits def build(self, input_shape): - if self.granularity == "per_tensor": + if self.use_hgq: + shape = tuple(input_shape) if not self.is_data else (1,) + tuple(input_shape[1:]) + self.k = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.k_init), trainable=False) + self.i = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.i_init), trainable=False) + self.f = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=shape, + initializer=keras.initializers.Constant(self.k_init + self.i_init + self.f_init), + trainable=False, + ) + if not self.quantizer.built: + self.quantizer.build(shape) + self.set_quantization_bits(self.i_init, self.f_init) + elif self.granularity == "per_tensor": self.k = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.k_init), trainable=False) self.i = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.i_init), trainable=False) self.f = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=(), initializer=keras.initializers.Constant(self.k_init + self.f_init + self.f_init), trainable=False + ) else: i, _ = self.compute_dynamic_bits(keras.ops.ones(input_shape)) self.k = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.k_init), trainable=False) self.i = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.i_init), trainable=False) self.f = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=i.shape, + initializer=keras.initializers.Constant(self.k_init + self.f_init + self.f_init), + trainable=False, + ) + super().build(input_shape) def get_total_bits(self, shape): @@ -84,8 +106,7 @@ def get_total_bits(self, shape): def get_quantization_bits(self): if self.use_hgq: return self.quantizer.quantizer.k, self.quantizer.quantizer.i, self.quantizer.quantizer.f - else: - return self.k, self.i, self.f + return self.k, self.i, self.f def set_quantization_bits(self, i, f): if self.use_hgq: @@ -94,8 +115,17 @@ def set_quantization_bits(self, i, f): self.i = i self.f = f - def post_pretrain(self): - self.is_pretraining = True + def apply_final_compression(self): + if self.use_hgq and not self.quantizer.built or not self.built: + return + k, i, f = self.get_quantization_bits() + self.i.assign(i) + self.f.assign(f) + self.b.assign(k + i + f) + self.final_compression_done = True + + def post_pre_train_function(self): + self.is_pretraining = False def call(self, x, training=None): if self.use_hgq: @@ -113,10 +143,7 @@ def call(self, x, training=None): def hgq_loss(self): if self.is_pretraining or not self.use_hgq: return 0.0 - loss = 0 - for layer_loss in self.quantizer.quantizer.losses: - loss += layer_loss - return loss + return sum(self.quantizer.losses) @classmethod def from_config(cls, config): diff --git a/src/pquant/core/keras/train.py b/src/pquant/core/keras/train.py index 59de0dc..2aabb3a 100644 --- a/src/pquant/core/keras/train.py +++ b/src/pquant/core/keras/train.py @@ -1,7 +1,10 @@ import keras from pquant.core.keras.layers import ( + apply_final_compression, call_post_round_functions, + get_ebops, + get_layer_keep_ratio, post_epoch_functions, post_pretrain_functions, pre_epoch_functions, @@ -10,6 +13,111 @@ ) +class PQuantCallback(keras.callbacks.Callback): + """ + Keras callback equivalent of train_model(). + + Call model.fit(epochs=callback.total_epochs, callbacks=[callback], ...). + Phase boundaries: + [0, pretraining_epochs) → pretraining + [pretraining_epochs, pretraining_epochs + rounds*epochs) → main rounds + [pretraining_epochs + rounds*epochs, total_epochs) → fine-tuning + """ + + def __init__( + self, + config, + log_ebops=True, + log_keep_ratio=True, + apply_final_compression=True, + pretraining_epochs=None, + epochs=None, + fine_tuning_epochs=None, + ): + super().__init__() + tc = config.training_parameters + self.config = config + self.pretraining_epochs = pretraining_epochs if pretraining_epochs is not None else tc.pretraining_epochs + self.rounds = tc.rounds + self.epochs_per_round = epochs if epochs is not None else tc.epochs + self.fine_tuning_epochs = fine_tuning_epochs if fine_tuning_epochs is not None else tc.fine_tuning_epochs + self.rewind = tc.rewind + self.save_weights_epoch = tc.save_weights_epoch + self.log_ebops = log_ebops + self.log_keep_ratio = log_keep_ratio + self.apply_final_compression = apply_final_compression + + self._main_end = self.pretraining_epochs + self.rounds * self.epochs_per_round + self._stage = "pretrain" if self.pretraining_epochs > 0 else "train" + + @property + def total_epochs(self): + return self._main_end + self.fine_tuning_epochs + + def on_train_begin(self, logs=None): + # post_pretrain_functions is always called; if there are no pretraining + # epochs the transition happens immediately. + if self.pretraining_epochs == 0: + post_pretrain_functions(self.model, self.config) + # pre_finetune_functions is also always called; if there are no + # main or fine-tuning epochs the transition also happens now. + if self.epochs_per_round == 0 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + + def on_epoch_begin(self, epoch, logs=None): + if epoch < self.pretraining_epochs: + pre_epoch_functions(self.model, epoch, self.pretraining_epochs) + elif epoch < self._main_end: + rel = epoch - self.pretraining_epochs + r, e = divmod(rel, self.epochs_per_round) + if r == 0 and e == self.save_weights_epoch: + save_weights_functions(self.model) + pre_epoch_functions(self.model, e, self.epochs_per_round) + else: + e = epoch - self._main_end + if e == 0: + pre_finetune_functions(self.model) + pre_epoch_functions(self.model, e, self.fine_tuning_epochs) + + @property + def stage(self): + """Current training stage: 'pretrain', 'train', or 'finetune'.""" + return self._stage + + def on_epoch_end(self, epoch, logs=None): + if epoch < self.pretraining_epochs: + self._stage = "pretrain" + e = epoch + post_epoch_functions(self.model, e, self.pretraining_epochs) + if e == self.pretraining_epochs - 1: + post_pretrain_functions(self.model, self.config) + if self.epochs_per_round == 0 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + elif epoch < self._main_end: + self._stage = "train" + rel = epoch - self.pretraining_epochs + r, e = divmod(rel, self.epochs_per_round) + post_epoch_functions(self.model, e, self.epochs_per_round) + if e == self.epochs_per_round - 1: + call_post_round_functions(self.model, self.rewind, self.rounds, r) + if r == self.rounds - 1 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + else: + self._stage = "finetune" + e = epoch - self._main_end + post_epoch_functions(self.model, e, self.fine_tuning_epochs) + if logs is not None: + logs["stage"] = self._stage + if self.log_ebops: + logs["ebops"] = get_ebops(self.model) + if self.log_keep_ratio: + logs["remaining_weights"] = get_layer_keep_ratio(self.model) + + def on_train_end(self, logs=None): # noqa: ARG002 + if self.apply_final_compression: + apply_final_compression(self.model) + + def train_model(model, config, train_func, valid_func, **kwargs): """ Generic training loop, user provides training and validation functions diff --git a/src/pquant/core/torch/layers.py b/src/pquant/core/torch/layers.py index 114a912..87ced7c 100644 --- a/src/pquant/core/torch/layers.py +++ b/src/pquant/core/torch/layers.py @@ -1127,19 +1127,19 @@ def add_layer_specific_quantization_to_model(name, layer, config): if name in config.quantization_parameters.layer_specific: layer_config = config.quantization_parameters.layer_specific[name] if "weight" in layer_config: - weight_k_bits = layer_config["weight"]["keep_negatives"] - weight_int_bits = layer_config["weight"]["integer_bits"] - weight_fractional_bits = layer_config["weight"]["fractional_bits"] - layer.k_weight = torch.tensor(weight_k_bits) - layer.i_weight = torch.tensor(weight_int_bits) - layer.f_weight = torch.tensor(weight_fractional_bits) + if "keep_negatives" in layer_config["weight"]: + layer.k_weight = torch.tensor(layer_config["weight"]["keep_negatives"]) + if "integer_bits" in layer_config["weight"]: + layer.i_weight = torch.tensor(layer_config["weight"]["integer_bits"]) + if "fractional_bits" in layer_config["weight"]: + layer.f_weight = torch.tensor(layer_config["weight"]["fractional_bits"]) if "bias" in layer_config: - bias_k_bits = layer_config["bias"]["keep_negatives"] - bias_int_bits = layer_config["bias"]["integer_bits"] - bias_fractional_bits = layer_config["bias"]["fractional_bits"] - layer.k_bias = torch.tensor(bias_k_bits) - layer.i_bias = torch.tensor(bias_int_bits) - layer.f_bias = torch.tensor(bias_fractional_bits) + if "keep_negatives" in layer_config["bias"]: + layer.k_bias = torch.tensor(layer_config["bias"]["keep_negatives"]) + if "integer_bits" in layer_config["bias"]: + layer.i_bias = torch.tensor(layer_config["bias"]["integer_bits"]) + if "fractional_bits" in layer_config["bias"]: + layer.f_bias = torch.tensor(layer_config["bias"]["fractional_bits"]) if "input" in layer_config: if "keep_negatives" in layer_config["input"]: input_keep_negatives = torch.tensor(layer_config["input"]["keep_negatives"]) @@ -1576,7 +1576,7 @@ def get_layer_keep_ratio(model): def is_training_stage(layer): - return False if layer.pruning_layer.is_finetuning or layer.pruning_layer.is_pretraining else True + return False if layer.pruning_layer._is_finetuning or layer.pruning_layer._is_pretraining else True def get_model_losses(model, losses): diff --git a/src/pquant/core/torch/quantizer.py b/src/pquant/core/torch/quantizer.py index 078de57..7fd57f8 100644 --- a/src/pquant/core/torch/quantizer.py +++ b/src/pquant/core/torch/quantizer.py @@ -29,8 +29,11 @@ def __init__( self.is_data = is_data self.i_init = i self.f_init = f + self.i = torch.nn.Parameter(torch.tensor(i), requires_grad=False) + self.f = torch.nn.Parameter(torch.tensor(f), requires_grad=False) + self.b = torch.nn.Parameter(torch.tensor(i + k + f), requires_grad=False) self.quantizer = create_quantizer(self.k, i, f, self.overflow, self.round_mode, self.use_hgq, self.is_data, place) - self.is_pretraining = False + self.is_pretraining = True self.hgq_gamma = hgq_gamma if isinstance(granularity, Enum): self.granularity = granularity.value @@ -90,7 +93,9 @@ def forward(self, x): return x else: if self.granularity == 'per_tensor': + self.initialize_quantization_parameters(self.i_init, self.f_init) _, i, f = self.get_quantization_bits() + return self.quantizer(x, k=self.k, i=i, f=f, training=self.training) else: i, f = self.compute_dynamic_bits(x) self.initialize_quantization_parameters(i, f) diff --git a/src/pquant/data_models/pruning_model.py b/src/pquant/data_models/pruning_model.py index 8b44f49..21e6607 100644 --- a/src/pquant/data_models/pruning_model.py +++ b/src/pquant/data_models/pruning_model.py @@ -88,3 +88,4 @@ class MDMMPruningModel(BasePruningModel): use_grad: bool = Field(default=False) l0_mode: Literal["coarse", "smooth"] = Field(default="coarse") scale_mode: Literal["mean", "sum"] = Field(default="mean") + constraint_lr: float = Field(default=1.0e-3) diff --git a/src/pquant/pruning_methods/activation_pruning.py b/src/pquant/pruning_methods/activation_pruning.py index 0940f88..1840129 100644 --- a/src/pquant/pruning_methods/activation_pruning.py +++ b/src/pquant/pruning_methods/activation_pruning.py @@ -12,86 +12,120 @@ def __init__(self, config, layer_type, *args, **kwargs): super().__init__(*args, **kwargs) self.config = config self.act_type = "relu" - self.t = 0 - self.batches_collected = 0 self.layer_type = layer_type - self.activations = None - self.total = 0.0 - self.is_pretraining = True - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False self.threshold = ops.convert_to_tensor(config.pruning_parameters.threshold) self.t_start_collecting_batch = self.config.pruning_parameters.t_start_collecting_batch def build(self, input_shape): self.shape = (input_shape[0], 1) - if self.layer_type == "conv": + if self.layer_type in ("conv", "depthwise_conv"): if len(input_shape) == 3: self.shape = (input_shape[0], 1, 1) else: self.shape = (input_shape[0], 1, 1, 1) + n_channels = input_shape[0] self.mask = self.add_weight(shape=self.shape, initializer="ones", trainable=False) - self.mask_placeholder = ops.ones(self.shape) + self.mask_placeholder = self.add_weight(shape=self.shape, initializer="ones", trainable=False) + self.activations = self.add_weight(shape=(n_channels,), initializer="zeros", trainable=False) + self.batches_collected = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.t = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) + super().build(input_shape) def collect_output(self, output, training): """ - Accumulates values for how often the outputs of the neurons and channels of - linear/convolution layer are over 0. Every t_delta steps, uses these values to update - the mask to prune those channels and neurons that are active less than a given threshold + Accumulates per-channel activity fractions. Every t_delta batches, updates + mask_placeholder. The actual mask used in call() is updated from mask_placeholder + in post_epoch_function (outside the compiled graph, no step-to-step dependency). """ - if not training or self.is_pretraining or self.is_finetuning: - # Don't collect during validation + if not training: return - if self.activations is None: - # Initialize activations dynamically - self.activations = ops.zeros(shape=output.shape[1:], dtype=output.dtype) - if self.t < self.t_start_collecting_batch: - return - self.batches_collected += 1 - self.total += output.shape[0] - gt_zero = ops.cast((output > 0), output.dtype) - gt_zero = ops.sum(gt_zero, axis=0) # Sum over batch, take average during mask update - self.activations += gt_zero - if self.batches_collected % self.config.pruning_parameters.t_delta == 0: - pct_active = self.activations / self.total - self.t = 0 - self.total = 0 - self.batches_collected = 0 - if self.layer_type == "linear": - self.mask_placeholder = ops.expand_dims(ops.cast((pct_active > self.threshold), pct_active.dtype), 1) - else: - pct_active = ops.reshape(pct_active, (pct_active.shape[0], -1)) - pct_active_avg = ops.mean(pct_active, axis=-1) - pct_active_above_threshold = ops.cast((pct_active_avg > self.threshold), pct_active_avg.dtype) - if len(output.shape) == 3: - self.mask_placeholder = ops.reshape( - pct_active_above_threshold, list(pct_active_above_threshold.shape) + [1, 1] - ) - else: - self.mask_placeholder = ops.reshape( - pct_active_above_threshold, list(pct_active_above_threshold.shape) + [1, 1, 1] - ) - self.activations *= 0.0 - - def call(self, weight): # Mask is only updated every t_delta step, using collect_output - if self.is_pretraining: - return weight + + t_delta = self.config.pruning_parameters.t_delta + # Collect only when training and if above the starting point of collecting + should_collect = ops.logical_not( + ops.logical_or( + ops.logical_or(self.is_pretraining, self.is_finetuning), + self.t < self.t_start_collecting_batch, + ) + ) + + # Per-channel mean activity fraction + gt_zero = ops.cast(output > 0, output.dtype) + if self.layer_type == "linear": + per_channel = ops.mean(gt_zero, axis=0) else: - return self.mask * weight + # output is channels-first (batch, channels, ...); average over batch + spatial + axes = (0,) + tuple(range(2, len(output.shape))) + per_channel = ops.mean(gt_zero, axis=axes) + + # Snapshot current state + activations_cur = ops.convert_to_tensor(self.activations) + batches_cur = ops.convert_to_tensor(self.batches_collected) + mask_ph_cur = ops.convert_to_tensor(self.mask_placeholder) + + # Accumulate (gated by should_collect) + new_activations = activations_cur + ops.where(should_collect, per_channel, ops.zeros_like(per_channel)) + new_batches = batches_cur + ops.cast(should_collect, "int32") + + # Update mask_placeholder every t_delta batches + should_update = ops.logical_and( + should_collect, + ops.equal(new_batches % t_delta, 0), + ) + + safe_batches = ops.cast(ops.maximum(new_batches, 1), new_activations.dtype) + pct_active = new_activations / safe_batches + new_mask_ph = self._compute_mask(pct_active) - def get_hard_mask(self, weight=None): - return self.mask + self.mask_placeholder.assign(ops.where(should_update, new_mask_ph, mask_ph_cur)) + + # Reset accumulators after mask update, else keep accumulated values + self.activations.assign(ops.where(should_update, ops.zeros_like(new_activations), new_activations)) + self.batches_collected.assign(ops.where(should_update, ops.zeros_like(new_batches), new_batches)) + self.t.assign(ops.where(should_update, ops.zeros_like(self.t), self.t)) + + def _compute_mask(self, pct_active): + binary = ops.cast(pct_active > self.threshold, pct_active.dtype) + return ops.reshape(binary, self.shape) + + def call(self, weight): + stored_mask = ops.convert_to_tensor(self.mask) + return ops.where(self.is_pretraining, weight, stored_mask * weight) + + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 pass def post_round_function(self): pass def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def calculate_additional_loss(self): return 0 @@ -99,14 +133,12 @@ def calculate_additional_loss(self): def get_layer_sparsity(self, weight): pass - def post_epoch_function(self, epoch, total_epochs): - if self.is_pretraining is False: - self.t += 1 - self.mask.assign(self.mask_placeholder) - pass + def post_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 + if not self._is_pretraining: + self.t.assign_add(1) + self.mask.assign(ops.convert_to_tensor(self.mask_placeholder)) def get_config(self): config = super().get_config() - config.update({"config": self.config.get_dict(), "layer_type": self.layer_type}) return config diff --git a/src/pquant/pruning_methods/autosparse.py b/src/pquant/pruning_methods/autosparse.py index 552b2c0..07c1252 100644 --- a/src/pquant/pruning_methods/autosparse.py +++ b/src/pquant/pruning_methods/autosparse.py @@ -63,10 +63,11 @@ def __init__(self, config, layer_type, *args, **kwargs): self.g = ops.sigmoid self.config = config self.layer_type = layer_type + self._alpha_init = float(config.pruning_parameters.alpha) global BACKWARD_SPARSITY BACKWARD_SPARSITY = config.pruning_parameters.backward_sparsity - self.is_pretraining = True - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): self.threshold_size = get_threshold_size(self.config, input_shape) @@ -76,30 +77,57 @@ def build(self, input_shape): initializer=Constant(self.config.pruning_parameters.threshold_init), trainable=True, ) - self.alpha = ops.convert_to_tensor(self.config.pruning_parameters.alpha, dtype="float32") + self.mask = self.add_weight( + name="mask", + shape=input_shape, + initializer="ones", + trainable=False, + ) + self.alpha = self.add_weight( + name="alpha", + shape=(), + initializer=Constant(self._alpha_init), + trainable=False, + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) def call(self, weight): - """ - sign(W) * ReLu(X), where X = |W| - sigmoid(threshold), with gradient: - 1 if W > 0 else alpha. Alpha is decayed after each epoch. - """ - if self.is_pretraining: - return weight - if self.is_finetuning: - return self.mask * weight - else: - mask = self.get_mask(weight) - self.mask = ops.reshape(mask, weight.shape) - return ops.sign(weight) * ops.reshape(mask, weight.shape) - - def get_hard_mask(self, weight=None): - return self.mask + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + w_t = ops.abs(weight_reshaped) - self.g(self.threshold) + + new_binary_mask = ops.cast(ops.reshape(w_t > 0, weight.shape), weight.dtype) + is_training = ops.logical_not(ops.logical_or(self.is_pretraining, self.is_finetuning)) + self.mask.assign(ops.where(is_training, new_binary_mask, ops.convert_to_tensor(self.mask))) + + sparse_weight = ops.sign(weight) * ops.reshape(autosparse_prune(w_t, self.alpha), weight.shape) + + return ops.where( + self.is_pretraining, + weight, + ops.where(self.is_finetuning, ops.convert_to_tensor(self.mask) * weight, sparse_weight), + ) + + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def get_mask(self, weight): weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) w_t = ops.abs(weight_reshaped) - self.g(self.threshold) - return autosparse_prune(w_t, self.alpha) + return ops.cast(ops.reshape(w_t > 0, weight.shape), weight.dtype) def get_layer_sparsity(self, weight): masked_weight = self.get_mask(weight) @@ -109,26 +137,30 @@ def get_layer_sparsity(self, weight): def pre_epoch_function(self, epoch, total_epochs): pass - def calculate_additional_loss(*args, **kwargs): + def calculate_additional_loss(self): return 0 def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def post_round_function(self): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_epoch_function(self, epoch, total_epochs): - self.alpha *= cosine_sigmoid_decay(epoch, total_epochs) - if epoch == self.config.pruning_parameters.alpha_reset_epoch: - self.alpha *= 0.0 + decay = cosine_sigmoid_decay(epoch, total_epochs) + self.alpha.assign(self._alpha_init * decay) + if epoch >= self.config.pruning_parameters.alpha_reset_epoch: + self.alpha.assign(ops.zeros_like(self.alpha)) def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/constraint_functions.py b/src/pquant/pruning_methods/constraint_functions.py index 0753996..431cdd2 100644 --- a/src/pquant/pruning_methods/constraint_functions.py +++ b/src/pquant/pruning_methods/constraint_functions.py @@ -50,7 +50,7 @@ def __init__(self, lmbda_init=1.0, scale=1.0, damping=1.0, **kwargs): trainable=False, ) - def call(self, weight): + def call(self, weight, training=None): """Calculates the penalty from a given infeasibility measure.""" raw_infeasibility = self.get_infeasibility(weight) infeasibility = self.pipe_infeasibility(raw_infeasibility) @@ -61,8 +61,9 @@ def call(self, weight): else: lmbda_step = self.lr_ * self.scale * self.prev_infs ascent_lmbda = self.lmbda + lmbda_step - self.lmbda.assign_add(lmbda_step) - self.prev_infs.assign(infeasibility) + if training: + self.lmbda.assign_add(lmbda_step) + self.prev_infs.assign(infeasibility) l_term = ascent_lmbda * infeasibility damp_term = self.damping * ops.square(infeasibility) / 2 diff --git a/src/pquant/pruning_methods/cs.py b/src/pquant/pruning_methods/cs.py index 6d5c69a..97e64ff 100644 --- a/src/pquant/pruning_methods/cs.py +++ b/src/pquant/pruning_methods/cs.py @@ -13,9 +13,9 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config self.final_temp = config.pruning_parameters.final_temp - self.is_finetuning = False + self._is_finetuning = False self.layer_type = layer_type - self.is_pretraining = True + self._is_pretraining = True def build(self, input_shape): self.s_init = ops.convert_to_tensor(self.config.pruning_parameters.threshold_init * ops.ones(input_shape)) @@ -23,37 +23,55 @@ def build(self, input_shape): self.scaling = 1.0 / ops.sigmoid(self.s_init) self.beta = self.add_weight(name="beta", shape=(), initializer=Constant(1.0), trainable=False) self.mask = self.add_weight(name="mask", shape=input_shape, initializer=Constant(1.0), trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) def call(self, weight): - if self.is_pretraining: - return weight - mask = self.get_mask() - self.mask.assign(mask) - return mask * weight + stored_mask = ops.convert_to_tensor(self.mask) + new_mask = self.get_mask() + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + updated_mask = ops.where(use_current_mask, stored_mask, new_mask) + self.mask.assign(updated_mask) + return updated_mask * weight def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) + if hasattr(self, "mask"): + self.mask.assign(self.get_hard_mask()) def get_mask(self): - if self.is_finetuning: - mask = self.get_hard_mask() - return mask - else: - mask = ops.sigmoid(self.beta * self.s) - mask = mask * self.scaling - return mask + return ops.sigmoid(self.beta * self.s) * self.scaling def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs): # noqa: ARG002 pass - def post_epoch_function(self, epoch, total_epochs): - self.beta.assign(self.beta * self.final_temp ** (1 / (total_epochs - 1))) + def post_epoch_function(self, epoch, total_epochs): # noqa: ARG002 + if total_epochs <= 1: + self.beta.assign(self.beta * self.final_temp) + else: + self.beta.assign(self.beta * self.final_temp ** (1 / (total_epochs - 1))) - def get_hard_mask(self, weight=None): + def get_hard_mask(self, weight=None): # noqa: ARG002 if self.config.pruning_parameters.enable_pruning: return ops.cast((self.s > 0), self.s.dtype) return ops.convert_to_tensor(1.0) @@ -73,7 +91,6 @@ def get_layer_sparsity(self, weight): def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/dst.py b/src/pquant/pruning_methods/dst.py index e8aac5f..e2d5ca9 100644 --- a/src/pquant/pruning_methods/dst.py +++ b/src/pquant/pruning_methods/dst.py @@ -38,14 +38,28 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config - self.is_pretraining = True self.layer_type = layer_type - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): self.threshold_size = get_threshold_size(self.config, input_shape) self.threshold = self.add_weight(shape=self.threshold_size, initializer="zeros", trainable=True) self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) def call(self, weight): """ @@ -54,19 +68,27 @@ def call(self, weight): 0.4 if 0.4 < |W| <= 1 0 if |W| > 1 """ - if self.is_pretraining: - return weight - if self.is_finetuning: - return weight * self.mask - mask = self.get_mask(weight) - ratio = 1.0 - ops.sum(mask) / ops.cast(ops.size(mask), mask.dtype) - flag = ratio >= self.config.pruning_parameters.max_pruning_pct - self.threshold.assign(ops.where(flag, ops.ones(self.threshold.shape), self.threshold)) - mask = self.get_mask(weight) - self.mask.assign(mask) - masked_weight = weight * mask + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + + def use_existing(): + return weight * ops.convert_to_tensor(self.mask) + + def compute_new(): + mask = self.get_mask(weight) + ratio = 1.0 - ops.sum(mask) / ops.cast(ops.size(mask), mask.dtype) + flag = ratio >= self.config.pruning_parameters.max_pruning_pct + + def reset_and_recalculate(): + self.threshold.assign(ops.zeros(self.threshold.shape)) + return self.get_mask(weight) + + mask = ops.cond(flag, reset_and_recalculate, lambda: mask) + self.mask.assign(mask) + return weight * mask + + result = ops.cond(use_current_mask, use_existing, compute_new) self.add_loss(self.calculate_additional_loss()) - return masked_weight + return result def get_hard_mask(self, weight=None): return self.mask @@ -86,19 +108,22 @@ def get_layer_sparsity(self, weight): return ops.sum(self.get_mask(weight)) / ops.size(weight) def calculate_additional_loss(self): - if self.is_finetuning: - return 0.0 - loss = self.config.pruning_parameters.alpha * ops.sum(ops.exp(-self.threshold)) - return loss + if self._is_pretraining or self._is_finetuning: + return ops.cast(0.0, self.threshold.dtype) + return self.config.pruning_parameters.alpha * ops.sum(ops.exp(-self.threshold)) def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def post_epoch_function(self, epoch, total_epochs): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_round_function(self): pass diff --git a/src/pquant/pruning_methods/mdmm.py b/src/pquant/pruning_methods/mdmm.py index ae88fce..5837dd0 100644 --- a/src/pquant/pruning_methods/mdmm.py +++ b/src/pquant/pruning_methods/mdmm.py @@ -26,10 +26,8 @@ def __init__(self, config, layer_type, *args, **kwargs): self.config = config self.layer_type = layer_type self.constraint_layer = None - self.penalty_loss = None - self.built = False - self.is_finetuning = False - self.is_pretraining = True + self._is_finetuning = False + self._is_pretraining = True def build(self, input_shape): pruning_parameters = self.config.pruning_parameters @@ -62,7 +60,7 @@ def build(self, input_shape): "scale": self.config.pruning_parameters.scale, "damping": self.config.pruning_parameters.damping, "use_grad": self.config.pruning_parameters.use_grad, - "lr": self.config.training_parameters.lr, + "lr": self.config.pruning_parameters.constraint_lr, } constraint_type_cls = CONSTRAINT_REGISTRY.get(constraint_type) @@ -71,49 +69,57 @@ def build(self, input_shape): else: raise ValueError(f"Unknown constraint_type: {constraint_type}") - self.mask = ops.ones(input_shape) + self.mask = self.add_weight(name="mask", shape=input_shape, initializer="ones", trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) self.constraint_layer.build(input_shape) super().build(input_shape) - self.built = True def call(self, weight): - if not self.built: - self.build(weight.shape) - - if self.is_finetuning: - self.penalty_loss = 0.0 - weight = weight * self.get_hard_mask(weight) - else: - self.penalty_loss = self.constraint_layer(weight) epsilon = self.config.pruning_parameters.epsilon - self.hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype) - return weight + hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype) + not_active = ops.logical_or(self.is_pretraining, self.is_finetuning) + self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask)) + + penalty = ops.sum(self.constraint_layer(weight)) + self.add_loss(ops.where(not_active, ops.zeros_like(penalty), penalty)) + + return ops.where(self.is_finetuning, weight * hard_mask, weight) def get_hard_mask(self, weight=None): if weight is None: - return self.hard_mask + return ops.convert_to_tensor(self.mask) epsilon = self.config.pruning_parameters.epsilon return ops.cast(ops.abs(weight) > epsilon, weight.dtype) def get_layer_sparsity(self, weight): - return ops.sum(self.get_hard_mask(weight)) / ops.size(weight) # Should this be subtracted from 1.0? + return ops.sum(self.get_hard_mask(weight)) / ops.size(weight) def calculate_additional_loss(self): - if self.penalty_loss is None: - raise ValueError("Penalty loss has not been calculated. Call the layer with weights first.") - else: - penalty_loss = ops.sum(self.penalty_loss) - - return penalty_loss + # Loss is added via self.add_loss() in call() for model.fit. + # For custom training loops, accumulate model.losses from the last forward pass instead. + return 0.0 def pre_epoch_function(self, epoch, total_epochs): pass def pre_finetune_function(self): - # Freeze the weights - # Set lmbda(s) to zero - self.is_finetuning = True - if hasattr(self.constraint_layer, 'module'): + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) + if hasattr(self.constraint_layer, "module"): self.constraint_layer.module.turn_off() else: self.constraint_layer.turn_off() @@ -122,14 +128,15 @@ def post_epoch_function(self, epoch, total_epochs): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_round_function(self): pass def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/pdp.py b/src/pquant/pruning_methods/pdp.py index 48799f8..b928ee9 100644 --- a/src/pquant/pruning_methods/pdp.py +++ b/src/pquant/pruning_methods/pdp.py @@ -10,165 +10,158 @@ def __init__(self, config, layer_type, *args, **kwargs): from pquant.core.hyperparameter_optimization import PQConfig config = PQConfig.load_from_config(config) - self.init_r = ops.convert_to_tensor(config.pruning_parameters.sparsity) - self.epsilon = ops.convert_to_tensor(config.pruning_parameters.epsilon) - self.r = config.pruning_parameters.sparsity + self._init_r = float(config.pruning_parameters.sparsity) + self._epsilon = float(config.pruning_parameters.epsilon) self.temp = config.pruning_parameters.temperature - self.is_pretraining = True self.config = config - self.is_finetuning = False self.layer_type = layer_type + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): - input_shape_concatenated = list(input_shape) + [1] - self.softmax_shape = input_shape_concatenated - self.t = ops.ones(input_shape_concatenated) * 0.5 - if self.config.pruning_parameters.structured_pruning: + self.softmax_shape = list(input_shape) + [1] + + structured = self.config.pruning_parameters.structured_pruning + if structured: if self.layer_type == "linear": - shape = (input_shape[0], 1) + mask_shape = (input_shape[0], 1) + elif len(input_shape) == 3: + mask_shape = (input_shape[0], 1, 1) else: - if len(input_shape) == 3: - shape = (input_shape[0], 1, 1) - else: - shape = (input_shape[0], 1, 1, 1) + mask_shape = (input_shape[0], 1, 1, 1) + else: + mask_shape = tuple(input_shape) + + self.mask = self.add_weight(shape=mask_shape, initializer="ones", name="mask", trainable=False) + import math + + self._mask_numel = math.prod(mask_shape) + self.flat_weight_size = float(self._mask_numel) + + # Dynamic state as Keras variables so they survive tf.function tracing + # and can be updated between epochs without retracing. + self.r = self.add_weight( + shape=(), + initializer=keras.initializers.Constant(self._init_r), + name="r", + trainable=False, + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) + + # Resolve static config branches at build time — no runtime branching needed. + if structured: + self._compute_mask = self._mask_structured_channel if self.layer_type == "conv" else self._mask_structured_linear else: - shape = input_shape - self.mask = self.add_weight(shape=shape, initializer="ones", name="mask", trainable=False) - self.flat_weight_size = ops.cast(ops.size(self.mask), self.mask.dtype) + self._compute_mask = self._mask_unstructured + super().build(input_shape) + # --- Lifecycle (called outside the training graph) --- + def post_pre_train_function(self): - self.is_pretraining = False # Enables pruning + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): - if not self.is_pretraining: - self.r = ops.minimum(1.0, self.epsilon * (epoch + 1)) * self.init_r + def pre_epoch_function(self, epoch, _): + if hasattr(self, "r"): + self.r.assign(ops.minimum(1.0, self._epsilon * (epoch + 1)) * self._init_r) def post_round_function(self): pass - def get_hard_mask(self, weight=None): - if self.is_finetuning: - return self.mask - if weight is None: - return ops.cast((self.mask >= 0.5), self.mask.dtype) - if self.config.pruning_parameters.structured_pruning: - if self.layer_type == "conv": - mask = self.get_mask_structured_channel(weight) - else: - mask = self.get_mask_structured_linear(weight) - else: - mask = self.get_mask(weight) - self.mask.assign(ops.cast((mask >= 0.5), mask.dtype)) - return self.mask - def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) if hasattr(self, "mask"): - self.mask.assign(ops.cast((self.mask >= 0.5), self.mask.dtype)) - - def get_mask_structured_linear(self, weight): - """ - Structured pruning. Use the l2 norm of the neurons instead of the absolute weight values - to calculate threshold point t. Prunes whole neurons. - """ - if self.is_pretraining: - return self.mask + self.mask.assign(ops.cast(self.mask >= 0.5, self.mask.dtype)) + + def post_epoch_function(self, epoch, total_epochs): + pass + + # --- Mask computation (graph-compatible, no Python conditionals on dynamic state) --- + + def _mask_unstructured(self, weight): + weight_reshaped = ops.reshape(weight, self.softmax_shape) + abs_flat = ops.ravel(ops.abs(weight)) + all_vals, _ = ops.top_k(abs_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, int(self.flat_weight_size) - 2) + Wh, Wt = all_vals[lim], all_vals[lim + 1] + t = ops.ones_like(weight_reshaped) * (0.5 * (Wh + Wt)) + soft_input = ops.concatenate((t**2, weight_reshaped**2), axis=-1) / self.temp + _, mw = ops.unstack(ops.softmax(soft_input, axis=-1), axis=-1) + return ops.reshape(mw, weight.shape) + + def _mask_structured_linear(self, weight): norm = ops.norm(weight, axis=1, ord=2, keepdims=True) norm_flat = ops.ravel(norm) - """ Do top_k for all neuron norms. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(norm)) to calculate t directly """ - W_all, _ = ops.top_k(norm_flat, ops.size(norm_flat)) - size = ops.cast(ops.size(W_all), self.mask.dtype) - ind = ops.cast((1 - self.r) * size, "int32") - 1 - lim = ops.clip(ind, 0, ops.cast(size - 2, "int32")) - Wh = W_all[lim] - Wt = W_all[lim + 1] - # norm = ops.expand_dims(norm, -1) + W_all, _ = ops.top_k(norm_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, self._mask_numel - 2) + Wh, Wt = W_all[lim], W_all[lim + 1] t = ops.ones(norm.shape) * 0.5 * (Wh + Wt) soft_input = ops.concatenate((t**2, norm**2), axis=1) / self.temp - softmax_result = ops.softmax(soft_input, axis=1) - _, mw = ops.unstack(softmax_result, axis=1) - mw = ops.expand_dims(mw, -1) - self.mask.assign(mw) - return mw + _, mw = ops.unstack(ops.softmax(soft_input, axis=1), axis=1) + return ops.expand_dims(mw, -1) - def get_mask_structured_channel(self, weight): - """ - Structured pruning. Use the l2 norm of the channels instead of the absolute weight values - to calculate threshold point t. Prunes whole channels. - """ - if self.is_pretraining: - return self.mask + def _mask_structured_channel(self, weight): weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) norm = ops.norm(weight_reshaped, axis=1, ord=2) - norm_flat = ops.ravel(norm) - """ Do top_k for all channel norms. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(norm)) to calculate t directly """ - W_all, _ = ops.top_k(norm_flat, ops.size(norm_flat)) - size = ops.cast(ops.size(W_all), self.mask.dtype) - ind = ops.cast((1 - self.r) * size, "int32") - 1 - lim = ops.clip(ind, 0, ops.cast(size - 2, "int32")) - - Wh = W_all[lim] - Wt = W_all[lim + 1] + W_all, _ = ops.top_k(norm_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, self._mask_numel - 2) + Wh, Wt = W_all[lim], W_all[lim + 1] norm = ops.expand_dims(norm, -1) t = ops.ones(norm.shape) * 0.5 * (Wh + Wt) soft_input = ops.concatenate((t**2, norm**2), axis=-1) / self.temp - softmax_result = ops.softmax(soft_input, axis=-1) - zw, mw = ops.unstack(softmax_result, axis=-1) - diff = len(weight.shape) - len(mw.shape) - for _ in range(diff): + zw, mw = ops.unstack(ops.softmax(soft_input, axis=-1), axis=-1) + for _ in range(len(weight.shape) - len(mw.shape)): mw = ops.expand_dims(mw, -1) - self.mask.assign(mw) return mw - def get_mask(self, weight): - if self.is_pretraining: - return self.mask - weight_reshaped = ops.reshape(weight, self.softmax_shape) - abs_weight_flat = ops.ravel(ops.abs(weight)) - """ Do top_k for all weights. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(weight)) to calculate t directly """ - all, _ = ops.top_k(abs_weight_flat, ops.size(abs_weight_flat)) - ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 # Index begins from 0 - lim = ops.clip(ind, 0, ops.cast(self.flat_weight_size - 2, "int32")) - Wh = all[lim] - Wt = all[lim + 1] - t = self.t * (Wh + Wt) - soft_input = ops.concatenate((t**2, weight_reshaped**2), axis=-1) / self.temp - softmax_result = ops.softmax(soft_input, axis=-1) - _, mw = ops.unstack(softmax_result, axis=-1) - mask = ops.reshape(mw, weight.shape) - self.mask.assign(mask) - return mask - def call(self, weight): - if self.is_finetuning: - mask = self.mask - else: - if self.config.pruning_parameters.structured_pruning: - if self.layer_type == "conv": - mask = self.get_mask_structured_channel(weight) - else: - mask = self.get_mask_structured_linear(weight) - else: - mask = self.get_mask(weight) + new_mask = self._compute_mask(weight) + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + mask = ops.where(use_current_mask, ops.convert_to_tensor(self.mask), new_mask) return mask * weight + def update_mask(self, weight): + """Update stored mask from current weights. Called once per epoch from post_epoch_functions.""" + if not self._is_pretraining and not self._is_finetuning: + self.mask.assign(self._compute_mask(weight)) + + # --- Utilities --- + + def get_hard_mask(self, weight=None): + # if weight is not None and not bool(self.is_finetuning): + # self.mask.assign(self._compute_mask(weight)) + return ops.cast(self.mask >= 0.5, self.mask.dtype) + def calculate_additional_loss(self): return 0 def get_layer_sparsity(self, weight): - masked_weight_rounded = ops.cast((self.mask >= 0.5), self.mask.dtype) - masked_weight = masked_weight_rounded * weight + hard_mask = ops.cast(self.mask >= 0.5, self.mask.dtype) + masked_weight = hard_mask * weight return ops.count_nonzero(masked_weight) / ops.size(masked_weight) - def post_epoch_function(self, epoch, total_epochs): - pass - def get_config(self): config = super().get_config() - config.update({"config": self.config.get_dict(), "layer_type": self.layer_type, "mask": self.mask}) + config.update({"config": self.config.get_dict(), "layer_type": self.layer_type}) return config diff --git a/src/pquant/pruning_methods/wanda.py b/src/pquant/pruning_methods/wanda.py index 55b202e..7057726 100644 --- a/src/pquant/pruning_methods/wanda.py +++ b/src/pquant/pruning_methods/wanda.py @@ -12,120 +12,174 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config self.act_type = "relu" - self.t = 0 self.layer_type = layer_type - self.batches_collected = 0 - self.inputs = None - self.total = 0.0 - self.done = False + self._is_pretraining = True + self._is_finetuning = False self.sparsity = self.config.pruning_parameters.sparsity - self.is_pretraining = True - self.is_finetuning = False self.N = self.config.pruning_parameters.N self.M = self.config.pruning_parameters.M self.t_start_collecting_batch = self.config.pruning_parameters.t_start_collecting_batch def build(self, input_shape): - self.mask = ops.ones(input_shape) + # input_shape is the (transposed) weight shape: (out, in) or (out, in, kH, kW) + # For depthwise_conv, weight shape is (in_ch, depth_mult, kH, kW) so n_in = input_shape[0] + n_in = input_shape[0] if self.layer_type == "depthwise_conv" else input_shape[1] + self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + # Accumulate per-input-channel sum of squared inputs; shape (n_in,) known at build time. + # Replaces storing full (batch, n_in, ...) inputs whose spatial/batch dims are unknown. + self.inputs_sq_sum = self.add_weight(shape=(n_in,), initializer="zeros", trainable=False) + self.batches_collected = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.t = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.done = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.zeros(shape), dtype), + trainable=False, + dtype="bool", + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) - def get_mask(self, weight, metric, sparsity): - d0, d1 = metric.shape - keep_idxs = ops.argsort(metric, axis=1)[:, int(d1 * sparsity) :] + ops.arange(d0)[:, None] * d1 - keep_idxs = ops.ravel(keep_idxs) - kept_values = ops.reshape( - ops.scatter(keep_idxs[:, None], ops.take(ops.ravel(weight), keep_idxs), ops.array((ops.size(weight),))), - weight.shape, + def collect_input(self, x, weight, training): + """ + Accumulates per-input-channel sum-of-squares of layer inputs. After t_delta batches, + computes mask = Wanda metric (|W| * L2_norm) once and sets done=True. One-shot pruning. + """ + if not training: + return + + t_delta = self.config.pruning_parameters.t_delta + + # Only collect in training stage if pruning hasn't been done already, and if above starting epoch + should_collect = ops.logical_not( + ops.logical_or( + ops.logical_or(self.is_pretraining, self.is_finetuning), + ops.logical_or(self.done, self.t < self.t_start_collecting_batch), + ) ) - mask = ops.cast(kept_values != 0, weight.dtype) - return mask - def handle_linear(self, x, weight): - norm = ops.norm(x, ord=2, axis=0) + # Per-batch per-channel sum of squared activations (shape: n_in,) + if self.layer_type == "linear": + per_batch_sq = ops.sum(ops.square(x), axis=0) # (n_in,) + else: + # x is channels-first (batch, in_channels, ...); sum over batch + spatial + axes = (0,) + tuple(range(2, len(x.shape))) + per_batch_sq = ops.sum(ops.square(x), axis=axes) # (in_channels,) + + # Snapshot current state + sq_sum_cur = ops.convert_to_tensor(self.inputs_sq_sum) + batches_cur = ops.convert_to_tensor(self.batches_collected) + + new_sq_sum = sq_sum_cur + ops.where(should_collect, per_batch_sq, ops.zeros_like(per_batch_sq)) + new_batches = batches_cur + ops.cast(should_collect, "int32") # Adding 0 if not collecting + + # Prune once when t_delta batches have been collected + should_prune = ops.equal(new_batches, ops.cast(t_delta, "int32")) + + mask_cur = ops.convert_to_tensor(self.mask) + + def do_prune(): + norm = ops.sqrt(new_sq_sum) + return self._compute_prune_mask(norm, weight) + + new_mask = ops.cond(should_prune, do_prune, lambda: mask_cur) + self.mask.assign(new_mask) + self.done.assign(ops.logical_or(self.done, should_prune)) + + # Reset accumulators after pruning + self.inputs_sq_sum.assign(ops.where(should_prune, ops.zeros_like(new_sq_sum), new_sq_sum)) + self.batches_collected.assign(ops.where(should_prune, ops.zeros_like(new_batches), new_batches)) + + def _compute_prune_mask(self, norm, weight): + if self.layer_type == "linear": + return self._handle_linear(norm, weight) + if self.layer_type == "depthwise_conv": + return self._handle_depthwise_conv(norm, weight) + return self._handle_conv(norm, weight) + + def _handle_linear(self, norm, weight): + # norm.shape = (in_features,); weight.shape = (out_features, in_features) metric = ops.abs(weight) * norm if self.N is not None and self.M is not None: - # N:M pruning metric_reshaped = ops.reshape(metric, (-1, self.M)) weight_reshaped = ops.reshape(weight, (-1, self.M)) mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.N / self.M) - self.mask = ops.reshape(mask, weight.shape) - else: - # Unstructured pruning - metric_reshaped = ops.reshape(metric, (1, -1)) - weight_reshaped = ops.reshape(weight, (1, -1)) - mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) - self.mask = ops.reshape(mask, weight.shape) - - def handle_conv(self, x, weight): - inputs_avg = ops.mean(ops.reshape(x, (x.shape[0], x.shape[1], -1)), axis=0) - norm = ops.norm(inputs_avg, ord=2, axis=-1) + return ops.reshape(mask, weight.shape) + metric_reshaped = ops.reshape(metric, (1, -1)) + weight_reshaped = ops.reshape(weight, (1, -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) + + def _handle_conv(self, norm, weight): + # norm.shape = (in_channels,); weight.shape = (out_channels, in_channels, ...) if len(weight.shape) == 3: - norm = ops.reshape(norm, [1] + list(norm.shape) + [1]) + norm_reshaped = ops.reshape(norm, [1] + list(norm.shape) + [1]) else: - norm = ops.reshape(norm, [1] + list(norm.shape) + [1, 1]) - metric = ops.abs(weight) * norm + norm_reshaped = ops.reshape(norm, [1] + list(norm.shape) + [1, 1]) + metric = ops.abs(weight) * norm_reshaped if self.N is not None and self.M is not None: - # N:M pruning metric_reshaped = ops.reshape(metric, (-1, self.M)) weight_reshaped = ops.reshape(weight, (-1, self.M)) mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.N / self.M) - self.mask = ops.reshape(mask, weight.shape) - else: - # Unstructured pruning - metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) - weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) - mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) - self.mask = ops.reshape(mask, weight.shape) + return ops.reshape(mask, weight.shape) + metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) + + def _handle_depthwise_conv(self, norm, weight): + # norm.shape = (in_channels,); weight.shape = (in_channels, depth_mult, kH, kW) + # Prune per-input-channel: norm[ic] scales all weights for that channel + norm_reshaped = ops.reshape(norm, list(norm.shape) + [1, 1, 1]) + metric = ops.abs(weight) * norm_reshaped + metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) - def collect_input(self, x, weight, training): - if self.done or not training: - return - """ - Accumulates layer inputs starting at step t_start_collecting for t_delta steps, then averages it. - Calculates a metric based on weight absolute values and norm of inputs. - For linear layers, calculate norm over batch dimension. - For conv layers, take average over batch dimension and calculate norm over flattened kernel_size dimension. - If N and M are defined, do N:M pruning. - """ - ok_batch = True - if self.inputs is not None: - batch_size = self.inputs.shape[0] - ok_batch = x.shape[0] == batch_size - if not training or not ok_batch: - # Don't collect during validation - return - if self.t < self.t_start_collecting_batch: - return - self.batches_collected += 1 - self.total += 1 - - self.inputs = x if self.inputs is None else self.inputs + x - if self.batches_collected % (self.config.pruning_parameters.t_delta) == 0: - inputs_avg = self.inputs / self.total - self.prune(inputs_avg, weight) - self.done = True - self.inputs = None + def get_mask(self, weight, metric, sparsity): + d0, d1 = metric.shape + keep_idxs = ops.argsort(metric, axis=1)[:, int(d1 * sparsity) :] + ops.arange(d0)[:, None] * d1 + keep_idxs = ops.ravel(keep_idxs) + kept_values = ops.reshape( + ops.scatter(keep_idxs[:, None], ops.take(ops.ravel(weight), keep_idxs), ops.array((ops.size(weight),))), + weight.shape, + ) + return ops.cast(kept_values != 0, weight.dtype) - def prune(self, x, weight): - if self.layer_type == "linear": - self.handle_linear(x, weight) - else: - self.handle_conv(x, weight) + def call(self, weight): + return ops.convert_to_tensor(self.mask) * weight - def call(self, weight): # Mask is only updated every t_delta step, using collect_output - return self.mask * weight + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 pass def post_round_function(self): pass def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def calculate_additional_loss(self): return 0 @@ -133,17 +187,12 @@ def calculate_additional_loss(self): def get_layer_sparsity(self, weight): pass - def get_hard_mask(self, weight=None): - return self.mask - - def post_epoch_function(self, epoch, total_epochs): - if self.is_pretraining is False: - self.t += 1 - pass + def post_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 + if not self._is_pretraining: + self.t.assign_add(1) def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/tests/test_ap.py b/tests/test_ap.py index 0ee480c..d126898 100644 --- a/tests/test_ap.py +++ b/tests/test_ap.py @@ -41,6 +41,7 @@ def test_linear(config): mask = ops.expand_dims(mask, -1) for _ in range(config["pruning_parameters"]["t_delta"]): ap.collect_output(layer_output, training=True) + ap.post_epoch_function(0, 1) result = ap(weight) result_masked = mask * result # Multiplying by mask should not change the result at all @@ -70,6 +71,7 @@ def test_conv(config): # mask = ops.expand_dims(ops.expand_dims(ops.expand_dims(mask, -1), -1), -1) for _ in range(config["pruning_parameters"]["t_delta"]): ap.collect_output(layer_output, training=True) + ap.post_epoch_function(0, 1) result = ap(weight) result_masked = mask * result # Multiplying by mask should not change the result at all diff --git a/tests/test_keras_compression_layers.py b/tests/test_keras_compression_layers.py index 8a2fabd..2cee1b0 100644 --- a/tests/test_keras_compression_layers.py +++ b/tests/test_keras_compression_layers.py @@ -1,4 +1,3 @@ -from types import SimpleNamespace from unittest.mock import patch import keras @@ -17,7 +16,17 @@ SeparableConv2D, ) +from pquant import ( + ap_config, + autosparse_config, + cs_config, + dst_config, + mdmm_config, + pdp_config, + wanda_config, +) from pquant.activations import PQActivation +from pquant.core.hyperparameter_optimization import PQConfig from pquant.layers import ( PQAvgPool1d, PQAvgPool2d, @@ -34,15 +43,6 @@ pre_finetune_functions, ) - -def _to_obj(x): - if isinstance(x, dict): - return SimpleNamespace(**{k: _to_obj(v) for k, v in x.items()}) - if isinstance(x, list): - return [_to_obj(v) for v in x] - return x - - BATCH_SIZE = 4 OUT_FEATURES = 32 IN_FEATURES = 16 @@ -57,165 +57,173 @@ def run_around_tests(): @pytest.fixture def config_pdp(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "epsilon": 1.0, - "pruning_method": "pdp", - "sparsity": 0.75, - "temperature": 1e-5, - "threshold_decay": 0.0, - "structured_pruning": False, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "epsilon": 1.0, + "pruning_method": "pdp", + "sparsity": 0.75, + "temperature": 1e-5, + "threshold_decay": 0.0, + "structured_pruning": False, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + "granularity": "per_tensor", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_ap(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "pruning_method": "activation_pruning", - "threshold": 0.3, - "t_start_collecting_batch": 0, - "threshold_decay": 0.0, - "t_delta": 1, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "pruning_method": "activation_pruning", + "threshold": 0.3, + "t_start_collecting_batch": 0, + "threshold_decay": 0.0, + "t_delta": 1, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_wanda(): - cfg = { - "pruning_parameters": { - "calculate_pruning_budget": False, - "disable_pruning_for_layers": [], - "enable_pruning": True, - "pruning_method": "wanda", - "sparsity": 0.75, - "t_start_collecting_batch": 0, - "threshold_decay": 0.0, - "t_delta": 1, - "N": None, - "M": None, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "calculate_pruning_budget": False, + "disable_pruning_for_layers": [], + "enable_pruning": True, + "pruning_method": "wanda", + "sparsity": 0.75, + "t_start_collecting_batch": 0, + "threshold_decay": 0.0, + "t_delta": 1, + "N": None, + "M": None, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_cs(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "final_temp": 200, - "pruning_method": "cs", - "threshold_decay": 0.0, - "threshold_init": 0.1, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "final_temp": 200, + "pruning_method": "cs", + "threshold_decay": 0.0, + "threshold_init": 0.1, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) + + +np.random.seed(42) @pytest.fixture(scope="function", autouse=True) @@ -376,19 +384,19 @@ def test_separable_conv2d_trigger_post_pretraining(config_pdp, conv2d_input): model = keras.Model(inputs=inputs, outputs=act2, name="test_conv2d") model = add_compression_layers(model, config_pdp, conv2d_input.shape) - assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining is True - assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining is True - assert model.layers[2].is_pretraining is True - assert model.layers[4].pruning_layer.is_pretraining is True - assert model.layers[5].is_pretraining is True + assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[2].is_pretraining == True # noqa: E712 + assert model.layers[4].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[5].is_pretraining == True # noqa: E712 post_pretrain_functions(model, config_pdp) - assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining is False - assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining is False - assert model.layers[2].is_pretraining is False - assert model.layers[4].pruning_layer.is_pretraining is False - assert model.layers[5].is_pretraining is False + assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[2].is_pretraining == False # noqa: E712 + assert model.layers[4].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[5].is_pretraining == False # noqa: E712 def test_conv1d_call(config_pdp, conv1d_input): @@ -1363,17 +1371,17 @@ def test_trigger_post_pretraining(config_pdp, conv2d_input): model = add_compression_layers(model, config_pdp, conv2d_input.shape) - assert model.layers[1].pruning_layer.is_pretraining is True - assert model.layers[2].is_pretraining is True - assert model.layers[3].pruning_layer.is_pretraining is True - assert model.layers[4].is_pretraining is True + assert model.layers[1].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[2].is_pretraining == True # noqa: E712 + assert model.layers[3].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[4].is_pretraining == True # noqa: E712 post_pretrain_functions(model, config_pdp) - assert model.layers[1].pruning_layer.is_pretraining is False - assert model.layers[2].is_pretraining is False - assert model.layers[3].pruning_layer.is_pretraining is False - assert model.layers[4].is_pretraining is False + assert model.layers[1].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[2].is_pretraining == False # noqa: E712 + assert model.layers[3].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[4].is_pretraining == False # noqa: E712 def test_hgq_weight_shape(config_pdp, dense_input): @@ -1709,6 +1717,15 @@ def call(self, x, *args, **kwargs): self.layer_called += 1 return x + def post_pre_train_function(self): + pass + + def get_total_bits(self, shape): + return keras.ops.ones(shape) + + def hgq_loss(self): + return 0.0 + def extra_repr(self): return f"Layer called = {self.layer_called} times." @@ -1901,3 +1918,219 @@ def test_layer_replacement_quant_called(config_pdp, conv2d_input): assert model.layers[-1].output_quantizer.layer_called == 1 model(conv2d_input, training=False) assert model.layers[-1].output_quantizer.layer_called == 2 + + +def build_model(config): + inp = keras.layers.Input(shape=((16,))) + x = PQDense( + config, + 64, + in_quant_bits=( + 1.0, + 3.0, + 3.0, + ), + )(inp) + x = keras.layers.ReLU()(x) + x = PQDense(config, 32)(x) + x = keras.layers.ReLU()(x) + x = PQDense(config, 32)(x) + x = keras.layers.ReLU()(x) + out = PQDense(config, 5, out_quant_bits=(1.0, 3.0, 3.0), quantize_output=True)(x) + return keras.Model(inputs=inp, outputs=out) + + +def test_ebops_dense_varied(config_pdp, dense_input): + config = config_pdp + config.pruning_parameters.enable_pruning = False + config.quantization_parameters.use_high_granularity_quantization = True + config.quantization_parameters.hgq_gamma = 1.0 + config.quantization_parameters.enable_quantization = True + config.quantization_parameters.overflow_mode_data = "WRAP" + config.quantization_parameters.overflow_mode_parameters = "SAT_SYM" + model = build_model(config) + model(dense_input, training=True) + post_pretrain_functions(model, config) + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_model_serialization(tmp_path, config_fn): + config = config_fn() + config.quantization_parameters.enable_quantization = True + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + conv1d_shape = (OUT_FEATURES, 16 * 16) + dummy = np.zeros((1, IN_FEATURES, 32, 32), dtype=np.float32) + bn_axis = 1 + else: + input_shape = (32, 32, IN_FEATURES) + conv1d_shape = (16 * 16, OUT_FEATURES) + dummy = np.zeros((1, 32, 32, IN_FEATURES), dtype=np.float32) + bn_axis = -1 + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = PQBatchNormalization(config, axis=bn_axis)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool2d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Reshape(conv1d_shape)(x) + x = PQConv1d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool1d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + model = keras.Model(inputs, x) + + # Call the model once to trigger build of all sublayers (pruning masks, etc.) + model(dummy) + + # Randomize weights; skip boolean flags which must stay at their 0/1 values + rng = np.random.default_rng(42) + for w in model.weights: + if w.name.endswith(("is_pretraining", "is_finetuning")): + continue + w.assign(rng.standard_normal(w.shape).astype(w.dtype)) + + pq_types = (PQConv2d, PQConv1d, PQDense) + + def assert_pq_state_survives_roundtrip(m, label): + """Save m, reload it, and verify all state round-trips correctly.""" + save_path = tmp_path / f"{label}.keras" + m.save(save_path) + reloaded = keras.models.load_model(save_path) + + orig_pq = [layer for layer in m.layers if isinstance(layer, pq_types)] + loaded_pq = [layer for layer in reloaded.layers if isinstance(layer, pq_types)] + assert len(loaded_pq) == len(orig_pq) + + for orig_l, loaded_l in zip(orig_pq, loaded_pq): + for attr in ("final_compression_done", "is_pretraining", "is_finetuning"): + np.testing.assert_equal( + getattr(loaded_l, attr), + getattr(orig_l, attr), + err_msg=f"[{label}] {orig_l.name}.{attr} mismatch", + ) + + assert len(m.weights) == len(reloaded.weights) + for orig_w, loaded_w in zip(m.weights, reloaded.weights): + np.testing.assert_array_equal( + np.array(orig_w), + np.array(loaded_w), + err_msg=f"[{label}] Weight mismatch: {orig_w.name}", + ) + + # Stage 1: initial state (is_pretraining=True, is_finetuning=False) + assert_pq_state_survives_roundtrip(model, "initial") + + # Stage 2: after post_pretrain_functions (is_pretraining=False, is_finetuning=False) + post_pretrain_functions(model, config) + assert_pq_state_survives_roundtrip(model, "post_pretrain") + + # Stage 3: after pre_finetune_functions (is_pretraining=False, is_finetuning=True) + pre_finetune_functions(model) + assert_pq_state_survives_roundtrip(model, "pre_finetune") + + # Stage 4: after apply_final_compression (final_compression_done=True on all layers) + apply_final_compression(model) + assert_pq_state_survives_roundtrip(model, "final_compression") + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_checkpoint_save_load(tmp_path, config_fn): + """Verify that save_weights/load_weights preserves non-trainable pruning state (e.g. mask).""" + config = config_fn() + config.quantization_parameters.enable_quantization = True + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + dummy = np.zeros((1, IN_FEATURES, 32, 32), dtype=np.float32) + else: + input_shape = (32, 32, IN_FEATURES) + dummy = np.zeros((1, 32, 32, IN_FEATURES), dtype=np.float32) + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + model = keras.Model(inputs, x) + model(dummy) + + # Randomize all weights including non-trainable pruning state (mask, etc.) + rng = np.random.default_rng(0) + for w in model.weights: + if w.name.endswith(("is_pretraining", "is_finetuning")): + continue + w.assign(rng.standard_normal(w.shape).astype(w.dtype)) + + original_weights = [np.array(w) for w in model.weights] + + path = str(tmp_path / "ckpt.weights.h5") + model.save_weights(path) + + # Overwrite all weights with zeros + for w in model.weights: + w.assign(np.zeros(w.shape, dtype=w.dtype)) + + model.load_weights(path) + + for orig, w in zip(original_weights, model.weights): + np.testing.assert_array_equal(orig, np.array(w), err_msg=f"Checkpoint weight mismatch: {w.name}") + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_model_fit(config_fn): + from pquant.core.keras.train import PQuantCallback + + config = config_fn() + config.quantization_parameters.enable_quantization = True + config.training_parameters.pretraining_epochs = 1 + config.training_parameters.epochs = 1 + config.training_parameters.fine_tuning_epochs = 1 + config.training_parameters.rounds = 1 + config.training_parameters.save_weights_epoch = 0 + + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + conv1d_shape = (OUT_FEATURES, 16 * 16) + dummy_x = np.zeros((BATCH_SIZE, IN_FEATURES, 32, 32), dtype=np.float32) + bn_axis = 1 + else: + input_shape = (32, 32, IN_FEATURES) + conv1d_shape = (16 * 16, OUT_FEATURES) + dummy_x = np.zeros((BATCH_SIZE, 32, 32, IN_FEATURES), dtype=np.float32) + bn_axis = -1 + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = PQBatchNormalization(config, axis=bn_axis)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool2d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Reshape(conv1d_shape)(x) + x = PQConv1d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool1d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + model = keras.Model(inputs, x) + + dummy_y = np.zeros((BATCH_SIZE, model.output_shape[-1]), dtype=np.float32) + + model.compile(optimizer="adam", loss="mse", jit_compile=False) + callback = PQuantCallback(config, log_ebops=False, log_keep_ratio=False) + model.fit(dummy_x, dummy_y, epochs=callback.total_epochs, callbacks=[callback], verbose=0) diff --git a/tests/test_pdp.py b/tests/test_pdp.py index 0cebf16..4a6a081 100644 --- a/tests/test_pdp.py +++ b/tests/test_pdp.py @@ -75,15 +75,19 @@ def test_linear_structured(config): sparsity = config["pruning_parameters"]["sparsity"] config["pruning_parameters"]["structured_pruning"] = True - inp = ops.linspace(-1, 1, num=OUT_FEATURES) + # PDP structured linear prunes rows (dim 0 = out_features). + # Weight shape is (OUT_FEATURES, IN_FEATURES) matching the transposed convention + # used by the Keras layer before passing to the pruning layer. + row_scales = ops.linspace(-1, 1, num=OUT_FEATURES) threshold_point = int(OUT_FEATURES * sparsity) - 1 - threshold_value = sorted(ops.abs(inp))[threshold_point] - inp = shuffle(inp) - mask = ops.cast((ops.abs(inp) > threshold_value), inp.dtype) - inp = ops.tile(inp, (IN_FEATURES, 1)) - mask = ops.tile(mask, (IN_FEATURES, 1)) - threshold_point = int(OUT_FEATURES * sparsity) - # In the matrix of shape NxM, prune over the M dimension + threshold_value = sorted(ops.abs(row_scales))[threshold_point] + row_scales = shuffle(row_scales) + mask_1d = ops.cast((ops.abs(row_scales) > threshold_value), row_scales.dtype) + + # Each row i has uniform value row_scales[i], giving distinct per-row norms. + inp = ops.tile(ops.reshape(row_scales, (OUT_FEATURES, 1)), (1, IN_FEATURES)) + mask = ops.tile(ops.reshape(mask_1d, (OUT_FEATURES, 1)), (1, IN_FEATURES)) + pdp = PDP(config, "linear") pdp.post_pre_train_function() pdp.build(inp.shape) @@ -99,17 +103,21 @@ def test_conv_structured(config): config["pruning_parameters"]["structured_pruning"] = True sparsity = config["pruning_parameters"]["sparsity"] - inp = ops.ones(shape=(IN_FEATURES, OUT_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) - channel_dim = ops.linspace(-1, 1, num=OUT_FEATURES) + # PDP structured conv prunes rows (dim 0 = out_channels). + # Weight shape is (OUT_FEATURES, IN_FEATURES, kH, kW) matching the transposed + # convention used by the Keras layer before passing to the pruning layer. + channel_scales = ops.linspace(-1, 1, num=OUT_FEATURES) threshold_point = int(OUT_FEATURES * sparsity) - 1 - threshold_value = sorted(ops.abs(channel_dim))[threshold_point] - channel_dim = shuffle(channel_dim) - mask = ops.cast((ops.abs(channel_dim) > threshold_value), inp.dtype) - mask = ops.expand_dims(ops.expand_dims(ops.expand_dims(mask, axis=0), axis=-1), axis=-1) - mask = ops.tile(mask, (IN_FEATURES, 1, KERNEL_SIZE, KERNEL_SIZE)) - - mult = ops.expand_dims(ops.expand_dims(ops.expand_dims(channel_dim, axis=0), axis=-1), axis=-1) - inp *= mult + threshold_value = sorted(ops.abs(channel_scales))[threshold_point] + channel_scales = shuffle(channel_scales) + mask_1d = ops.cast((ops.abs(channel_scales) > threshold_value), channel_scales.dtype) + + # Each output channel c has all spatial elements equal to channel_scales[c]. + mult = ops.reshape(channel_scales, (OUT_FEATURES, 1, 1, 1)) + inp = ops.ones(shape=(OUT_FEATURES, IN_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) * mult + mask = ops.ones(shape=(OUT_FEATURES, IN_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) * ops.reshape( + mask_1d, (OUT_FEATURES, 1, 1, 1) + ) pdp = PDP(config, "conv") pdp.post_pre_train_function() diff --git a/tests/test_torch_compression_layers.py b/tests/test_torch_compression_layers.py index ccf90a3..cb4d1ff 100644 --- a/tests/test_torch_compression_layers.py +++ b/tests/test_torch_compression_layers.py @@ -1,5 +1,3 @@ -from types import SimpleNamespace - import keras import numpy as np import pytest @@ -18,6 +16,7 @@ from pquant import post_training_prune from pquant.activations import PQActivation +from pquant.core.hyperparameter_optimization import PQConfig from pquant.layers import ( PQAvgPool1d, PQAvgPool2d, @@ -34,15 +33,6 @@ pre_finetune_functions, ) - -def _to_obj(x): - if isinstance(x, dict): - return SimpleNamespace(**{k: _to_obj(v) for k, v in x.items()}) - if isinstance(x, list): - return [_to_obj(v) for v in x] - return x - - BATCH_SIZE = 4 OUT_FEATURES = 32 IN_FEATURES = 16 @@ -84,11 +74,12 @@ def config_pdp(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -124,11 +115,12 @@ def config_ap(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -167,11 +159,12 @@ def config_wanda(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -206,11 +199,12 @@ def config_cs(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -586,17 +580,20 @@ def test_trigger_post_pretraining(config_pdp, dense_input): model = add_compression_layers(model, config_pdp, dense_input.shape) - assert model.submodule.pruning_layer.is_pretraining is True - assert model.activation.is_pretraining is True - assert model.submodule2.pruning_layer.is_pretraining is True - assert model.activation2.is_pretraining is True + def _to_bool(val): + return val.numpy() if hasattr(val, "numpy") else bool(val) + + assert _to_bool(model.submodule.pruning_layer.is_pretraining) + assert _to_bool(model.activation.is_pretraining) + assert _to_bool(model.submodule2.pruning_layer.is_pretraining) + assert _to_bool(model.activation2.is_pretraining) post_pretrain_functions(model, config_pdp) - assert model.submodule.pruning_layer.is_pretraining is False - assert model.activation.is_pretraining is False - assert model.submodule2.pruning_layer.is_pretraining is False - assert model.activation2.is_pretraining is False + assert not _to_bool(model.submodule.pruning_layer.is_pretraining) + assert not _to_bool(model.activation.is_pretraining) + assert not _to_bool(model.submodule2.pruning_layer.is_pretraining) + assert not _to_bool(model.activation2.is_pretraining) def test_hgq_weight_shape(config_pdp, dense_input): @@ -1777,15 +1774,19 @@ def test_hgq_loss_calc_no_qoutput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 3.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops @@ -1812,15 +1813,20 @@ def test_hgq_loss_calc_no_bias_no_qoutput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss m.weight_quantizer.hgq_loss = dummy_hgq_loss - m.bias_quantizer.hgq_loss = dummy_hgq_loss # Won't be called - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "bias_quantizer"): + m.bias_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 2.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops @@ -1883,19 +1889,24 @@ def test_hgq_loss_calc_no_qinput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 3.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called - m.output_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 2.0