diff --git a/src/pquant/configs/config_mdmm.yaml b/src/pquant/configs/config_mdmm.yaml index 33e3587..abbae3c 100644 --- a/src/pquant/configs/config_mdmm.yaml +++ b/src/pquant/configs/config_mdmm.yaml @@ -14,6 +14,7 @@ pruning_parameters: damping: 1.0 use_grad: false l0_mode: "coarse" # 'coarse' or 'smooth' + constraint_lr: 1.0e-3 quantization_parameters: enable_quantization: true diff --git a/src/pquant/core/keras/activations.py b/src/pquant/core/keras/activations.py index ddc3a27..dcf5f6b 100644 --- a/src/pquant/core/keras/activations.py +++ b/src/pquant/core/keras/activations.py @@ -23,7 +23,7 @@ def hard_tanh(x): activation_registry = {"relu": relu, "tanh": tanh, "hard_tanh": hard_tanh} -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQActivation(keras.layers.Layer): def __init__( self, @@ -121,6 +121,10 @@ def set_output_quantization_bits(self, i, f): def post_pre_train_function(self): self.is_pretraining = False + if self.quantize_input: + self.input_quantizer.post_pre_train_function() + if self.quantize_output: + self.output_quantizer.post_pre_train_function() def ebops(self): if self.quantize_input and self.quantize_output: diff --git a/src/pquant/core/keras/layers.py b/src/pquant/core/keras/layers.py index 6a30fc5..aabf69f 100644 --- a/src/pquant/core/keras/layers.py +++ b/src/pquant/core/keras/layers.py @@ -17,7 +17,10 @@ SeparableConv2D, ) from keras.src.layers.input_spec import InputSpec -from keras.src.ops.operation_utils import compute_pooling_output_shape +from keras.src.ops.operation_utils import ( + compute_conv_output_shape, + compute_pooling_output_shape, +) from pquant.core.hyperparameter_optimization import PQConfig from pquant.core.keras.activations import PQActivation @@ -27,7 +30,7 @@ T = TypeVar("T") -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQWeightBiasBase(keras.layers.Layer): def __init__( self, @@ -73,7 +76,8 @@ def __init__( self.i_output = config.quantization_parameters.default_data_integer_bits self.f_output = config.quantization_parameters.default_data_fractional_bits - self.pruning_layer = get_pruning_layer(config=config, layer_type=layer_type) + self.layer_type = layer_type + self.pruning_layer = get_pruning_layer(config=config, layer_type=self.layer_type) self.pruning_method = config.pruning_parameters.pruning_method self.quantize_input = quantize_input self.quantize_output = quantize_output @@ -97,7 +101,8 @@ def __init__( self.parallelization_factor = -1 self.hgq_beta = config.quantization_parameters.hgq_beta self.input_shape = None - self.is_pretraining = True + self._is_pretraining = True + self._is_finetuning = False self.config = config self.weight_quantizer = Quantizer( @@ -167,28 +172,72 @@ def build(self, input_shape): self.input_shape = (1,) + tuple(input_shape[1:]) self.n_parallel = ops.prod(input_shape[1:-1]) self.parallelization_factor = self.parallelization_factor if self.parallelization_factor > 0 else self.n_parallel + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="float32", + ) super().build(input_shape=input_shape) def apply_final_compression(self): pass + def save_own_variables(self, store): + if not self.built: + return + all_vars = self._trainable_variables + self._non_trainable_variables + for i, v in enumerate(all_vars): + store[str(i)] = v + + def load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + f"but received {len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + for i, v in enumerate(all_vars): + v.assign(store[str(i)]) + def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) if self.pruning_layer is not None: self.pruning_layer.post_pre_train_function() + self.input_quantizer.post_pre_train_function() + self.weight_quantizer.post_pre_train_function() + self.bias_quantizer.post_pre_train_function() + self.output_quantizer.post_pre_train_function() + + def pre_finetune_function(self): + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(1.0) def save_weights(self): - self.init_weight = self.weight.value + self.init_weight = ops.copy(self._kernel) def rewind_weights(self): - self.weight.assign(self.init_weight) + self._kernel.assign(self.init_weight) def ebops(self): return 0.0 def hgq_loss(self): - if self.pruning_layer.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) + loss = self.hgq_beta * self.ebops() loss += self.weight_quantizer.hgq_loss() if self._bias is not None: @@ -197,7 +246,7 @@ def hgq_loss(self): loss += self.input_quantizer.hgq_loss() if self.quantize_output: loss += self.output_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def handle_transpose(self, x, transpose, do_transpose=False): if do_transpose: @@ -211,17 +260,17 @@ def prune(self, weight): weight = self.handle_transpose(weight, self.weight_transpose_back, True) return weight - def pre_forward(self, x, training=None): + def pre_forward(self, x, training): if self.quantize_input and self.enable_quantization: x = self.input_quantizer(x, training=training) - if self.pruning_method == "wanda": + if self.pruning_method == "wanda" and self.enable_pruning: self.collect_input(x, self._kernel, training) return x - def post_forward(self, x, training=None): + def post_forward(self, x, training): if self.quantize_output and self.enable_quantization: x = self.output_quantizer(x, training=training) - if self.pruning_method == "activation_pruning": + if self.pruning_method == "activation_pruning" and self.enable_pruning: self.collect_output(x, training) return x @@ -236,42 +285,41 @@ def collect_output(self, x, training): @classmethod def from_config(cls, config): - # Deserialize all sublayers first - input_quantizer = keras.saving.deserialize_keras_object(config.pop("input_quantizer")) - weight_quantizer = keras.saving.deserialize_keras_object(config.pop("weight_quantizer")) - bias_quantizer = keras.saving.deserialize_keras_object(config.pop("bias_quantizer")) - output_quantizer = keras.saving.deserialize_keras_object(config.pop("output_quantizer")) - + # Quantizer objects are recreated by __init__ from the parent config; + # their variable values are restored from the h5 weights file by attribute name. + config.pop("input_quantizer", None) + config.pop("weight_quantizer", None) + config.pop("bias_quantizer", None) + config.pop("output_quantizer", None) + final_compression_done = config.pop("final_compression_done", False) instance = cls(**config) - instance.input_quantizer = input_quantizer - instance.weight_quantizer = weight_quantizer - instance.bias_quantizer = bias_quantizer - - if True: - instance.output_quantizer = output_quantizer + instance.final_compression_done = final_compression_done return instance def get_config(self): config = super().get_config() + config.update( { - "config": self.config, + "config": self.config.get_dict(), "input_quantizer": keras.saving.serialize_keras_object(self.input_quantizer), "weight_quantizer": keras.saving.serialize_keras_object(self.weight_quantizer), "bias_quantizer": keras.saving.serialize_keras_object(self.bias_quantizer), + "output_quantizer": keras.saving.serialize_keras_object(self.output_quantizer), "quantize_input": self.quantize_input, "quantize_output": self.quantize_output, "in_quant_bits": self.in_quant_bits, "weight_quant_bits": self.weight_quant_bits, "bias_quant_bits": self.bias_quant_bits, "out_quant_bits": self.out_quant_bits, + "enable_pruning": self.enable_pruning, + "final_compression_done": self.final_compression_done, } ) - config.update({"output_quantizer": keras.saving.serialize_keras_object(self.output_quantizer)}) return config -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class PQDepthwiseConv2d(PQWeightBiasBase, keras.layers.DepthwiseConv2D): def __init__( self, @@ -313,14 +361,14 @@ def __init__( activation=None, use_bias=use_bias, depthwise_initializer=depthwise_initializer, - bias_initializer=bias_regularizer, + bias_initializer=bias_initializer, depthwise_regularizer=depthwise_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, depthwise_constraint=depthwise_constraint, bias_constraint=bias_constraint, config=config, - layer_type="conv", + layer_type="depthwise_conv", quantize_input=quantize_input, quantize_output=quantize_output, in_quant_bits=in_quant_bits, @@ -381,7 +429,19 @@ def build(self, input_shape): if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) self.input_shape = (1,) + input_shape[1:] + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -464,7 +524,7 @@ def call(self, x, training=None): def apply_final_compression(self): self._kernel.assign(self.kernel) if self._bias is not None: - self._bias.assign = self.bias + self._bias.assign(self.bias) self.final_compression_done = True def extra_repr(self) -> str: @@ -480,8 +540,14 @@ def extra_repr(self) -> str: ) +def _normalize_tuple(value, n): + if isinstance(value, int): + return (value,) * n + return tuple(value) + + @keras.saving.register_keras_serializable(package="PQuant") -class PQConv2d(PQWeightBiasBase, keras.layers.Conv2D): +class PQConv2d(PQWeightBiasBase): def __init__( self, config, @@ -511,22 +577,6 @@ def __init__( **kwargs, ): super().__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - groups=groups, - activation=None, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, config=config, layer_type="conv", quantize_input=quantize_input, @@ -536,17 +586,38 @@ def __init__( bias_quant_bits=bias_quant_bits, out_quant_bits=out_quant_bits, enable_pruning=enable_pruning, + activity_regularizer=activity_regularizer, **kwargs, ) - + self.filters = filters + self.kernel_size = _normalize_tuple(kernel_size, 2) + self.strides = _normalize_tuple(strides, 2) + self.padding = padding.lower() + self.data_format = keras.backend.image_data_format() if data_format is None else data_format + self.dilation_rate = _normalize_tuple(dilation_rate, 2) + self.groups = groups + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) self.weight_transpose = (3, 2, 0, 1) self.weight_transpose_back = (2, 3, 1, 0) self.data_transpose = (0, 3, 1, 2) self.do_transpose_data = self.data_format == "channels_last" - self.use_biase = use_bias def build(self, input_shape): - super().build(input_shape) + in_channels = input_shape[-1] if self.data_format == "channels_last" else input_shape[1] + kernel_shape = self.kernel_size + (in_channels // self.groups, self.filters) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) if self.use_bias: self._bias = self.add_weight( name="bias", @@ -559,12 +630,25 @@ def build(self, input_shape): ) else: self._bias = None + super().build(input_shape) if self.use_hgq: self.input_quantizer.build(input_shape) self.weight_quantizer.build(self._kernel.shape) if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -632,16 +716,65 @@ def ebops(self, include_mask=False): ebops += ops.mean(bw_bias) * size return ebops + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def apply_final_compression(self): + self._kernel.assign(self.kernel) + if self._bias is not None: + self._bias.assign(self.bias) + self.final_compression_done = True + def call(self, x, training=None): x = self.pre_forward(x, training) - x = super().call(x) + x = ops.conv( + x, + self.kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if self.use_bias: + bias_shape = (1, 1, 1, self.filters) if self.data_format == "channels_last" else (1, self.filters, 1, 1) + x = x + ops.reshape(self.bias, bias_shape) x = self.post_forward(x, training) if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "groups": self.groups, + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize(self.kernel_initializer), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config -@keras.saving.register_keras_serializable(package="PQuant") + +@keras.saving.register_keras_serializable(package="PQuantML") class PQSeparableConv2d(Layer): def __init__( self, @@ -667,7 +800,7 @@ def __init__( quantize_output=False, **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.weight_transpose = (3, 2, 0, 1) self.weight_transpose_back = (2, 3, 1, 0) self.data_transpose = (0, 3, 1, 2) @@ -720,9 +853,28 @@ def call(self, x, training=None): x = self.pointwise_conv(x, training=training) return x + def get_config(self): + config = super().get_config() + config.update( + { + "config": self.depthwise_conv.config.model_dump(), + "filters": self.pointwise_conv.filters, + "kernel_size": self.depthwise_conv.kernel_size, + "strides": self.depthwise_conv.strides, + "padding": self.depthwise_conv.padding, + "data_format": self.depthwise_conv.data_format, + "dilation_rate": self.depthwise_conv.dilation_rate, + "depth_multiplier": self.depthwise_conv.depth_multiplier, + "use_bias": self.pointwise_conv.use_bias, + "quantize_input": self.depthwise_conv.quantize_input, + "quantize_output": self.pointwise_conv.quantize_output, + } + ) + return config + @keras.saving.register_keras_serializable(package="PQuant") -class PQConv1d(PQWeightBiasBase, keras.layers.Conv1D): +class PQConv1d(PQWeightBiasBase): def __init__( self, config, @@ -751,24 +903,7 @@ def __init__( bias_constraint=None, **kwargs, ): - super().__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - groups=groups, - activation=None, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_regularizer, - bias_constraint=bias_constraint, config=config, layer_type="conv", quantize_input=quantize_input, @@ -778,17 +913,38 @@ def __init__( bias_quant_bits=bias_quant_bits, out_quant_bits=out_quant_bits, enable_pruning=enable_pruning, + activity_regularizer=activity_regularizer, **kwargs, ) - + self.filters = filters + self.kernel_size = _normalize_tuple(kernel_size, 1) + self.strides = _normalize_tuple(strides, 1) + self.padding = padding.lower() + self.data_format = keras.backend.image_data_format() if data_format is None else data_format + self.dilation_rate = _normalize_tuple(dilation_rate, 1) + self.groups = groups + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) self.weight_transpose = (2, 1, 0) self.weight_transpose_back = (2, 1, 0) self.data_transpose = (0, 2, 1) self.do_transpose_data = self.data_format == "channels_last" - self.use_bias = use_bias def build(self, input_shape): - super().build(input_shape) + in_channels = input_shape[-1] if self.data_format == "channels_last" else input_shape[1] + kernel_shape = self.kernel_size + (in_channels // self.groups, self.filters) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) if self.use_bias: self._bias = self.add_weight( name="bias", @@ -801,12 +957,25 @@ def build(self, input_shape): ) else: self._bias = None + super().build(input_shape) if self.use_hgq: self.input_quantizer.build(input_shape) self.weight_quantizer.build(self._kernel.shape) if self.use_bias: self.bias_quantizer.build(self._bias.shape) self.output_quantizer.build(self.compute_output_shape(input_shape)) + else: + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + self.output_quantizer.build(self.compute_output_shape(input_shape)) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -873,16 +1042,65 @@ def ebops(self, include_mask=False): ebops += ops.mean(bw_bias) * size return ebops + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def apply_final_compression(self): + self._kernel.assign(self.kernel) + if self._bias is not None: + self._bias.assign(self.bias) + self.final_compression_done = True + def call(self, x, training=None): x = self.pre_forward(x, training) - x = super().call(x) + x = ops.conv( + x, + self.kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if self.use_bias: + bias_shape = (1, 1, self.filters) if self.data_format == "channels_last" else (1, self.filters, 1) + x = x + ops.reshape(self.bias, bias_shape) x = self.post_forward(x, training) if self.use_hgq and self.enable_quantization: self.add_loss(self.hgq_loss()) return x + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "groups": self.groups, + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize(self.kernel_initializer), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config -@keras.saving.register_keras_serializable(package="PQuant") + +@keras.saving.register_keras_serializable(package="PQuantML") class PQDense(PQWeightBiasBase): def __init__( self, @@ -929,6 +1147,7 @@ def __init__( self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) + self._ebops = self.add_variable(shape=(), initializer="zeros", trainable=False) def build(self, input_shape): input_dim = input_shape[-1] @@ -950,6 +1169,18 @@ def build(self, input_shape): else: self._bias = None super().build(input_shape) + if not self.input_quantizer.built: + self.input_quantizer.build(input_shape) + if not self.weight_quantizer.built: + self.weight_quantizer.build(self._kernel.shape) + if self.use_bias and not self.bias_quantizer.built: + self.bias_quantizer.build(self._bias.shape) + if self.quantize_output and not self.output_quantizer.built: + output_shape = input_shape[:-1] + (self.units,) + self.output_quantizer.build(output_shape) + if self.enable_pruning and self.pruning_layer is not None and not self.pruning_layer.built: + pruning_shape = tuple(self._kernel.shape[i] for i in self.weight_transpose) + self.pruning_layer.build(pruning_shape) @property def kernel(self): @@ -986,40 +1217,43 @@ def ebops(self, include_mask=False): step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype) bw_ker = bw_ker * step_size_mask ebops = ops.sum(ops.matmul(bw_inp, bw_ker)) - ebops = ebops * self.parallelization_factor / self.n_parallel if self.use_bias: bw_bias = self.bias_quantizer.get_total_bits(ops.shape(self._bias)) - size = ops.cast(ops.prod(self.input_shape), self.dtype) + size = ops.cast(ops.prod(self.input_shape[:-1]) * self.units, self.dtype) ebops += ops.mean(bw_bias) * size + ebops = ebops * self.parallelization_factor / self.n_parallel return ebops def apply_final_compression(self): self._kernel.assign(self.kernel) if self._bias is not None: - self._bias.assign = self.bias + self._bias.assign(self.bias) self.final_compression_done = True + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[-1] = self.units + return tuple(output_shape) + def call(self, x, training=None): + self.training = training x = self.pre_forward(x, training) x = ops.matmul(x, self.kernel) bias = self.bias - if bias is not None: + if self.use_bias: x = ops.add(x, bias) x = self.post_forward(x, training) + if self.use_hgq: + self.add_loss(self.hgq_loss()) return x def get_config(self): config = super().get_config() - config.update( - { - "config": self.config.model_dump(), - "units": self.units, - "use_bias": self.use_bias, - } - ) + config.update({"units": self.units, "use_bias": self.use_bias}) return config +@keras.saving.register_keras_serializable(package="PQuant") class PQBatchNormalization(keras.layers.BatchNormalization): def __init__( self, @@ -1039,8 +1273,11 @@ def __init__( gamma_constraint=None, synchronized=False, quantize_input=True, + quantize_parameters=True, **kwargs, ): + if isinstance(config, dict): + config = PQConfig.load_from_config(config) super().__init__( axis, momentum, @@ -1068,6 +1305,7 @@ def __init__( self.use_hgq = config.quantization_parameters.use_high_granularity_quantization self.hgq_beta = config.quantization_parameters.hgq_beta self.quantize_input = quantize_input + self.quantize_parameters = quantize_parameters self.granularity = config.quantization_parameters.granularity self.config = config self.f_weight = self.f_bias = ops.convert_to_tensor(config.quantization_parameters.default_weight_fractional_bits) @@ -1075,10 +1313,17 @@ def __init__( self.i_input = ops.convert_to_tensor(config.quantization_parameters.default_data_integer_bits) self.f_input = ops.convert_to_tensor(config.quantization_parameters.default_data_fractional_bits) self.final_compression_done = False - self.is_pretraining = True + self._is_pretraining = True def build(self, input_shape): super().build(input_shape) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) self.input_quantizer = Quantizer( k=1.0, i=self.i_input, @@ -1116,18 +1361,15 @@ def build(self, input_shape): shape = [1] * len(input_shape) shape[self.axis] = input_shape[self.axis] self._shape = tuple(shape) - self.input_shape = (1,) + input_shape[1:] + self.input_shape = (1,) + tuple(input_shape[1:]) def apply_final_compression(self): self.final_compression_done = True - gamma, beta = self.gamma, self.beta - if self.enable_quantization: - if gamma is not None: - gamma = self.weight_quantizer(gamma) - self.gamma.assign(gamma) - if beta is not None: - beta = self.bias_quantizer(beta) - self.beta.assign(beta) + if self.enable_quantization and self.quantize_parameters: + if self.gamma is not None: + self.gamma.assign(self.weight_quantizer(self.gamma)) + if self.beta is not None: + self.beta.assign(self.bias_quantizer(self.beta)) def ebops(self): bw_inp = self.input_quantizer.get_total_bits(self.input_shape) @@ -1138,14 +1380,14 @@ def ebops(self): return ebops def hgq_loss(self): - if self.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) loss = self.hgq_beta * self.ebops() loss += self.weight_quantizer.hgq_loss() loss += self.bias_quantizer.hgq_loss() if self.quantize_input: loss += self.input_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def call(self, inputs, training=None, mask=None): # Check if the mask has one less dimension than the inputs. @@ -1178,7 +1420,7 @@ def call(self, inputs, training=None, mask=None): if self.scale: gamma = self.gamma - if self.enable_quantization and not self.final_compression_done: + if self.enable_quantization and self.quantize_parameters and not self.final_compression_done: gamma = self.weight_quantizer(self.gamma) gamma = ops.cast(gamma, inputs.dtype) else: @@ -1186,7 +1428,7 @@ def call(self, inputs, training=None, mask=None): if self.center: beta = self.beta - if self.enable_quantization and not self.final_compression_done: + if self.enable_quantization and self.quantize_parameters and not self.final_compression_done: beta = self.bias_quantizer(self.beta) beta = ops.cast(beta, inputs.dtype) else: @@ -1214,9 +1456,31 @@ def get_bias_quantization_bits(self): return self.bias_quantizer.get_quantization_bits() def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) + @classmethod + def from_config(cls, config): + final_compression_done = config.pop("final_compression_done", False) + instance = cls(**config) + instance.final_compression_done = final_compression_done + return instance + def get_config(self): + config = super().get_config() + config.update( + { + "config": self.config.get_dict(), + "quantize_input": self.quantize_input, + "quantize_parameters": self.quantize_parameters, + "final_compression_done": self.final_compression_done, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="PQuantML") class PQAvgPoolBase(keras.layers.Layer): def __init__( self, @@ -1228,8 +1492,13 @@ def __init__( **kwargs, ): + if isinstance(config, dict): + config = PQConfig.load_from_config(config) super().__init__(**kwargs) + self.in_quant_bits = in_quant_bits + self.out_quant_bits = out_quant_bits + if in_quant_bits is not None: self.k_input, self.i_input, self.f_input = in_quant_bits else: @@ -1245,7 +1514,6 @@ def __init__( self.f_output = config.quantization_parameters.default_data_fractional_bits self.overflow_mode_data = config.quantization_parameters.overflow_mode_data self.config = config - self.is_pretraining = True self.round_mode = config.quantization_parameters.round_mode self.data_k = config.quantization_parameters.default_data_keep_negatives self.use_hgq = config.quantization_parameters.use_high_granularity_quantization @@ -1253,14 +1521,26 @@ def __init__( self.hgq_gamma = config.quantization_parameters.hgq_gamma self.hgq_beta = config.quantization_parameters.hgq_beta self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous - self.saved_inputs = [] + self._is_pretraining = True self.quantize_input = quantize_input self.quantize_output = quantize_output + # BasePooling.__init__ sets built=True to skip the standard Keras build + # call, but we need build() to run so quantizers are created. + self.built = False def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(0.0) def build(self, input_shape): + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="float32", + ) self.input_quantizer = Quantizer( k=1.0, i=self.i_input, @@ -1283,10 +1563,9 @@ def build(self, input_shape): hgq_gamma=self.hgq_gamma, place="datalane", ) - if self.use_hgq: - self.input_quantizer.build(input_shape) - self.output_quantizer.build(self.compute_output_shape(input_shape)) - self.input_shape = (1,) + input_shape[1:] + self.input_quantizer.build(input_shape) + self.output_quantizer.build(self.compute_output_shape(input_shape)) + self.input_shape = (1,) + tuple(input_shape[1:]) def get_input_quantization_bits(self): return self.input_quantizer.get_quantization_bits() @@ -1304,8 +1583,6 @@ def compute_output_shape(self, input_shape): ) def pre_pooling(self, x, training): - if not hasattr(self, "input_quantizer"): - self.build(x.shape) if self.quantize_input and self.enable_quantization: x = self.input_quantizer(x, training=training) return x @@ -1320,33 +1597,30 @@ def ebops(self): return ops.sum(bw_inp) def hgq_loss(self): - if self.is_pretraining or not self.use_hgq: + if not self.use_hgq: return ops.convert_to_tensor(0.0) loss = self.hgq_beta * self.ebops() if self.quantize_input: loss += self.input_quantizer.hgq_loss() if self.quantize_output: loss += self.output_quantizer.hgq_loss() - return loss + return ops.where(ops.cast(self.is_pretraining, "bool"), ops.zeros_like(loss), loss) def get_config(self): config = super().get_config() config.update( { - "i_input": self.i_input, - "f_input": self.f_input, - "i_output": self.i_output, - "f_output": self.f_output, - "is_pretraining": self.is_pretraining, - "overflow": self.overflow_mode_data, - "hgq_gamma": self.hgq_gamma, - "hgq_heterogeneous": self.hgq_heterogeneous, - "pooling": self.pooling, + "config": self.config.get_dict(), + "quantize_input": self.quantize_input, + "quantize_output": self.quantize_output, + "in_quant_bits": self.in_quant_bits, + "out_quant_bits": self.out_quant_bits, } ) return config +@keras.saving.register_keras_serializable(package="PQuant") class PQAvgPool1d(PQAvgPoolBase, keras.layers.AveragePooling1D): def __init__( self, @@ -1384,7 +1658,11 @@ def call(self, x, training=None): self.add_loss(self.hgq_loss()) return x + def get_config(self): + return super().get_config() + +@keras.saving.register_keras_serializable(package="PQuant") class PQAvgPool2d(PQAvgPoolBase, keras.layers.AveragePooling2D): def __init__( self, @@ -1421,6 +1699,9 @@ def call(self, x, training=None): self.add_loss(self.hgq_loss()) return x + def get_config(self): + return super().get_config() + def call_post_round_functions(model, rewind, rounds, r): last_round = r == rounds - 1 @@ -1433,15 +1714,20 @@ def call_post_round_functions(model, rewind, rounds, r): def apply_final_compression(model): - x = model.layers[0].output - for layer in model.layers[1:]: + for layer in model.layers: if isinstance(layer, (PQWeightBiasBase, PQSeparableConv2d, PQBatchNormalization, PQDepthwiseConv2d)): layer.apply_final_compression() - x = layer(x) - else: - x = layer(x) - replaced_model = keras.Model(inputs=model.inputs, outputs=x) - return replaced_model + if hasattr(layer, "input_quantizer"): + layer.input_quantizer.apply_final_compression() + if hasattr(layer, "output_quantizer"): + layer.output_quantizer.apply_final_compression() + return model + + +def _update_pruning_mask(layer): + if layer.enable_pruning and hasattr(layer.pruning_layer, "update_mask"): + kernel = layer.handle_transpose(layer._kernel, layer.weight_transpose, True) + layer.pruning_layer.update_mask(kernel) def post_epoch_functions(model, epoch, total_epochs, **kwargs): @@ -1455,10 +1741,15 @@ def post_epoch_functions(model, epoch, total_epochs, **kwargs): PQDense, ), ): - layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + if layer.enable_pruning: + layer.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer) elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) - layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + if layer.enable_pruning: + layer.depthwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer.depthwise_conv) + layer.pointwise_conv.pruning_layer.post_epoch_function(epoch, total_epochs, **kwargs) + _update_pruning_mask(layer.pointwise_conv) def pre_epoch_functions(model, epoch, total_epochs): @@ -1472,10 +1763,12 @@ def pre_epoch_functions(model, epoch, total_epochs): PQDense, ), ): - layer.pruning_layer.pre_epoch_function(epoch, total_epochs) + if layer.enable_pruning: + layer.pruning_layer.pre_epoch_function(epoch, total_epochs) elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) - layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + if layer.enable_pruning: + layer.depthwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) + layer.pointwise_conv.pruning_layer.pre_epoch_function(epoch, total_epochs) def post_round_functions(model): @@ -1540,9 +1833,12 @@ def pre_finetune_functions(model): PQDense, ), ): + layer.pre_finetune_function() layer.pruning_layer.pre_finetune_function() elif isinstance(layer, PQSeparableConv2d): + layer.depthwise_conv.pre_finetune_function() layer.depthwise_conv.pruning_layer.pre_finetune_function() + layer.pointwise_conv.pre_finetune_function() layer.pointwise_conv.pruning_layer.pre_finetune_function() @@ -1557,10 +1853,10 @@ def post_pretrain_functions(model, config): PQDense, ), ): - layer.pruning_layer.post_pre_train_function() + layer.post_pre_train_function() elif isinstance(layer, PQSeparableConv2d): - layer.depthwise_conv.pruning_layer.post_pre_train_function() - layer.pointwise_conv.pruning_layer.post_pre_train_function() + layer.depthwise_conv.post_pre_train_function() + layer.pointwise_conv.post_pre_train_function() elif isinstance(layer, (PQActivation, PQAvgPoolBase, PQBatchNormalization)): layer.post_pre_train_function() if config.pruning_parameters.pruning_method == "pdp" or ( @@ -1782,7 +2078,6 @@ def add_compression_layers(model, config, input_shape=None): depth_multiplier=layer.depth_multiplier, data_format=layer.data_format, dilation_rate=layer.dilation_rate, - activation=layer.activation, use_bias=layer.use_bias, bias_initializer=layer.bias_initializer, depthwise_initializer=layer.depthwise_initializer, @@ -1816,7 +2111,6 @@ def add_compression_layers(model, config, input_shape=None): data_format=layer.data_format, dilation_rate=layer.dilation_rate, groups=layer.groups, - activation=layer.activation, use_bias=layer.use_bias, kernel_initializer=layer.kernel_initializer, bias_initializer=layer.bias_initializer, @@ -1871,13 +2165,13 @@ def add_compression_layers(model, config, input_shape=None): new_layer.pointwise_conv.set_enable_pruning(enable_pruning_pointwise) pruning_layer_input = layer.depthwise_kernel - transpose_shape = new_layer.weight_transpose - pruning_layer_input = ops.transpose(pruning_layer_input, transpose_shape) + pruning_layer_input = ops.transpose(pruning_layer_input, new_layer.depthwise_conv.weight_transpose) new_layer.depthwise_conv.pruning_layer.build(pruning_layer_input.shape) pointwise_pruning_layer_input = layer.pointwise_kernel - transpose_shape = new_layer.weight_transpose - pointwise_pruning_layer_input = ops.transpose(pointwise_pruning_layer_input, transpose_shape) + pointwise_pruning_layer_input = ops.transpose( + pointwise_pruning_layer_input, new_layer.pointwise_conv.weight_transpose + ) new_layer.pointwise_conv.pruning_layer.build(pointwise_pruning_layer_input.shape) new_layer.depthwise_conv.build(x.shape) y = new_layer.depthwise_conv(x).shape @@ -1916,7 +2210,6 @@ def add_compression_layers(model, config, input_shape=None): new_layer = PQDense( config=config, units=layer.units, - activation=layer.activation, use_bias=layer.use_bias, kernel_initializer=layer.kernel_initializer, bias_initializer=layer.bias_initializer, @@ -2018,8 +2311,9 @@ def set_quantization_bits_activations(config, layer, new_layer): if isinstance(layer, ReLU): f_input += 1 f_output += 1 # Unsigned, add 1 bit to default value only - if layer.name in config.quantization_parameters.layer_specific: - layer_config = config.quantization_parameters.layer_specific[layer.name] + layer_specific = config.quantization_parameters.layer_specific + if layer.name in layer_specific: + layer_config = layer_specific[layer.name] if hasattr(layer, "activation") and layer.activation.__name__ in layer_config: if "input" in layer_config[layer.activation.__name__]: if "integer_bits" in layer_config[layer.activation.__name__]["input"]: @@ -2144,6 +2438,10 @@ def set_quantization_bits_weight_layers(config, layer, new_layer): new_layer.f_weight = f_bits_w new_layer.i_bias = i_bits_b new_layer.f_bias = f_bits_b + new_layer.weight_quantizer.i_init = float(i_bits_w) + new_layer.weight_quantizer.f_init = float(f_bits_w) + new_layer.bias_quantizer.i_init = float(i_bits_b) + new_layer.bias_quantizer.f_init = float(f_bits_b) def get_enable_pruning(layer, config): @@ -2152,7 +2450,7 @@ def get_enable_pruning(layer, config): enable_pruning_depthwise = enable_pruning_pointwise = True if layer.name + "_depthwise" in config.pruning_parameters.disable_pruning_for_layers: enable_pruning_depthwise = False - if layer.name + "pointwise" in config.pruning_parameters.disable_pruning_for_layers: + if layer.name + "_pointwise" in config.pruning_parameters.disable_pruning_for_layers: enable_pruning_pointwise = False return enable_pruning_depthwise, enable_pruning_pointwise else: @@ -2223,7 +2521,7 @@ def populate_config_with_all_layers(model, config): elif isinstance( layer, (Activation, ReLU, AveragePooling1D, AveragePooling2D, AveragePooling3D, PQActivation, PQAvgPoolBase) ): - custom_scheme.layer_specific[layer.name] = { + custom_scheme["layer_specific"][layer.name] = { "input": {"quantize": True, "integer_bits": 0.0, "fractional_bits": 7.0}, "output": {"quantize": True, "integer_bits": 0.0, "fractional_bits": 7.0}, } @@ -2248,7 +2546,7 @@ def post_training_prune(model, config, calibration_data): model = add_compression_layers(model, config, inputs.shape) post_pretrain_functions(model, config) model(inputs, training=True) # True so pruning works - return apply_final_compression(model, config) + return apply_final_compression(model) def get_ebops(model, **kwargs): diff --git a/src/pquant/core/keras/quantizer.py b/src/pquant/core/keras/quantizer.py index d4398a0..38b96af 100644 --- a/src/pquant/core/keras/quantizer.py +++ b/src/pquant/core/keras/quantizer.py @@ -6,7 +6,7 @@ from pquant.core.quantizer_functions import create_quantizer -@keras.saving.register_keras_serializable(package="PQuant") +@keras.saving.register_keras_serializable(package="PQuantML") class Quantizer(keras.layers.Layer): # HGQ quantizer wrapper def __init__( @@ -35,7 +35,7 @@ def __init__( self.quantizer = create_quantizer( self.k_init, self.i_init, self.f_init, self.overflow, self.round_mode, self.use_hgq, self.is_data, place ) - self.is_pretraining = False + self.is_pretraining = True self.hgq_gamma = hgq_gamma if isinstance(granularity, Enum): self.granularity = granularity.value @@ -63,15 +63,37 @@ def compute_dynamic_bits(self, x): return int_bits, frac_bits def build(self, input_shape): - if self.granularity == "per_tensor": + if self.use_hgq: + shape = tuple(input_shape) if not self.is_data else (1,) + tuple(input_shape[1:]) + self.k = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.k_init), trainable=False) + self.i = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.i_init), trainable=False) + self.f = self.add_weight(shape=shape, initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=shape, + initializer=keras.initializers.Constant(self.k_init + self.i_init + self.f_init), + trainable=False, + ) + if not self.quantizer.built: + self.quantizer.build(shape) + self.set_quantization_bits(self.i_init, self.f_init) + elif self.granularity == "per_tensor": self.k = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.k_init), trainable=False) self.i = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.i_init), trainable=False) self.f = self.add_weight(shape=(), initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=(), initializer=keras.initializers.Constant(self.k_init + self.f_init + self.f_init), trainable=False + ) else: i, _ = self.compute_dynamic_bits(keras.ops.ones(input_shape)) self.k = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.k_init), trainable=False) self.i = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.i_init), trainable=False) self.f = self.add_weight(shape=i.shape, initializer=keras.initializers.Constant(self.f_init), trainable=False) + self.b = self.add_weight( + shape=i.shape, + initializer=keras.initializers.Constant(self.k_init + self.f_init + self.f_init), + trainable=False, + ) + super().build(input_shape) def get_total_bits(self, shape): @@ -84,8 +106,7 @@ def get_total_bits(self, shape): def get_quantization_bits(self): if self.use_hgq: return self.quantizer.quantizer.k, self.quantizer.quantizer.i, self.quantizer.quantizer.f - else: - return self.k, self.i, self.f + return self.k, self.i, self.f def set_quantization_bits(self, i, f): if self.use_hgq: @@ -94,8 +115,17 @@ def set_quantization_bits(self, i, f): self.i = i self.f = f - def post_pretrain(self): - self.is_pretraining = True + def apply_final_compression(self): + if self.use_hgq and not self.quantizer.built or not self.built: + return + k, i, f = self.get_quantization_bits() + self.i.assign(i) + self.f.assign(f) + self.b.assign(k + i + f) + self.final_compression_done = True + + def post_pre_train_function(self): + self.is_pretraining = False def call(self, x, training=None): if self.use_hgq: @@ -113,10 +143,7 @@ def call(self, x, training=None): def hgq_loss(self): if self.is_pretraining or not self.use_hgq: return 0.0 - loss = 0 - for layer_loss in self.quantizer.quantizer.losses: - loss += layer_loss - return loss + return sum(self.quantizer.losses) @classmethod def from_config(cls, config): diff --git a/src/pquant/core/keras/train.py b/src/pquant/core/keras/train.py index 59de0dc..2aabb3a 100644 --- a/src/pquant/core/keras/train.py +++ b/src/pquant/core/keras/train.py @@ -1,7 +1,10 @@ import keras from pquant.core.keras.layers import ( + apply_final_compression, call_post_round_functions, + get_ebops, + get_layer_keep_ratio, post_epoch_functions, post_pretrain_functions, pre_epoch_functions, @@ -10,6 +13,111 @@ ) +class PQuantCallback(keras.callbacks.Callback): + """ + Keras callback equivalent of train_model(). + + Call model.fit(epochs=callback.total_epochs, callbacks=[callback], ...). + Phase boundaries: + [0, pretraining_epochs) → pretraining + [pretraining_epochs, pretraining_epochs + rounds*epochs) → main rounds + [pretraining_epochs + rounds*epochs, total_epochs) → fine-tuning + """ + + def __init__( + self, + config, + log_ebops=True, + log_keep_ratio=True, + apply_final_compression=True, + pretraining_epochs=None, + epochs=None, + fine_tuning_epochs=None, + ): + super().__init__() + tc = config.training_parameters + self.config = config + self.pretraining_epochs = pretraining_epochs if pretraining_epochs is not None else tc.pretraining_epochs + self.rounds = tc.rounds + self.epochs_per_round = epochs if epochs is not None else tc.epochs + self.fine_tuning_epochs = fine_tuning_epochs if fine_tuning_epochs is not None else tc.fine_tuning_epochs + self.rewind = tc.rewind + self.save_weights_epoch = tc.save_weights_epoch + self.log_ebops = log_ebops + self.log_keep_ratio = log_keep_ratio + self.apply_final_compression = apply_final_compression + + self._main_end = self.pretraining_epochs + self.rounds * self.epochs_per_round + self._stage = "pretrain" if self.pretraining_epochs > 0 else "train" + + @property + def total_epochs(self): + return self._main_end + self.fine_tuning_epochs + + def on_train_begin(self, logs=None): + # post_pretrain_functions is always called; if there are no pretraining + # epochs the transition happens immediately. + if self.pretraining_epochs == 0: + post_pretrain_functions(self.model, self.config) + # pre_finetune_functions is also always called; if there are no + # main or fine-tuning epochs the transition also happens now. + if self.epochs_per_round == 0 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + + def on_epoch_begin(self, epoch, logs=None): + if epoch < self.pretraining_epochs: + pre_epoch_functions(self.model, epoch, self.pretraining_epochs) + elif epoch < self._main_end: + rel = epoch - self.pretraining_epochs + r, e = divmod(rel, self.epochs_per_round) + if r == 0 and e == self.save_weights_epoch: + save_weights_functions(self.model) + pre_epoch_functions(self.model, e, self.epochs_per_round) + else: + e = epoch - self._main_end + if e == 0: + pre_finetune_functions(self.model) + pre_epoch_functions(self.model, e, self.fine_tuning_epochs) + + @property + def stage(self): + """Current training stage: 'pretrain', 'train', or 'finetune'.""" + return self._stage + + def on_epoch_end(self, epoch, logs=None): + if epoch < self.pretraining_epochs: + self._stage = "pretrain" + e = epoch + post_epoch_functions(self.model, e, self.pretraining_epochs) + if e == self.pretraining_epochs - 1: + post_pretrain_functions(self.model, self.config) + if self.epochs_per_round == 0 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + elif epoch < self._main_end: + self._stage = "train" + rel = epoch - self.pretraining_epochs + r, e = divmod(rel, self.epochs_per_round) + post_epoch_functions(self.model, e, self.epochs_per_round) + if e == self.epochs_per_round - 1: + call_post_round_functions(self.model, self.rewind, self.rounds, r) + if r == self.rounds - 1 and self.fine_tuning_epochs == 0: + pre_finetune_functions(self.model) + else: + self._stage = "finetune" + e = epoch - self._main_end + post_epoch_functions(self.model, e, self.fine_tuning_epochs) + if logs is not None: + logs["stage"] = self._stage + if self.log_ebops: + logs["ebops"] = get_ebops(self.model) + if self.log_keep_ratio: + logs["remaining_weights"] = get_layer_keep_ratio(self.model) + + def on_train_end(self, logs=None): # noqa: ARG002 + if self.apply_final_compression: + apply_final_compression(self.model) + + def train_model(model, config, train_func, valid_func, **kwargs): """ Generic training loop, user provides training and validation functions diff --git a/src/pquant/core/torch/layers.py b/src/pquant/core/torch/layers.py index 114a912..87ced7c 100644 --- a/src/pquant/core/torch/layers.py +++ b/src/pquant/core/torch/layers.py @@ -1127,19 +1127,19 @@ def add_layer_specific_quantization_to_model(name, layer, config): if name in config.quantization_parameters.layer_specific: layer_config = config.quantization_parameters.layer_specific[name] if "weight" in layer_config: - weight_k_bits = layer_config["weight"]["keep_negatives"] - weight_int_bits = layer_config["weight"]["integer_bits"] - weight_fractional_bits = layer_config["weight"]["fractional_bits"] - layer.k_weight = torch.tensor(weight_k_bits) - layer.i_weight = torch.tensor(weight_int_bits) - layer.f_weight = torch.tensor(weight_fractional_bits) + if "keep_negatives" in layer_config["weight"]: + layer.k_weight = torch.tensor(layer_config["weight"]["keep_negatives"]) + if "integer_bits" in layer_config["weight"]: + layer.i_weight = torch.tensor(layer_config["weight"]["integer_bits"]) + if "fractional_bits" in layer_config["weight"]: + layer.f_weight = torch.tensor(layer_config["weight"]["fractional_bits"]) if "bias" in layer_config: - bias_k_bits = layer_config["bias"]["keep_negatives"] - bias_int_bits = layer_config["bias"]["integer_bits"] - bias_fractional_bits = layer_config["bias"]["fractional_bits"] - layer.k_bias = torch.tensor(bias_k_bits) - layer.i_bias = torch.tensor(bias_int_bits) - layer.f_bias = torch.tensor(bias_fractional_bits) + if "keep_negatives" in layer_config["bias"]: + layer.k_bias = torch.tensor(layer_config["bias"]["keep_negatives"]) + if "integer_bits" in layer_config["bias"]: + layer.i_bias = torch.tensor(layer_config["bias"]["integer_bits"]) + if "fractional_bits" in layer_config["bias"]: + layer.f_bias = torch.tensor(layer_config["bias"]["fractional_bits"]) if "input" in layer_config: if "keep_negatives" in layer_config["input"]: input_keep_negatives = torch.tensor(layer_config["input"]["keep_negatives"]) @@ -1576,7 +1576,7 @@ def get_layer_keep_ratio(model): def is_training_stage(layer): - return False if layer.pruning_layer.is_finetuning or layer.pruning_layer.is_pretraining else True + return False if layer.pruning_layer._is_finetuning or layer.pruning_layer._is_pretraining else True def get_model_losses(model, losses): diff --git a/src/pquant/core/torch/quantizer.py b/src/pquant/core/torch/quantizer.py index 078de57..7fd57f8 100644 --- a/src/pquant/core/torch/quantizer.py +++ b/src/pquant/core/torch/quantizer.py @@ -29,8 +29,11 @@ def __init__( self.is_data = is_data self.i_init = i self.f_init = f + self.i = torch.nn.Parameter(torch.tensor(i), requires_grad=False) + self.f = torch.nn.Parameter(torch.tensor(f), requires_grad=False) + self.b = torch.nn.Parameter(torch.tensor(i + k + f), requires_grad=False) self.quantizer = create_quantizer(self.k, i, f, self.overflow, self.round_mode, self.use_hgq, self.is_data, place) - self.is_pretraining = False + self.is_pretraining = True self.hgq_gamma = hgq_gamma if isinstance(granularity, Enum): self.granularity = granularity.value @@ -90,7 +93,9 @@ def forward(self, x): return x else: if self.granularity == 'per_tensor': + self.initialize_quantization_parameters(self.i_init, self.f_init) _, i, f = self.get_quantization_bits() + return self.quantizer(x, k=self.k, i=i, f=f, training=self.training) else: i, f = self.compute_dynamic_bits(x) self.initialize_quantization_parameters(i, f) diff --git a/src/pquant/data_models/pruning_model.py b/src/pquant/data_models/pruning_model.py index 8b44f49..21e6607 100644 --- a/src/pquant/data_models/pruning_model.py +++ b/src/pquant/data_models/pruning_model.py @@ -88,3 +88,4 @@ class MDMMPruningModel(BasePruningModel): use_grad: bool = Field(default=False) l0_mode: Literal["coarse", "smooth"] = Field(default="coarse") scale_mode: Literal["mean", "sum"] = Field(default="mean") + constraint_lr: float = Field(default=1.0e-3) diff --git a/src/pquant/data_models/training_model.py b/src/pquant/data_models/training_model.py index f03319e..f841d70 100644 --- a/src/pquant/data_models/training_model.py +++ b/src/pquant/data_models/training_model.py @@ -11,4 +11,4 @@ class BaseTrainingModel(BaseModel): rewind: str = Field(default="never") rounds: int = Field(default=1) save_weights_epoch: int = Field(default=-1) - pruning_first: bool = Field(default=False) \ No newline at end of file + pruning_first: bool = Field(default=False) diff --git a/src/pquant/pruning_methods/activation_pruning.py b/src/pquant/pruning_methods/activation_pruning.py index 0940f88..1840129 100644 --- a/src/pquant/pruning_methods/activation_pruning.py +++ b/src/pquant/pruning_methods/activation_pruning.py @@ -12,86 +12,120 @@ def __init__(self, config, layer_type, *args, **kwargs): super().__init__(*args, **kwargs) self.config = config self.act_type = "relu" - self.t = 0 - self.batches_collected = 0 self.layer_type = layer_type - self.activations = None - self.total = 0.0 - self.is_pretraining = True - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False self.threshold = ops.convert_to_tensor(config.pruning_parameters.threshold) self.t_start_collecting_batch = self.config.pruning_parameters.t_start_collecting_batch def build(self, input_shape): self.shape = (input_shape[0], 1) - if self.layer_type == "conv": + if self.layer_type in ("conv", "depthwise_conv"): if len(input_shape) == 3: self.shape = (input_shape[0], 1, 1) else: self.shape = (input_shape[0], 1, 1, 1) + n_channels = input_shape[0] self.mask = self.add_weight(shape=self.shape, initializer="ones", trainable=False) - self.mask_placeholder = ops.ones(self.shape) + self.mask_placeholder = self.add_weight(shape=self.shape, initializer="ones", trainable=False) + self.activations = self.add_weight(shape=(n_channels,), initializer="zeros", trainable=False) + self.batches_collected = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.t = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) + super().build(input_shape) def collect_output(self, output, training): """ - Accumulates values for how often the outputs of the neurons and channels of - linear/convolution layer are over 0. Every t_delta steps, uses these values to update - the mask to prune those channels and neurons that are active less than a given threshold + Accumulates per-channel activity fractions. Every t_delta batches, updates + mask_placeholder. The actual mask used in call() is updated from mask_placeholder + in post_epoch_function (outside the compiled graph, no step-to-step dependency). """ - if not training or self.is_pretraining or self.is_finetuning: - # Don't collect during validation + if not training: return - if self.activations is None: - # Initialize activations dynamically - self.activations = ops.zeros(shape=output.shape[1:], dtype=output.dtype) - if self.t < self.t_start_collecting_batch: - return - self.batches_collected += 1 - self.total += output.shape[0] - gt_zero = ops.cast((output > 0), output.dtype) - gt_zero = ops.sum(gt_zero, axis=0) # Sum over batch, take average during mask update - self.activations += gt_zero - if self.batches_collected % self.config.pruning_parameters.t_delta == 0: - pct_active = self.activations / self.total - self.t = 0 - self.total = 0 - self.batches_collected = 0 - if self.layer_type == "linear": - self.mask_placeholder = ops.expand_dims(ops.cast((pct_active > self.threshold), pct_active.dtype), 1) - else: - pct_active = ops.reshape(pct_active, (pct_active.shape[0], -1)) - pct_active_avg = ops.mean(pct_active, axis=-1) - pct_active_above_threshold = ops.cast((pct_active_avg > self.threshold), pct_active_avg.dtype) - if len(output.shape) == 3: - self.mask_placeholder = ops.reshape( - pct_active_above_threshold, list(pct_active_above_threshold.shape) + [1, 1] - ) - else: - self.mask_placeholder = ops.reshape( - pct_active_above_threshold, list(pct_active_above_threshold.shape) + [1, 1, 1] - ) - self.activations *= 0.0 - - def call(self, weight): # Mask is only updated every t_delta step, using collect_output - if self.is_pretraining: - return weight + + t_delta = self.config.pruning_parameters.t_delta + # Collect only when training and if above the starting point of collecting + should_collect = ops.logical_not( + ops.logical_or( + ops.logical_or(self.is_pretraining, self.is_finetuning), + self.t < self.t_start_collecting_batch, + ) + ) + + # Per-channel mean activity fraction + gt_zero = ops.cast(output > 0, output.dtype) + if self.layer_type == "linear": + per_channel = ops.mean(gt_zero, axis=0) else: - return self.mask * weight + # output is channels-first (batch, channels, ...); average over batch + spatial + axes = (0,) + tuple(range(2, len(output.shape))) + per_channel = ops.mean(gt_zero, axis=axes) + + # Snapshot current state + activations_cur = ops.convert_to_tensor(self.activations) + batches_cur = ops.convert_to_tensor(self.batches_collected) + mask_ph_cur = ops.convert_to_tensor(self.mask_placeholder) + + # Accumulate (gated by should_collect) + new_activations = activations_cur + ops.where(should_collect, per_channel, ops.zeros_like(per_channel)) + new_batches = batches_cur + ops.cast(should_collect, "int32") + + # Update mask_placeholder every t_delta batches + should_update = ops.logical_and( + should_collect, + ops.equal(new_batches % t_delta, 0), + ) + + safe_batches = ops.cast(ops.maximum(new_batches, 1), new_activations.dtype) + pct_active = new_activations / safe_batches + new_mask_ph = self._compute_mask(pct_active) - def get_hard_mask(self, weight=None): - return self.mask + self.mask_placeholder.assign(ops.where(should_update, new_mask_ph, mask_ph_cur)) + + # Reset accumulators after mask update, else keep accumulated values + self.activations.assign(ops.where(should_update, ops.zeros_like(new_activations), new_activations)) + self.batches_collected.assign(ops.where(should_update, ops.zeros_like(new_batches), new_batches)) + self.t.assign(ops.where(should_update, ops.zeros_like(self.t), self.t)) + + def _compute_mask(self, pct_active): + binary = ops.cast(pct_active > self.threshold, pct_active.dtype) + return ops.reshape(binary, self.shape) + + def call(self, weight): + stored_mask = ops.convert_to_tensor(self.mask) + return ops.where(self.is_pretraining, weight, stored_mask * weight) + + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 pass def post_round_function(self): pass def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def calculate_additional_loss(self): return 0 @@ -99,14 +133,12 @@ def calculate_additional_loss(self): def get_layer_sparsity(self, weight): pass - def post_epoch_function(self, epoch, total_epochs): - if self.is_pretraining is False: - self.t += 1 - self.mask.assign(self.mask_placeholder) - pass + def post_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 + if not self._is_pretraining: + self.t.assign_add(1) + self.mask.assign(ops.convert_to_tensor(self.mask_placeholder)) def get_config(self): config = super().get_config() - config.update({"config": self.config.get_dict(), "layer_type": self.layer_type}) return config diff --git a/src/pquant/pruning_methods/autosparse.py b/src/pquant/pruning_methods/autosparse.py index 552b2c0..07c1252 100644 --- a/src/pquant/pruning_methods/autosparse.py +++ b/src/pquant/pruning_methods/autosparse.py @@ -63,10 +63,11 @@ def __init__(self, config, layer_type, *args, **kwargs): self.g = ops.sigmoid self.config = config self.layer_type = layer_type + self._alpha_init = float(config.pruning_parameters.alpha) global BACKWARD_SPARSITY BACKWARD_SPARSITY = config.pruning_parameters.backward_sparsity - self.is_pretraining = True - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): self.threshold_size = get_threshold_size(self.config, input_shape) @@ -76,30 +77,57 @@ def build(self, input_shape): initializer=Constant(self.config.pruning_parameters.threshold_init), trainable=True, ) - self.alpha = ops.convert_to_tensor(self.config.pruning_parameters.alpha, dtype="float32") + self.mask = self.add_weight( + name="mask", + shape=input_shape, + initializer="ones", + trainable=False, + ) + self.alpha = self.add_weight( + name="alpha", + shape=(), + initializer=Constant(self._alpha_init), + trainable=False, + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) def call(self, weight): - """ - sign(W) * ReLu(X), where X = |W| - sigmoid(threshold), with gradient: - 1 if W > 0 else alpha. Alpha is decayed after each epoch. - """ - if self.is_pretraining: - return weight - if self.is_finetuning: - return self.mask * weight - else: - mask = self.get_mask(weight) - self.mask = ops.reshape(mask, weight.shape) - return ops.sign(weight) * ops.reshape(mask, weight.shape) - - def get_hard_mask(self, weight=None): - return self.mask + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + w_t = ops.abs(weight_reshaped) - self.g(self.threshold) + + new_binary_mask = ops.cast(ops.reshape(w_t > 0, weight.shape), weight.dtype) + is_training = ops.logical_not(ops.logical_or(self.is_pretraining, self.is_finetuning)) + self.mask.assign(ops.where(is_training, new_binary_mask, ops.convert_to_tensor(self.mask))) + + sparse_weight = ops.sign(weight) * ops.reshape(autosparse_prune(w_t, self.alpha), weight.shape) + + return ops.where( + self.is_pretraining, + weight, + ops.where(self.is_finetuning, ops.convert_to_tensor(self.mask) * weight, sparse_weight), + ) + + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def get_mask(self, weight): weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) w_t = ops.abs(weight_reshaped) - self.g(self.threshold) - return autosparse_prune(w_t, self.alpha) + return ops.cast(ops.reshape(w_t > 0, weight.shape), weight.dtype) def get_layer_sparsity(self, weight): masked_weight = self.get_mask(weight) @@ -109,26 +137,30 @@ def get_layer_sparsity(self, weight): def pre_epoch_function(self, epoch, total_epochs): pass - def calculate_additional_loss(*args, **kwargs): + def calculate_additional_loss(self): return 0 def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def post_round_function(self): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_epoch_function(self, epoch, total_epochs): - self.alpha *= cosine_sigmoid_decay(epoch, total_epochs) - if epoch == self.config.pruning_parameters.alpha_reset_epoch: - self.alpha *= 0.0 + decay = cosine_sigmoid_decay(epoch, total_epochs) + self.alpha.assign(self._alpha_init * decay) + if epoch >= self.config.pruning_parameters.alpha_reset_epoch: + self.alpha.assign(ops.zeros_like(self.alpha)) def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/constraint_functions.py b/src/pquant/pruning_methods/constraint_functions.py index 0753996..431cdd2 100644 --- a/src/pquant/pruning_methods/constraint_functions.py +++ b/src/pquant/pruning_methods/constraint_functions.py @@ -50,7 +50,7 @@ def __init__(self, lmbda_init=1.0, scale=1.0, damping=1.0, **kwargs): trainable=False, ) - def call(self, weight): + def call(self, weight, training=None): """Calculates the penalty from a given infeasibility measure.""" raw_infeasibility = self.get_infeasibility(weight) infeasibility = self.pipe_infeasibility(raw_infeasibility) @@ -61,8 +61,9 @@ def call(self, weight): else: lmbda_step = self.lr_ * self.scale * self.prev_infs ascent_lmbda = self.lmbda + lmbda_step - self.lmbda.assign_add(lmbda_step) - self.prev_infs.assign(infeasibility) + if training: + self.lmbda.assign_add(lmbda_step) + self.prev_infs.assign(infeasibility) l_term = ascent_lmbda * infeasibility damp_term = self.damping * ops.square(infeasibility) / 2 diff --git a/src/pquant/pruning_methods/cs.py b/src/pquant/pruning_methods/cs.py index 6d5c69a..97e64ff 100644 --- a/src/pquant/pruning_methods/cs.py +++ b/src/pquant/pruning_methods/cs.py @@ -13,9 +13,9 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config self.final_temp = config.pruning_parameters.final_temp - self.is_finetuning = False + self._is_finetuning = False self.layer_type = layer_type - self.is_pretraining = True + self._is_pretraining = True def build(self, input_shape): self.s_init = ops.convert_to_tensor(self.config.pruning_parameters.threshold_init * ops.ones(input_shape)) @@ -23,37 +23,55 @@ def build(self, input_shape): self.scaling = 1.0 / ops.sigmoid(self.s_init) self.beta = self.add_weight(name="beta", shape=(), initializer=Constant(1.0), trainable=False) self.mask = self.add_weight(name="mask", shape=input_shape, initializer=Constant(1.0), trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) def call(self, weight): - if self.is_pretraining: - return weight - mask = self.get_mask() - self.mask.assign(mask) - return mask * weight + stored_mask = ops.convert_to_tensor(self.mask) + new_mask = self.get_mask() + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + updated_mask = ops.where(use_current_mask, stored_mask, new_mask) + self.mask.assign(updated_mask) + return updated_mask * weight def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) + if hasattr(self, "mask"): + self.mask.assign(self.get_hard_mask()) def get_mask(self): - if self.is_finetuning: - mask = self.get_hard_mask() - return mask - else: - mask = ops.sigmoid(self.beta * self.s) - mask = mask * self.scaling - return mask + return ops.sigmoid(self.beta * self.s) * self.scaling def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs): # noqa: ARG002 pass - def post_epoch_function(self, epoch, total_epochs): - self.beta.assign(self.beta * self.final_temp ** (1 / (total_epochs - 1))) + def post_epoch_function(self, epoch, total_epochs): # noqa: ARG002 + if total_epochs <= 1: + self.beta.assign(self.beta * self.final_temp) + else: + self.beta.assign(self.beta * self.final_temp ** (1 / (total_epochs - 1))) - def get_hard_mask(self, weight=None): + def get_hard_mask(self, weight=None): # noqa: ARG002 if self.config.pruning_parameters.enable_pruning: return ops.cast((self.s > 0), self.s.dtype) return ops.convert_to_tensor(1.0) @@ -73,7 +91,6 @@ def get_layer_sparsity(self, weight): def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/dst.py b/src/pquant/pruning_methods/dst.py index e8aac5f..e2d5ca9 100644 --- a/src/pquant/pruning_methods/dst.py +++ b/src/pquant/pruning_methods/dst.py @@ -38,14 +38,28 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config - self.is_pretraining = True self.layer_type = layer_type - self.is_finetuning = False + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): self.threshold_size = get_threshold_size(self.config, input_shape) self.threshold = self.add_weight(shape=self.threshold_size, initializer="zeros", trainable=True) self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) def call(self, weight): """ @@ -54,19 +68,27 @@ def call(self, weight): 0.4 if 0.4 < |W| <= 1 0 if |W| > 1 """ - if self.is_pretraining: - return weight - if self.is_finetuning: - return weight * self.mask - mask = self.get_mask(weight) - ratio = 1.0 - ops.sum(mask) / ops.cast(ops.size(mask), mask.dtype) - flag = ratio >= self.config.pruning_parameters.max_pruning_pct - self.threshold.assign(ops.where(flag, ops.ones(self.threshold.shape), self.threshold)) - mask = self.get_mask(weight) - self.mask.assign(mask) - masked_weight = weight * mask + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + + def use_existing(): + return weight * ops.convert_to_tensor(self.mask) + + def compute_new(): + mask = self.get_mask(weight) + ratio = 1.0 - ops.sum(mask) / ops.cast(ops.size(mask), mask.dtype) + flag = ratio >= self.config.pruning_parameters.max_pruning_pct + + def reset_and_recalculate(): + self.threshold.assign(ops.zeros(self.threshold.shape)) + return self.get_mask(weight) + + mask = ops.cond(flag, reset_and_recalculate, lambda: mask) + self.mask.assign(mask) + return weight * mask + + result = ops.cond(use_current_mask, use_existing, compute_new) self.add_loss(self.calculate_additional_loss()) - return masked_weight + return result def get_hard_mask(self, weight=None): return self.mask @@ -86,19 +108,22 @@ def get_layer_sparsity(self, weight): return ops.sum(self.get_mask(weight)) / ops.size(weight) def calculate_additional_loss(self): - if self.is_finetuning: - return 0.0 - loss = self.config.pruning_parameters.alpha * ops.sum(ops.exp(-self.threshold)) - return loss + if self._is_pretraining or self._is_finetuning: + return ops.cast(0.0, self.threshold.dtype) + return self.config.pruning_parameters.alpha * ops.sum(ops.exp(-self.threshold)) def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def post_epoch_function(self, epoch, total_epochs): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_round_function(self): pass diff --git a/src/pquant/pruning_methods/mdmm.py b/src/pquant/pruning_methods/mdmm.py index ae88fce..5837dd0 100644 --- a/src/pquant/pruning_methods/mdmm.py +++ b/src/pquant/pruning_methods/mdmm.py @@ -26,10 +26,8 @@ def __init__(self, config, layer_type, *args, **kwargs): self.config = config self.layer_type = layer_type self.constraint_layer = None - self.penalty_loss = None - self.built = False - self.is_finetuning = False - self.is_pretraining = True + self._is_finetuning = False + self._is_pretraining = True def build(self, input_shape): pruning_parameters = self.config.pruning_parameters @@ -62,7 +60,7 @@ def build(self, input_shape): "scale": self.config.pruning_parameters.scale, "damping": self.config.pruning_parameters.damping, "use_grad": self.config.pruning_parameters.use_grad, - "lr": self.config.training_parameters.lr, + "lr": self.config.pruning_parameters.constraint_lr, } constraint_type_cls = CONSTRAINT_REGISTRY.get(constraint_type) @@ -71,49 +69,57 @@ def build(self, input_shape): else: raise ValueError(f"Unknown constraint_type: {constraint_type}") - self.mask = ops.ones(input_shape) + self.mask = self.add_weight(name="mask", shape=input_shape, initializer="ones", trainable=False) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) self.constraint_layer.build(input_shape) super().build(input_shape) - self.built = True def call(self, weight): - if not self.built: - self.build(weight.shape) - - if self.is_finetuning: - self.penalty_loss = 0.0 - weight = weight * self.get_hard_mask(weight) - else: - self.penalty_loss = self.constraint_layer(weight) epsilon = self.config.pruning_parameters.epsilon - self.hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype) - return weight + hard_mask = ops.cast(ops.abs(weight) > epsilon, weight.dtype) + not_active = ops.logical_or(self.is_pretraining, self.is_finetuning) + self.mask.assign(ops.where(not_active, ops.convert_to_tensor(self.mask), hard_mask)) + + penalty = ops.sum(self.constraint_layer(weight)) + self.add_loss(ops.where(not_active, ops.zeros_like(penalty), penalty)) + + return ops.where(self.is_finetuning, weight * hard_mask, weight) def get_hard_mask(self, weight=None): if weight is None: - return self.hard_mask + return ops.convert_to_tensor(self.mask) epsilon = self.config.pruning_parameters.epsilon return ops.cast(ops.abs(weight) > epsilon, weight.dtype) def get_layer_sparsity(self, weight): - return ops.sum(self.get_hard_mask(weight)) / ops.size(weight) # Should this be subtracted from 1.0? + return ops.sum(self.get_hard_mask(weight)) / ops.size(weight) def calculate_additional_loss(self): - if self.penalty_loss is None: - raise ValueError("Penalty loss has not been calculated. Call the layer with weights first.") - else: - penalty_loss = ops.sum(self.penalty_loss) - - return penalty_loss + # Loss is added via self.add_loss() in call() for model.fit. + # For custom training loops, accumulate model.losses from the last forward pass instead. + return 0.0 def pre_epoch_function(self, epoch, total_epochs): pass def pre_finetune_function(self): - # Freeze the weights - # Set lmbda(s) to zero - self.is_finetuning = True - if hasattr(self.constraint_layer, 'module'): + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) + if hasattr(self.constraint_layer, "module"): self.constraint_layer.module.turn_off() else: self.constraint_layer.turn_off() @@ -122,14 +128,15 @@ def post_epoch_function(self, epoch, total_epochs): pass def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) def post_round_function(self): pass def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/src/pquant/pruning_methods/pdp.py b/src/pquant/pruning_methods/pdp.py index dfe7b92..b928ee9 100644 --- a/src/pquant/pruning_methods/pdp.py +++ b/src/pquant/pruning_methods/pdp.py @@ -10,163 +10,158 @@ def __init__(self, config, layer_type, *args, **kwargs): from pquant.core.hyperparameter_optimization import PQConfig config = PQConfig.load_from_config(config) - self.init_r = ops.convert_to_tensor(config.pruning_parameters.sparsity) - self.epsilon = ops.convert_to_tensor(config.pruning_parameters.epsilon) - self.r = config.pruning_parameters.sparsity + self._init_r = float(config.pruning_parameters.sparsity) + self._epsilon = float(config.pruning_parameters.epsilon) self.temp = config.pruning_parameters.temperature - self.is_pretraining = True self.config = config - self.is_finetuning = False self.layer_type = layer_type + self._is_pretraining = True + self._is_finetuning = False def build(self, input_shape): - input_shape_concatenated = list(input_shape) + [1] - self.softmax_shape = input_shape_concatenated - self.t = ops.ones(input_shape_concatenated) * 0.5 - if self.config.pruning_parameters.structured_pruning: + self.softmax_shape = list(input_shape) + [1] + + structured = self.config.pruning_parameters.structured_pruning + if structured: if self.layer_type == "linear": - shape = (input_shape[0], 1) + mask_shape = (input_shape[0], 1) + elif len(input_shape) == 3: + mask_shape = (input_shape[0], 1, 1) else: - if len(input_shape) == 3: - shape = (input_shape[0], 1, 1) - else: - shape = (input_shape[0], 1, 1, 1) - self.mask = self.add_weight(shape=shape, initializer="ones", name="mask", trainable=False) - self.flat_weight_size = ops.cast(ops.size(self.mask), self.mask.dtype) + mask_shape = (input_shape[0], 1, 1, 1) + else: + mask_shape = tuple(input_shape) + + self.mask = self.add_weight(shape=mask_shape, initializer="ones", name="mask", trainable=False) + import math + + self._mask_numel = math.prod(mask_shape) + self.flat_weight_size = float(self._mask_numel) + + # Dynamic state as Keras variables so they survive tf.function tracing + # and can be updated between epochs without retracing. + self.r = self.add_weight( + shape=(), + initializer=keras.initializers.Constant(self._init_r), + name="r", + trainable=False, + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) + + # Resolve static config branches at build time — no runtime branching needed. + if structured: + self._compute_mask = self._mask_structured_channel if self.layer_type == "conv" else self._mask_structured_linear + else: + self._compute_mask = self._mask_unstructured + super().build(input_shape) + # --- Lifecycle (called outside the training graph) --- + def post_pre_train_function(self): - self.is_pretraining = False # Enables pruning + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): - if not self.is_pretraining: - self.r = ops.minimum(1.0, self.epsilon * (epoch + 1)) * self.init_r + def pre_epoch_function(self, epoch, _): + if hasattr(self, "r"): + self.r.assign(ops.minimum(1.0, self._epsilon * (epoch + 1)) * self._init_r) def post_round_function(self): pass - def get_hard_mask(self, weight=None): - if self.is_finetuning: - return self.mask - if weight is None: - return ops.cast((self.mask >= 0.5), self.mask.dtype) - if self.config.pruning_parameters.structured_pruning: - if self.layer_type == "conv": - mask = self.get_mask_structured_channel(weight) - else: - mask = self.get_mask_structured_linear(weight) - else: - mask = self.get_mask(weight) - self.mask.assign(ops.cast((mask >= 0.5), mask.dtype)) - return self.mask - def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) if hasattr(self, "mask"): - self.mask.assign(ops.cast((self.mask >= 0.5), self.mask.dtype)) - - def get_mask_structured_linear(self, weight): - """ - Structured pruning. Use the l2 norm of the neurons instead of the absolute weight values - to calculate threshold point t. Prunes whole neurons. - """ - if self.is_pretraining: - return self.mask + self.mask.assign(ops.cast(self.mask >= 0.5, self.mask.dtype)) + + def post_epoch_function(self, epoch, total_epochs): + pass + + # --- Mask computation (graph-compatible, no Python conditionals on dynamic state) --- + + def _mask_unstructured(self, weight): + weight_reshaped = ops.reshape(weight, self.softmax_shape) + abs_flat = ops.ravel(ops.abs(weight)) + all_vals, _ = ops.top_k(abs_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, int(self.flat_weight_size) - 2) + Wh, Wt = all_vals[lim], all_vals[lim + 1] + t = ops.ones_like(weight_reshaped) * (0.5 * (Wh + Wt)) + soft_input = ops.concatenate((t**2, weight_reshaped**2), axis=-1) / self.temp + _, mw = ops.unstack(ops.softmax(soft_input, axis=-1), axis=-1) + return ops.reshape(mw, weight.shape) + + def _mask_structured_linear(self, weight): norm = ops.norm(weight, axis=1, ord=2, keepdims=True) norm_flat = ops.ravel(norm) - """ Do top_k for all neuron norms. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(norm)) to calculate t directly """ - W_all, _ = ops.top_k(norm_flat, ops.size(norm_flat)) - size = ops.cast(ops.size(W_all), self.mask.dtype) - ind = ops.cast((1 - self.r) * size, "int32") - 1 - lim = ops.clip(ind, 0, ops.cast(size - 2, "int32")) - Wh = W_all[lim] - Wt = W_all[lim + 1] - # norm = ops.expand_dims(norm, -1) + W_all, _ = ops.top_k(norm_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, self._mask_numel - 2) + Wh, Wt = W_all[lim], W_all[lim + 1] t = ops.ones(norm.shape) * 0.5 * (Wh + Wt) soft_input = ops.concatenate((t**2, norm**2), axis=1) / self.temp - softmax_result = ops.softmax(soft_input, axis=1) - _, mw = ops.unstack(softmax_result, axis=1) - mw = ops.expand_dims(mw, -1) - self.mask.assign(mw) - return mw + _, mw = ops.unstack(ops.softmax(soft_input, axis=1), axis=1) + return ops.expand_dims(mw, -1) - def get_mask_structured_channel(self, weight): - """ - Structured pruning. Use the l2 norm of the channels instead of the absolute weight values - to calculate threshold point t. Prunes whole channels. - """ - if self.is_pretraining: - return self.mask + def _mask_structured_channel(self, weight): weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) norm = ops.norm(weight_reshaped, axis=1, ord=2) - norm_flat = ops.ravel(norm) - """ Do top_k for all channel norms. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(norm)) to calculate t directly """ - W_all, _ = ops.top_k(norm_flat, ops.size(norm_flat)) - size = ops.cast(ops.size(W_all), self.mask.dtype) - ind = ops.cast((1 - self.r) * size, "int32") - 1 - lim = ops.clip(ind, 0, ops.cast(size - 2, "int32")) - - Wh = W_all[lim] - Wt = W_all[lim + 1] + W_all, _ = ops.top_k(norm_flat, self._mask_numel) + ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 + lim = ops.clip(ind, 0, self._mask_numel - 2) + Wh, Wt = W_all[lim], W_all[lim + 1] norm = ops.expand_dims(norm, -1) t = ops.ones(norm.shape) * 0.5 * (Wh + Wt) soft_input = ops.concatenate((t**2, norm**2), axis=-1) / self.temp - softmax_result = ops.softmax(soft_input, axis=-1) - zw, mw = ops.unstack(softmax_result, axis=-1) - diff = len(weight.shape) - len(mw.shape) - for _ in range(diff): + zw, mw = ops.unstack(ops.softmax(soft_input, axis=-1), axis=-1) + for _ in range(len(weight.shape) - len(mw.shape)): mw = ops.expand_dims(mw, -1) - self.mask.assign(mw) return mw - def get_mask(self, weight): - if self.is_pretraining: - return self.mask - weight_reshaped = ops.reshape(weight, self.softmax_shape) - abs_weight_flat = ops.ravel(ops.abs(weight)) - """ Do top_k for all weights. Returns sorted array, just use the values on both - sides of the threshold (sparsity * size(weight)) to calculate t directly """ - all, _ = ops.top_k(abs_weight_flat, ops.size(abs_weight_flat)) - ind = ops.cast((1 - self.r) * self.flat_weight_size, "int32") - 1 # Index begins from 0 - lim = ops.clip(ind, 0, ops.cast(self.flat_weight_size - 2, "int32")) - Wh = all[lim] - Wt = all[lim + 1] - t = self.t * (Wh + Wt) - soft_input = ops.concatenate((t**2, weight_reshaped**2), axis=-1) / self.temp - softmax_result = ops.softmax(soft_input, axis=-1) - _, mw = ops.unstack(softmax_result, axis=-1) - mask = ops.reshape(mw, weight.shape) - self.mask.assign(mask) - return mask - def call(self, weight): - if self.is_finetuning: - mask = self.mask - else: - if self.config.pruning_parameters.structured_pruning: - if self.layer_type == "conv": - mask = self.get_mask_structured_channel(weight) - else: - mask = self.get_mask_structured_linear(weight) - else: - mask = self.get_mask(weight) + new_mask = self._compute_mask(weight) + use_current_mask = ops.logical_or(self.is_pretraining, self.is_finetuning) + mask = ops.where(use_current_mask, ops.convert_to_tensor(self.mask), new_mask) return mask * weight + def update_mask(self, weight): + """Update stored mask from current weights. Called once per epoch from post_epoch_functions.""" + if not self._is_pretraining and not self._is_finetuning: + self.mask.assign(self._compute_mask(weight)) + + # --- Utilities --- + + def get_hard_mask(self, weight=None): + # if weight is not None and not bool(self.is_finetuning): + # self.mask.assign(self._compute_mask(weight)) + return ops.cast(self.mask >= 0.5, self.mask.dtype) + def calculate_additional_loss(self): return 0 def get_layer_sparsity(self, weight): - masked_weight_rounded = ops.cast((self.mask >= 0.5), self.mask.dtype) - masked_weight = masked_weight_rounded * weight + hard_mask = ops.cast(self.mask >= 0.5, self.mask.dtype) + masked_weight = hard_mask * weight return ops.count_nonzero(masked_weight) / ops.size(masked_weight) - def post_epoch_function(self, epoch, total_epochs): - pass - def get_config(self): config = super().get_config() - config.update({"config": self.config.get_dict(), "layer_type": self.layer_type, "mask": self.mask}) + config.update({"config": self.config.get_dict(), "layer_type": self.layer_type}) return config diff --git a/src/pquant/pruning_methods/wanda.py b/src/pquant/pruning_methods/wanda.py index 55b202e..7057726 100644 --- a/src/pquant/pruning_methods/wanda.py +++ b/src/pquant/pruning_methods/wanda.py @@ -12,120 +12,174 @@ def __init__(self, config, layer_type, *args, **kwargs): config = PQConfig.load_from_config(config) self.config = config self.act_type = "relu" - self.t = 0 self.layer_type = layer_type - self.batches_collected = 0 - self.inputs = None - self.total = 0.0 - self.done = False + self._is_pretraining = True + self._is_finetuning = False self.sparsity = self.config.pruning_parameters.sparsity - self.is_pretraining = True - self.is_finetuning = False self.N = self.config.pruning_parameters.N self.M = self.config.pruning_parameters.M self.t_start_collecting_batch = self.config.pruning_parameters.t_start_collecting_batch def build(self, input_shape): - self.mask = ops.ones(input_shape) + # input_shape is the (transposed) weight shape: (out, in) or (out, in, kH, kW) + # For depthwise_conv, weight shape is (in_ch, depth_mult, kH, kW) so n_in = input_shape[0] + n_in = input_shape[0] if self.layer_type == "depthwise_conv" else input_shape[1] + self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + # Accumulate per-input-channel sum of squared inputs; shape (n_in,) known at build time. + # Replaces storing full (batch, n_in, ...) inputs whose spatial/batch dims are unknown. + self.inputs_sq_sum = self.add_weight(shape=(n_in,), initializer="zeros", trainable=False) + self.batches_collected = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.t = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int32") + self.done = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.zeros(shape), dtype), + trainable=False, + dtype="bool", + ) + self.is_pretraining = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_pretraining else ops.zeros(shape), dtype), + name="is_pretraining", + trainable=False, + dtype="bool", + ) + self.is_finetuning = self.add_weight( + shape=(), + initializer=lambda shape, dtype: ops.cast(ops.ones(shape) if self._is_finetuning else ops.zeros(shape), dtype), + name="is_finetuning", + trainable=False, + dtype="bool", + ) super().build(input_shape) - def get_mask(self, weight, metric, sparsity): - d0, d1 = metric.shape - keep_idxs = ops.argsort(metric, axis=1)[:, int(d1 * sparsity) :] + ops.arange(d0)[:, None] * d1 - keep_idxs = ops.ravel(keep_idxs) - kept_values = ops.reshape( - ops.scatter(keep_idxs[:, None], ops.take(ops.ravel(weight), keep_idxs), ops.array((ops.size(weight),))), - weight.shape, + def collect_input(self, x, weight, training): + """ + Accumulates per-input-channel sum-of-squares of layer inputs. After t_delta batches, + computes mask = Wanda metric (|W| * L2_norm) once and sets done=True. One-shot pruning. + """ + if not training: + return + + t_delta = self.config.pruning_parameters.t_delta + + # Only collect in training stage if pruning hasn't been done already, and if above starting epoch + should_collect = ops.logical_not( + ops.logical_or( + ops.logical_or(self.is_pretraining, self.is_finetuning), + ops.logical_or(self.done, self.t < self.t_start_collecting_batch), + ) ) - mask = ops.cast(kept_values != 0, weight.dtype) - return mask - def handle_linear(self, x, weight): - norm = ops.norm(x, ord=2, axis=0) + # Per-batch per-channel sum of squared activations (shape: n_in,) + if self.layer_type == "linear": + per_batch_sq = ops.sum(ops.square(x), axis=0) # (n_in,) + else: + # x is channels-first (batch, in_channels, ...); sum over batch + spatial + axes = (0,) + tuple(range(2, len(x.shape))) + per_batch_sq = ops.sum(ops.square(x), axis=axes) # (in_channels,) + + # Snapshot current state + sq_sum_cur = ops.convert_to_tensor(self.inputs_sq_sum) + batches_cur = ops.convert_to_tensor(self.batches_collected) + + new_sq_sum = sq_sum_cur + ops.where(should_collect, per_batch_sq, ops.zeros_like(per_batch_sq)) + new_batches = batches_cur + ops.cast(should_collect, "int32") # Adding 0 if not collecting + + # Prune once when t_delta batches have been collected + should_prune = ops.equal(new_batches, ops.cast(t_delta, "int32")) + + mask_cur = ops.convert_to_tensor(self.mask) + + def do_prune(): + norm = ops.sqrt(new_sq_sum) + return self._compute_prune_mask(norm, weight) + + new_mask = ops.cond(should_prune, do_prune, lambda: mask_cur) + self.mask.assign(new_mask) + self.done.assign(ops.logical_or(self.done, should_prune)) + + # Reset accumulators after pruning + self.inputs_sq_sum.assign(ops.where(should_prune, ops.zeros_like(new_sq_sum), new_sq_sum)) + self.batches_collected.assign(ops.where(should_prune, ops.zeros_like(new_batches), new_batches)) + + def _compute_prune_mask(self, norm, weight): + if self.layer_type == "linear": + return self._handle_linear(norm, weight) + if self.layer_type == "depthwise_conv": + return self._handle_depthwise_conv(norm, weight) + return self._handle_conv(norm, weight) + + def _handle_linear(self, norm, weight): + # norm.shape = (in_features,); weight.shape = (out_features, in_features) metric = ops.abs(weight) * norm if self.N is not None and self.M is not None: - # N:M pruning metric_reshaped = ops.reshape(metric, (-1, self.M)) weight_reshaped = ops.reshape(weight, (-1, self.M)) mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.N / self.M) - self.mask = ops.reshape(mask, weight.shape) - else: - # Unstructured pruning - metric_reshaped = ops.reshape(metric, (1, -1)) - weight_reshaped = ops.reshape(weight, (1, -1)) - mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) - self.mask = ops.reshape(mask, weight.shape) - - def handle_conv(self, x, weight): - inputs_avg = ops.mean(ops.reshape(x, (x.shape[0], x.shape[1], -1)), axis=0) - norm = ops.norm(inputs_avg, ord=2, axis=-1) + return ops.reshape(mask, weight.shape) + metric_reshaped = ops.reshape(metric, (1, -1)) + weight_reshaped = ops.reshape(weight, (1, -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) + + def _handle_conv(self, norm, weight): + # norm.shape = (in_channels,); weight.shape = (out_channels, in_channels, ...) if len(weight.shape) == 3: - norm = ops.reshape(norm, [1] + list(norm.shape) + [1]) + norm_reshaped = ops.reshape(norm, [1] + list(norm.shape) + [1]) else: - norm = ops.reshape(norm, [1] + list(norm.shape) + [1, 1]) - metric = ops.abs(weight) * norm + norm_reshaped = ops.reshape(norm, [1] + list(norm.shape) + [1, 1]) + metric = ops.abs(weight) * norm_reshaped if self.N is not None and self.M is not None: - # N:M pruning metric_reshaped = ops.reshape(metric, (-1, self.M)) weight_reshaped = ops.reshape(weight, (-1, self.M)) mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.N / self.M) - self.mask = ops.reshape(mask, weight.shape) - else: - # Unstructured pruning - metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) - weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) - mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) - self.mask = ops.reshape(mask, weight.shape) + return ops.reshape(mask, weight.shape) + metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) + + def _handle_depthwise_conv(self, norm, weight): + # norm.shape = (in_channels,); weight.shape = (in_channels, depth_mult, kH, kW) + # Prune per-input-channel: norm[ic] scales all weights for that channel + norm_reshaped = ops.reshape(norm, list(norm.shape) + [1, 1, 1]) + metric = ops.abs(weight) * norm_reshaped + metric_reshaped = ops.reshape(metric, (metric.shape[0], -1)) + weight_reshaped = ops.reshape(weight, (weight.shape[0], -1)) + mask = self.get_mask(weight_reshaped, metric_reshaped, sparsity=self.sparsity) + return ops.reshape(mask, weight.shape) - def collect_input(self, x, weight, training): - if self.done or not training: - return - """ - Accumulates layer inputs starting at step t_start_collecting for t_delta steps, then averages it. - Calculates a metric based on weight absolute values and norm of inputs. - For linear layers, calculate norm over batch dimension. - For conv layers, take average over batch dimension and calculate norm over flattened kernel_size dimension. - If N and M are defined, do N:M pruning. - """ - ok_batch = True - if self.inputs is not None: - batch_size = self.inputs.shape[0] - ok_batch = x.shape[0] == batch_size - if not training or not ok_batch: - # Don't collect during validation - return - if self.t < self.t_start_collecting_batch: - return - self.batches_collected += 1 - self.total += 1 - - self.inputs = x if self.inputs is None else self.inputs + x - if self.batches_collected % (self.config.pruning_parameters.t_delta) == 0: - inputs_avg = self.inputs / self.total - self.prune(inputs_avg, weight) - self.done = True - self.inputs = None + def get_mask(self, weight, metric, sparsity): + d0, d1 = metric.shape + keep_idxs = ops.argsort(metric, axis=1)[:, int(d1 * sparsity) :] + ops.arange(d0)[:, None] * d1 + keep_idxs = ops.ravel(keep_idxs) + kept_values = ops.reshape( + ops.scatter(keep_idxs[:, None], ops.take(ops.ravel(weight), keep_idxs), ops.array((ops.size(weight),))), + weight.shape, + ) + return ops.cast(kept_values != 0, weight.dtype) - def prune(self, x, weight): - if self.layer_type == "linear": - self.handle_linear(x, weight) - else: - self.handle_conv(x, weight) + def call(self, weight): + return ops.convert_to_tensor(self.mask) * weight - def call(self, weight): # Mask is only updated every t_delta step, using collect_output - return self.mask * weight + def get_hard_mask(self, weight=None): # noqa: ARG002 + return ops.convert_to_tensor(self.mask) def post_pre_train_function(self): - self.is_pretraining = False + self._is_pretraining = False + if hasattr(self, "is_pretraining"): + self.is_pretraining.assign(False) - def pre_epoch_function(self, epoch, total_epochs): + def pre_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 pass def post_round_function(self): pass def pre_finetune_function(self): - self.is_finetuning = True + self._is_finetuning = True + if hasattr(self, "is_finetuning"): + self.is_finetuning.assign(True) def calculate_additional_loss(self): return 0 @@ -133,17 +187,12 @@ def calculate_additional_loss(self): def get_layer_sparsity(self, weight): pass - def get_hard_mask(self, weight=None): - return self.mask - - def post_epoch_function(self, epoch, total_epochs): - if self.is_pretraining is False: - self.t += 1 - pass + def post_epoch_function(self, epoch, total_epochs, **kwargs): # noqa: ARG002 + if not self._is_pretraining: + self.t.assign_add(1) def get_config(self): config = super().get_config() - config.update( { "config": self.config.get_dict(), diff --git a/tests/test_ap.py b/tests/test_ap.py index 0ee480c..d126898 100644 --- a/tests/test_ap.py +++ b/tests/test_ap.py @@ -41,6 +41,7 @@ def test_linear(config): mask = ops.expand_dims(mask, -1) for _ in range(config["pruning_parameters"]["t_delta"]): ap.collect_output(layer_output, training=True) + ap.post_epoch_function(0, 1) result = ap(weight) result_masked = mask * result # Multiplying by mask should not change the result at all @@ -70,6 +71,7 @@ def test_conv(config): # mask = ops.expand_dims(ops.expand_dims(ops.expand_dims(mask, -1), -1), -1) for _ in range(config["pruning_parameters"]["t_delta"]): ap.collect_output(layer_output, training=True) + ap.post_epoch_function(0, 1) result = ap(weight) result_masked = mask * result # Multiplying by mask should not change the result at all diff --git a/tests/test_keras_compression_layers.py b/tests/test_keras_compression_layers.py index 8a2fabd..2cee1b0 100644 --- a/tests/test_keras_compression_layers.py +++ b/tests/test_keras_compression_layers.py @@ -1,4 +1,3 @@ -from types import SimpleNamespace from unittest.mock import patch import keras @@ -17,7 +16,17 @@ SeparableConv2D, ) +from pquant import ( + ap_config, + autosparse_config, + cs_config, + dst_config, + mdmm_config, + pdp_config, + wanda_config, +) from pquant.activations import PQActivation +from pquant.core.hyperparameter_optimization import PQConfig from pquant.layers import ( PQAvgPool1d, PQAvgPool2d, @@ -34,15 +43,6 @@ pre_finetune_functions, ) - -def _to_obj(x): - if isinstance(x, dict): - return SimpleNamespace(**{k: _to_obj(v) for k, v in x.items()}) - if isinstance(x, list): - return [_to_obj(v) for v in x] - return x - - BATCH_SIZE = 4 OUT_FEATURES = 32 IN_FEATURES = 16 @@ -57,165 +57,173 @@ def run_around_tests(): @pytest.fixture def config_pdp(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "epsilon": 1.0, - "pruning_method": "pdp", - "sparsity": 0.75, - "temperature": 1e-5, - "threshold_decay": 0.0, - "structured_pruning": False, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "epsilon": 1.0, + "pruning_method": "pdp", + "sparsity": 0.75, + "temperature": 1e-5, + "threshold_decay": 0.0, + "structured_pruning": False, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + "granularity": "per_tensor", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_ap(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "pruning_method": "activation_pruning", - "threshold": 0.3, - "t_start_collecting_batch": 0, - "threshold_decay": 0.0, - "t_delta": 1, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "pruning_method": "activation_pruning", + "threshold": 0.3, + "t_start_collecting_batch": 0, + "threshold_decay": 0.0, + "t_delta": 1, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_wanda(): - cfg = { - "pruning_parameters": { - "calculate_pruning_budget": False, - "disable_pruning_for_layers": [], - "enable_pruning": True, - "pruning_method": "wanda", - "sparsity": 0.75, - "t_start_collecting_batch": 0, - "threshold_decay": 0.0, - "t_delta": 1, - "N": None, - "M": None, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "calculate_pruning_budget": False, + "disable_pruning_for_layers": [], + "enable_pruning": True, + "pruning_method": "wanda", + "sparsity": 0.75, + "t_start_collecting_batch": 0, + "threshold_decay": 0.0, + "t_delta": 1, + "N": None, + "M": None, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) @pytest.fixture def config_cs(): - cfg = { - "pruning_parameters": { - "disable_pruning_for_layers": [], - "enable_pruning": True, - "final_temp": 200, - "pruning_method": "cs", - "threshold_decay": 0.0, - "threshold_init": 0.1, - }, - "quantization_parameters": { - "default_weight_integer_bits": 0.0, - "default_weight_fractional_bits": 7.0, - "default_data_integer_bits": 0.0, - "default_data_fractional_bits": 7.0, - "default_data_keep_negatives": 0.0, - "default_weight_keep_negatives": 1.0, - "quantize_input": True, - "quantize_output": False, - "enable_quantization": False, - "hgq_gamma": 0.0003, - "hgq_beta": 1e-5, - "hgq_heterogeneous": True, - "layer_specific": {}, - "use_high_granularity_quantization": False, - "use_real_tanh": False, - "use_relu_multiplier": True, - "use_symmetric_quantization": False, - "round_mode": "RND", - "overflow_mode_parameters": "SAT", - "overflow_mode_data": "SAT", - }, - "training_parameters": {"pruning_first": False}, - "fitcompress_parameters": {"enable_fitcompress": False}, - } - return _to_obj(cfg) + return PQConfig.load_from_config( + { + "pruning_parameters": { + "disable_pruning_for_layers": [], + "enable_pruning": True, + "final_temp": 200, + "pruning_method": "cs", + "threshold_decay": 0.0, + "threshold_init": 0.1, + }, + "quantization_parameters": { + "default_weight_integer_bits": 0.0, + "default_weight_fractional_bits": 7.0, + "default_data_integer_bits": 0.0, + "default_data_fractional_bits": 7.0, + "default_data_keep_negatives": 0.0, + "default_weight_keep_negatives": 1.0, + "quantize_input": True, + "quantize_output": False, + "enable_quantization": False, + "hgq_gamma": 0.0003, + "hgq_beta": 1e-5, + "hgq_heterogeneous": True, + "layer_specific": {}, + "use_high_granularity_quantization": False, + "use_real_tanh": False, + "use_relu_multiplier": True, + "use_symmetric_quantization": False, + "round_mode": "RND", + "overflow_mode_parameters": "SAT", + "overflow_mode_data": "SAT", + }, + "training_parameters": {"pruning_first": False}, + "fitcompress_parameters": {"enable_fitcompress": False}, + } + ) + + +np.random.seed(42) @pytest.fixture(scope="function", autouse=True) @@ -376,19 +384,19 @@ def test_separable_conv2d_trigger_post_pretraining(config_pdp, conv2d_input): model = keras.Model(inputs=inputs, outputs=act2, name="test_conv2d") model = add_compression_layers(model, config_pdp, conv2d_input.shape) - assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining is True - assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining is True - assert model.layers[2].is_pretraining is True - assert model.layers[4].pruning_layer.is_pretraining is True - assert model.layers[5].is_pretraining is True + assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[2].is_pretraining == True # noqa: E712 + assert model.layers[4].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[5].is_pretraining == True # noqa: E712 post_pretrain_functions(model, config_pdp) - assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining is False - assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining is False - assert model.layers[2].is_pretraining is False - assert model.layers[4].pruning_layer.is_pretraining is False - assert model.layers[5].is_pretraining is False + assert model.layers[1].depthwise_conv.pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[1].pointwise_conv.pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[2].is_pretraining == False # noqa: E712 + assert model.layers[4].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[5].is_pretraining == False # noqa: E712 def test_conv1d_call(config_pdp, conv1d_input): @@ -1363,17 +1371,17 @@ def test_trigger_post_pretraining(config_pdp, conv2d_input): model = add_compression_layers(model, config_pdp, conv2d_input.shape) - assert model.layers[1].pruning_layer.is_pretraining is True - assert model.layers[2].is_pretraining is True - assert model.layers[3].pruning_layer.is_pretraining is True - assert model.layers[4].is_pretraining is True + assert model.layers[1].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[2].is_pretraining == True # noqa: E712 + assert model.layers[3].pruning_layer.is_pretraining == True # noqa: E712 + assert model.layers[4].is_pretraining == True # noqa: E712 post_pretrain_functions(model, config_pdp) - assert model.layers[1].pruning_layer.is_pretraining is False - assert model.layers[2].is_pretraining is False - assert model.layers[3].pruning_layer.is_pretraining is False - assert model.layers[4].is_pretraining is False + assert model.layers[1].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[2].is_pretraining == False # noqa: E712 + assert model.layers[3].pruning_layer.is_pretraining == False # noqa: E712 + assert model.layers[4].is_pretraining == False # noqa: E712 def test_hgq_weight_shape(config_pdp, dense_input): @@ -1709,6 +1717,15 @@ def call(self, x, *args, **kwargs): self.layer_called += 1 return x + def post_pre_train_function(self): + pass + + def get_total_bits(self, shape): + return keras.ops.ones(shape) + + def hgq_loss(self): + return 0.0 + def extra_repr(self): return f"Layer called = {self.layer_called} times." @@ -1901,3 +1918,219 @@ def test_layer_replacement_quant_called(config_pdp, conv2d_input): assert model.layers[-1].output_quantizer.layer_called == 1 model(conv2d_input, training=False) assert model.layers[-1].output_quantizer.layer_called == 2 + + +def build_model(config): + inp = keras.layers.Input(shape=((16,))) + x = PQDense( + config, + 64, + in_quant_bits=( + 1.0, + 3.0, + 3.0, + ), + )(inp) + x = keras.layers.ReLU()(x) + x = PQDense(config, 32)(x) + x = keras.layers.ReLU()(x) + x = PQDense(config, 32)(x) + x = keras.layers.ReLU()(x) + out = PQDense(config, 5, out_quant_bits=(1.0, 3.0, 3.0), quantize_output=True)(x) + return keras.Model(inputs=inp, outputs=out) + + +def test_ebops_dense_varied(config_pdp, dense_input): + config = config_pdp + config.pruning_parameters.enable_pruning = False + config.quantization_parameters.use_high_granularity_quantization = True + config.quantization_parameters.hgq_gamma = 1.0 + config.quantization_parameters.enable_quantization = True + config.quantization_parameters.overflow_mode_data = "WRAP" + config.quantization_parameters.overflow_mode_parameters = "SAT_SYM" + model = build_model(config) + model(dense_input, training=True) + post_pretrain_functions(model, config) + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_model_serialization(tmp_path, config_fn): + config = config_fn() + config.quantization_parameters.enable_quantization = True + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + conv1d_shape = (OUT_FEATURES, 16 * 16) + dummy = np.zeros((1, IN_FEATURES, 32, 32), dtype=np.float32) + bn_axis = 1 + else: + input_shape = (32, 32, IN_FEATURES) + conv1d_shape = (16 * 16, OUT_FEATURES) + dummy = np.zeros((1, 32, 32, IN_FEATURES), dtype=np.float32) + bn_axis = -1 + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = PQBatchNormalization(config, axis=bn_axis)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool2d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Reshape(conv1d_shape)(x) + x = PQConv1d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool1d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + model = keras.Model(inputs, x) + + # Call the model once to trigger build of all sublayers (pruning masks, etc.) + model(dummy) + + # Randomize weights; skip boolean flags which must stay at their 0/1 values + rng = np.random.default_rng(42) + for w in model.weights: + if w.name.endswith(("is_pretraining", "is_finetuning")): + continue + w.assign(rng.standard_normal(w.shape).astype(w.dtype)) + + pq_types = (PQConv2d, PQConv1d, PQDense) + + def assert_pq_state_survives_roundtrip(m, label): + """Save m, reload it, and verify all state round-trips correctly.""" + save_path = tmp_path / f"{label}.keras" + m.save(save_path) + reloaded = keras.models.load_model(save_path) + + orig_pq = [layer for layer in m.layers if isinstance(layer, pq_types)] + loaded_pq = [layer for layer in reloaded.layers if isinstance(layer, pq_types)] + assert len(loaded_pq) == len(orig_pq) + + for orig_l, loaded_l in zip(orig_pq, loaded_pq): + for attr in ("final_compression_done", "is_pretraining", "is_finetuning"): + np.testing.assert_equal( + getattr(loaded_l, attr), + getattr(orig_l, attr), + err_msg=f"[{label}] {orig_l.name}.{attr} mismatch", + ) + + assert len(m.weights) == len(reloaded.weights) + for orig_w, loaded_w in zip(m.weights, reloaded.weights): + np.testing.assert_array_equal( + np.array(orig_w), + np.array(loaded_w), + err_msg=f"[{label}] Weight mismatch: {orig_w.name}", + ) + + # Stage 1: initial state (is_pretraining=True, is_finetuning=False) + assert_pq_state_survives_roundtrip(model, "initial") + + # Stage 2: after post_pretrain_functions (is_pretraining=False, is_finetuning=False) + post_pretrain_functions(model, config) + assert_pq_state_survives_roundtrip(model, "post_pretrain") + + # Stage 3: after pre_finetune_functions (is_pretraining=False, is_finetuning=True) + pre_finetune_functions(model) + assert_pq_state_survives_roundtrip(model, "pre_finetune") + + # Stage 4: after apply_final_compression (final_compression_done=True on all layers) + apply_final_compression(model) + assert_pq_state_survives_roundtrip(model, "final_compression") + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_checkpoint_save_load(tmp_path, config_fn): + """Verify that save_weights/load_weights preserves non-trainable pruning state (e.g. mask).""" + config = config_fn() + config.quantization_parameters.enable_quantization = True + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + dummy = np.zeros((1, IN_FEATURES, 32, 32), dtype=np.float32) + else: + input_shape = (32, 32, IN_FEATURES) + dummy = np.zeros((1, 32, 32, IN_FEATURES), dtype=np.float32) + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + model = keras.Model(inputs, x) + model(dummy) + + # Randomize all weights including non-trainable pruning state (mask, etc.) + rng = np.random.default_rng(0) + for w in model.weights: + if w.name.endswith(("is_pretraining", "is_finetuning")): + continue + w.assign(rng.standard_normal(w.shape).astype(w.dtype)) + + original_weights = [np.array(w) for w in model.weights] + + path = str(tmp_path / "ckpt.weights.h5") + model.save_weights(path) + + # Overwrite all weights with zeros + for w in model.weights: + w.assign(np.zeros(w.shape, dtype=w.dtype)) + + model.load_weights(path) + + for orig, w in zip(original_weights, model.weights): + np.testing.assert_array_equal(orig, np.array(w), err_msg=f"Checkpoint weight mismatch: {w.name}") + + +@pytest.mark.parametrize( + "config_fn", + [pdp_config, ap_config, autosparse_config, cs_config, dst_config, mdmm_config, wanda_config], + ids=["pdp", "ap", "autosparse", "cs", "dst", "mdmm", "wanda"], +) +def test_model_fit(config_fn): + from pquant.core.keras.train import PQuantCallback + + config = config_fn() + config.quantization_parameters.enable_quantization = True + config.training_parameters.pretraining_epochs = 1 + config.training_parameters.epochs = 1 + config.training_parameters.fine_tuning_epochs = 1 + config.training_parameters.rounds = 1 + config.training_parameters.save_weights_epoch = 0 + + channels_first = keras.backend.image_data_format() == "channels_first" + if channels_first: + input_shape = (IN_FEATURES, 32, 32) + conv1d_shape = (OUT_FEATURES, 16 * 16) + dummy_x = np.zeros((BATCH_SIZE, IN_FEATURES, 32, 32), dtype=np.float32) + bn_axis = 1 + else: + input_shape = (32, 32, IN_FEATURES) + conv1d_shape = (16 * 16, OUT_FEATURES) + dummy_x = np.zeros((BATCH_SIZE, 32, 32, IN_FEATURES), dtype=np.float32) + bn_axis = -1 + + inputs = keras.Input(shape=input_shape) + x = PQConv2d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(inputs) + x = PQBatchNormalization(config, axis=bn_axis)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool2d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Reshape(conv1d_shape)(x) + x = PQConv1d(config, OUT_FEATURES, KERNEL_SIZE, padding="same")(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + x = PQAvgPool1d(config, pool_size=2, strides=2, padding="same")(x) + x = keras.layers.Flatten()(x) + x = PQDense(config, units=OUT_FEATURES)(x) + x = PQActivation(config, activation="relu", quantize_input=True, quantize_output=True)(x) + model = keras.Model(inputs, x) + + dummy_y = np.zeros((BATCH_SIZE, model.output_shape[-1]), dtype=np.float32) + + model.compile(optimizer="adam", loss="mse", jit_compile=False) + callback = PQuantCallback(config, log_ebops=False, log_keep_ratio=False) + model.fit(dummy_x, dummy_y, epochs=callback.total_epochs, callbacks=[callback], verbose=0) diff --git a/tests/test_pdp.py b/tests/test_pdp.py index 0cebf16..4a6a081 100644 --- a/tests/test_pdp.py +++ b/tests/test_pdp.py @@ -75,15 +75,19 @@ def test_linear_structured(config): sparsity = config["pruning_parameters"]["sparsity"] config["pruning_parameters"]["structured_pruning"] = True - inp = ops.linspace(-1, 1, num=OUT_FEATURES) + # PDP structured linear prunes rows (dim 0 = out_features). + # Weight shape is (OUT_FEATURES, IN_FEATURES) matching the transposed convention + # used by the Keras layer before passing to the pruning layer. + row_scales = ops.linspace(-1, 1, num=OUT_FEATURES) threshold_point = int(OUT_FEATURES * sparsity) - 1 - threshold_value = sorted(ops.abs(inp))[threshold_point] - inp = shuffle(inp) - mask = ops.cast((ops.abs(inp) > threshold_value), inp.dtype) - inp = ops.tile(inp, (IN_FEATURES, 1)) - mask = ops.tile(mask, (IN_FEATURES, 1)) - threshold_point = int(OUT_FEATURES * sparsity) - # In the matrix of shape NxM, prune over the M dimension + threshold_value = sorted(ops.abs(row_scales))[threshold_point] + row_scales = shuffle(row_scales) + mask_1d = ops.cast((ops.abs(row_scales) > threshold_value), row_scales.dtype) + + # Each row i has uniform value row_scales[i], giving distinct per-row norms. + inp = ops.tile(ops.reshape(row_scales, (OUT_FEATURES, 1)), (1, IN_FEATURES)) + mask = ops.tile(ops.reshape(mask_1d, (OUT_FEATURES, 1)), (1, IN_FEATURES)) + pdp = PDP(config, "linear") pdp.post_pre_train_function() pdp.build(inp.shape) @@ -99,17 +103,21 @@ def test_conv_structured(config): config["pruning_parameters"]["structured_pruning"] = True sparsity = config["pruning_parameters"]["sparsity"] - inp = ops.ones(shape=(IN_FEATURES, OUT_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) - channel_dim = ops.linspace(-1, 1, num=OUT_FEATURES) + # PDP structured conv prunes rows (dim 0 = out_channels). + # Weight shape is (OUT_FEATURES, IN_FEATURES, kH, kW) matching the transposed + # convention used by the Keras layer before passing to the pruning layer. + channel_scales = ops.linspace(-1, 1, num=OUT_FEATURES) threshold_point = int(OUT_FEATURES * sparsity) - 1 - threshold_value = sorted(ops.abs(channel_dim))[threshold_point] - channel_dim = shuffle(channel_dim) - mask = ops.cast((ops.abs(channel_dim) > threshold_value), inp.dtype) - mask = ops.expand_dims(ops.expand_dims(ops.expand_dims(mask, axis=0), axis=-1), axis=-1) - mask = ops.tile(mask, (IN_FEATURES, 1, KERNEL_SIZE, KERNEL_SIZE)) - - mult = ops.expand_dims(ops.expand_dims(ops.expand_dims(channel_dim, axis=0), axis=-1), axis=-1) - inp *= mult + threshold_value = sorted(ops.abs(channel_scales))[threshold_point] + channel_scales = shuffle(channel_scales) + mask_1d = ops.cast((ops.abs(channel_scales) > threshold_value), channel_scales.dtype) + + # Each output channel c has all spatial elements equal to channel_scales[c]. + mult = ops.reshape(channel_scales, (OUT_FEATURES, 1, 1, 1)) + inp = ops.ones(shape=(OUT_FEATURES, IN_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) * mult + mask = ops.ones(shape=(OUT_FEATURES, IN_FEATURES, KERNEL_SIZE, KERNEL_SIZE)) * ops.reshape( + mask_1d, (OUT_FEATURES, 1, 1, 1) + ) pdp = PDP(config, "conv") pdp.post_pre_train_function() diff --git a/tests/test_torch_compression_layers.py b/tests/test_torch_compression_layers.py index ccf90a3..cb4d1ff 100644 --- a/tests/test_torch_compression_layers.py +++ b/tests/test_torch_compression_layers.py @@ -1,5 +1,3 @@ -from types import SimpleNamespace - import keras import numpy as np import pytest @@ -18,6 +16,7 @@ from pquant import post_training_prune from pquant.activations import PQActivation +from pquant.core.hyperparameter_optimization import PQConfig from pquant.layers import ( PQAvgPool1d, PQAvgPool2d, @@ -34,15 +33,6 @@ pre_finetune_functions, ) - -def _to_obj(x): - if isinstance(x, dict): - return SimpleNamespace(**{k: _to_obj(v) for k, v in x.items()}) - if isinstance(x, list): - return [_to_obj(v) for v in x] - return x - - BATCH_SIZE = 4 OUT_FEATURES = 32 IN_FEATURES = 16 @@ -84,11 +74,12 @@ def config_pdp(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -124,11 +115,12 @@ def config_ap(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -167,11 +159,12 @@ def config_wanda(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -206,11 +199,12 @@ def config_cs(): "round_mode": "RND", "overflow_mode_parameters": "SAT", "overflow_mode_data": "SAT", + "granularity": "per_tensor", }, "training_parameters": {"pruning_first": False}, "fitcompress_parameters": {"enable_fitcompress": False}, } - return _to_obj(cfg) + return PQConfig.load_from_config(cfg) @pytest.fixture @@ -586,17 +580,20 @@ def test_trigger_post_pretraining(config_pdp, dense_input): model = add_compression_layers(model, config_pdp, dense_input.shape) - assert model.submodule.pruning_layer.is_pretraining is True - assert model.activation.is_pretraining is True - assert model.submodule2.pruning_layer.is_pretraining is True - assert model.activation2.is_pretraining is True + def _to_bool(val): + return val.numpy() if hasattr(val, "numpy") else bool(val) + + assert _to_bool(model.submodule.pruning_layer.is_pretraining) + assert _to_bool(model.activation.is_pretraining) + assert _to_bool(model.submodule2.pruning_layer.is_pretraining) + assert _to_bool(model.activation2.is_pretraining) post_pretrain_functions(model, config_pdp) - assert model.submodule.pruning_layer.is_pretraining is False - assert model.activation.is_pretraining is False - assert model.submodule2.pruning_layer.is_pretraining is False - assert model.activation2.is_pretraining is False + assert not _to_bool(model.submodule.pruning_layer.is_pretraining) + assert not _to_bool(model.activation.is_pretraining) + assert not _to_bool(model.submodule2.pruning_layer.is_pretraining) + assert not _to_bool(model.activation2.is_pretraining) def test_hgq_weight_shape(config_pdp, dense_input): @@ -1777,15 +1774,19 @@ def test_hgq_loss_calc_no_qoutput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 3.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops @@ -1812,15 +1813,20 @@ def test_hgq_loss_calc_no_bias_no_qoutput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss m.weight_quantizer.hgq_loss = dummy_hgq_loss - m.bias_quantizer.hgq_loss = dummy_hgq_loss # Won't be called - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "bias_quantizer"): + m.bias_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 2.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss # Won't be called expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops @@ -1883,19 +1889,24 @@ def test_hgq_loss_calc_no_qinput(config_pdp, conv2d_input): for m in model.modules(): if isinstance(m, (PQWeightBiasBase)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss - m.output_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 3.0 elif isinstance(m, (PQAvgPool1d, PQAvgPool2d, PQActivation)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called - m.output_quantizer.hgq_loss = dummy_hgq_loss + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "output_quantizer"): + m.output_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 1.0 elif isinstance(m, (PQBatchNorm2d)): m.ebops = dummy_ebops - m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called + if hasattr(m, "input_quantizer"): + m.input_quantizer.hgq_loss = dummy_hgq_loss # Won't be called m.weight_quantizer.hgq_loss = dummy_hgq_loss m.bias_quantizer.hgq_loss = dummy_hgq_loss expected_loss += 2.0