Skip to content

Commit e2c8c60

Browse files
Arm backend: Align docstrings in arm_quantizer.py with backend standard (#15975)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent 7e869be commit e2c8c60

File tree

3 files changed

+238
-99
lines changed

3 files changed

+238
-99
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 152 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,25 @@ def get_symmetric_quantization_config(
7272
act_qmax: int = 127,
7373
weight_qmin: int = -127,
7474
weight_qmax: int = 127,
75-
):
75+
) -> QuantizationConfig:
76+
"""Create symmetric quantization config for activations and weights.
77+
78+
Args:
79+
is_per_channel (bool): Whether to use per-channel quantization for
80+
weights.
81+
is_qat (bool): Whether the configuration targets quantization aware
82+
training.
83+
is_dynamic (bool): Whether to generate dynamic activation observers.
84+
act_qmin (int): Minimum activation quantization value.
85+
act_qmax (int): Maximum activation quantization value.
86+
weight_qmin (int): Minimum weight quantization value.
87+
weight_qmax (int): Maximum weight quantization value.
88+
89+
Returns:
90+
QuantizationConfig: Quantization settings for activations, weights, and
91+
bias.
92+
93+
"""
7694
extra_args: Dict[str, Any] = {"eps": 2**-12}
7795
if is_qat:
7896
if is_dynamic:
@@ -169,23 +187,26 @@ def get_symmetric_a16w8_quantization_config(
169187
weight_qmin: int = -127,
170188
weight_qmax: int = 127,
171189
epsilon: float = 2**-12,
172-
):
173-
"""
174-
16A8W quantization config: 16-bit activations, 8-bit weights.
190+
) -> QuantizationConfig:
191+
"""16A8W quantization config: 16-bit activations, 8-bit weights.
175192
176193
This configuration provides better accuracy than 8A8W while maintaining
177194
reasonable memory usage through 8-bit weights.
178195
179196
Args:
180-
is_per_channel: Whether to use per-channel quantization for weights
181-
is_qat: Whether this is for Quantization Aware Training
182-
is_dynamic: Whether to use dynamic quantization
183-
weight_qmin: Minimum quantization value for weights
184-
weight_qmax: Maximum quantization value for weights
185-
epsilon: Value used to pad observed [qmin, qmax] before initial zero point and scale calculation
197+
is_per_channel (bool): Whether to use per-channel quantization for
198+
weights.
199+
is_qat (bool): Whether this is for quantization aware training.
200+
is_dynamic (bool): Whether to use dynamic quantization.
201+
weight_qmin (int): Minimum quantization value for weights.
202+
weight_qmax (int): Maximum quantization value for weights.
203+
epsilon (float): Value used to pad observed [qmin, qmax] before initial
204+
zero-point and scale calculation.
186205
187206
Returns:
188-
QuantizationConfig with 16-bit activations and 8-bit weights
207+
QuantizationConfig: Configuration with 16-bit activations and 8-bit
208+
weights.
209+
189210
"""
190211
extra_args: Dict[str, Any] = {"eps": epsilon}
191212

@@ -244,27 +265,39 @@ def get_symmetric_a16w8_quantization_config(
244265

245266

246267
NodeFilterType = Callable[[Node], bool]
247-
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
248-
a Node and returns whether the node should be annotated or not.
268+
"""Type for a Node Filter used by annotators.
269+
270+
A Node filter is a function that takes a Node and returns whether the node
271+
should be annotated or not.
272+
249273
"""
250274

251275

252276
def _get_module_type_filter(tp: Callable) -> NodeFilterType:
253-
"""Get the module_type_filter function for a given module type, the filter accepts
254-
a node and checks if the node comes from a module that has certain module type
277+
"""Get the module_type_filter function for a given module type.
255278
256-
For example:
257-
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
279+
The filter accepts a node and checks if the node comes from a module that
280+
has a certain module type.
281+
282+
Args:
283+
tp (Callable): Module class to match against the graph node metadata.
284+
285+
Returns:
286+
NodeFilterType: Predicate that returns True for nodes from the module
287+
type.
258288
289+
For example:
290+
node: linear_op = call_function[...](...) # type Block -> Sub -> Linear
259291
260-
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
292+
>> module_type_filter = _get_module_type_filter(Sub)
261293
>> print(module_type_filter(node))
262-
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
263-
"""
294+
True # the node is from the submodule `Sub` (same for `Block` and `Linear`)
264295
296+
"""
265297
tp_str = tp.__module__ + "." + tp.__qualname__
266298

267299
def module_type_filter(n: Node) -> bool:
300+
"""Return True if the node originates from the target module type."""
268301
# node_stack example: {
269302
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
270303
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
@@ -279,16 +312,29 @@ def module_type_filter(n: Node) -> bool:
279312
def _get_not_module_type_or_name_filter(
280313
tp_list: List[Callable], module_name_list: List[str]
281314
) -> NodeFilterType:
315+
"""Create a filter that excludes provided module types and names.
316+
317+
Args:
318+
tp_list (List[Callable]): Module types to exclude from annotation.
319+
module_name_list (List[str]): Module names to exclude from annotation.
320+
321+
Returns:
322+
NodeFilterType: Filter that returns True when the node does not match
323+
any provided module type or name.
324+
325+
"""
282326
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
283327
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
284328

285329
def not_module_type_or_name_filter(n: Node) -> bool:
330+
"""Return True when the node matches none of the blocked filters."""
286331
return not any(f(n) for f in module_type_filters + module_name_list_filters)
287332

288333
return not_module_type_or_name_filter
289334

290335

291336
class TOSAQuantizer(Quantizer):
337+
"""Manage quantization annotations for TOSA-compatible backends."""
292338

293339
def __init__(
294340
self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec
@@ -314,41 +360,47 @@ def __init__(
314360
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
315361

316362
def set_global(self, quantization_config: QuantizationConfig) -> TOSAQuantizer:
317-
"""
318-
Set quantization_config for submodules that are not already annotated by name or type filters.
363+
"""Set quantization_config for submodules not matched by other filters.
319364
320365
Args:
321-
quantization_config: The QuantizationConfig to set as global configuration.
366+
quantization_config (QuantizationConfig): Configuration to apply to
367+
modules that are not captured by name or type filters.
368+
322369
"""
323370
self.global_config = quantization_config
324371
return self
325372

326373
def set_module_type(
327374
self, module_type: Callable, quantization_config: QuantizationConfig
328375
) -> TOSAQuantizer:
329-
"""
330-
Set quantization_config for a submodule with type: `module_type`, for example:
331-
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
332-
patterns in the submodule with this module type with the given `quantization_config`.
376+
"""Set quantization_config for submodules with a given module type.
377+
378+
For example, calling set_module_type(Sub) quantizes supported patterns
379+
in each Sub instance with the provided quantization_config.
333380
334381
Args:
335-
module_type: The type of the submodule to set the quantization config for.
336-
quantization_config: The QuantizationConfig to set for the submodule.
382+
module_type (Callable): Type whose submodules should use the
383+
provided quantization configuration.
384+
quantization_config (QuantizationConfig): Configuration to apply to
385+
submodules of the given type.
386+
337387
"""
338388
self.module_type_config[module_type] = quantization_config
339389
return self
340390

341391
def set_module_name(
342392
self, module_name: str, quantization_config: Optional[QuantizationConfig]
343393
) -> TOSAQuantizer:
344-
"""
345-
Set quantization_config for a submodule with name: `module_name`, for example:
346-
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
347-
patterns in the submodule with this module name with the given `quantization_config`
394+
"""Set quantization_config for submodules with a given module name.
395+
396+
For example, calling set_module_name("blocks.sub") quantizes supported
397+
patterns for that submodule with the provided quantization_config.
348398
349399
Args:
350-
module_name: The name of the submodule to set the quantization config for.
351-
quantization_config: The QuantizationConfig to set for the submodule.
400+
module_name (str): Fully qualified module name to configure.
401+
quantization_config (QuantizationConfig): Configuration applied to
402+
the named submodule.
403+
352404
"""
353405
# Validate that quantization_config is provided
354406
if quantization_config is None:
@@ -357,26 +409,28 @@ def set_module_name(
357409
return self
358410

359411
def set_io(self, quantization_config: QuantizationConfig) -> TOSAQuantizer:
360-
"""
361-
Set quantization_config for input and output nodes.
412+
"""Set quantization_config for input and output nodes.
362413
363414
Args:
364-
quantization_config: The QuantizationConfig to set for input and output nodes.
415+
quantization_config (QuantizationConfig): Configuration describing
416+
activation quantization for model inputs and outputs.
417+
365418
"""
366419
self.io_config = quantization_config
367420
return self
368421

369422
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
370-
"""
371-
An initial pass for transforming the graph to prepare it for annotation.
423+
"""Transform the graph to prepare it for quantization annotation.
424+
372425
Currently transforms scalar values to tensor attributes.
373426
374427
Args:
375-
model: The model to transform.
428+
model (GraphModule): Model whose graph will be transformed.
429+
376430
Returns:
377-
The transformed model.
378-
"""
431+
GraphModule: Transformed model prepared for annotation.
379432
433+
"""
380434
# TODO: Fix the need to lazily import this.
381435
from executorch.backends.arm._passes import ArmPassManager
382436

@@ -385,12 +439,16 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
385439
)
386440

387441
def annotate(self, model: GraphModule) -> GraphModule:
388-
"""Performs the quantization annotation on the graph.
389-
Currently only does static quantization annotation.
442+
"""Annotate the graph with the configured quantization settings.
443+
444+
Currently only does static quantization annotation.
445+
390446
Args:
391-
model: The model to annotate statically.
447+
model (GraphModule): Model to annotate statically.
448+
392449
Returns:
393-
The annotated model.
450+
GraphModule: Annotated model ready for export.
451+
394452
"""
395453
model = self._annotate_for_static_quantization_config(model)
396454
return model
@@ -401,14 +459,19 @@ def _annotate_all_static_patterns(
401459
quantization_config: Optional[QuantizationConfig],
402460
filter_fn: Optional[Callable[[Node], bool]] = None,
403461
) -> GraphModule:
404-
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
462+
"""Annotate all static patterns registered for the backend.
463+
405464
Args:
406-
model: The model to annotate statically.
407-
quantization_config: Specifies the QuantizationSpecs for the model's
408-
input activations, output activations, weights and biases.
409-
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
465+
model (GraphModule): Model to annotate statically.
466+
quantization_config (Optional[QuantizationConfig]): Quantization
467+
specs for input activations, output activations, weights, and
468+
biases.
469+
filter_fn (Optional[Callable[[Node], bool]]): Optional node filter
470+
specifying which nodes to annotate.
471+
410472
Returns:
411-
The annotated model.
473+
GraphModule: Model populated with quantization annotations.
474+
412475
"""
413476
# TODO: implement the support for None to be canceling out previous annotations
414477
if quantization_config is None:
@@ -420,8 +483,15 @@ def _annotate_all_static_patterns(
420483
def _annotate_for_static_quantization_config(
421484
self, model: GraphModule
422485
) -> GraphModule:
423-
"""Matches the correct QuantizationConfig with the correct module using a filter
424-
when running _annotate_all_static_patterns.
486+
"""Match QuantizationConfigs to modules before annotating patterns.
487+
488+
Args:
489+
model (GraphModule): Model whose modules are being matched to
490+
quantization configs.
491+
492+
Returns:
493+
GraphModule: Annotated model after applying configured filters.
494+
425495
"""
426496
if self.io_config:
427497
self._annotate_io(model, self.io_config)
@@ -451,6 +521,14 @@ def _annotate_io(
451521
model: GraphModule,
452522
quantization_config: QuantizationConfig,
453523
):
524+
"""Annotate graph inputs and outputs with the provided configuration.
525+
526+
Args:
527+
model (GraphModule): GraphModule being annotated.
528+
quantization_config (QuantizationConfig): Activation qspecs to apply
529+
to IO nodes.
530+
531+
"""
454532
for node in model.graph.nodes:
455533
if is_annotated(node):
456534
continue
@@ -468,6 +546,7 @@ def _annotate_io(
468546
mark_node_as_annotated(node)
469547

470548
def validate(self, model: GraphModule) -> None:
549+
"""TODO: Implement validation of annotated graph for TOSA backend."""
471550
pass
472551

473552
def quantize_with_submodules(
@@ -479,10 +558,16 @@ def quantize_with_submodules(
479558
"""Quantizes a GraphModule in a way such that conditional submodules are handled properly.
480559
481560
Args:
482-
model: GraphModule, the model to quantize.
483-
calibration_samples: list[tuple], a list of inputs to used to calibrate the model during quantization.
484-
To properly calibrate a model with submodules, at least one sample per code path is needed.
485-
is_qat: bool, whether to do quantization aware training or not.
561+
model (GraphModule): The model to quantize.
562+
calibration_samples (list[tuple]): A list of inputs to used to
563+
calibrate the model during quantization. To properly calibrate a
564+
model with submodules, at least one sample per code path is
565+
needed.
566+
is_qat (bool): Whether to do quantization aware training or not.
567+
568+
Returns:
569+
GraphModule: The quantized model.
570+
486571
"""
487572
prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e
488573

@@ -499,23 +584,25 @@ def quantize_with_submodules(
499584

500585

501586
class EthosUQuantizer(TOSAQuantizer):
502-
"""
503-
Quantizer supported by the Arm Ethos-U backend.
587+
"""Quantizer supported by the Arm Ethos-U backend.
504588
505589
Args:
506-
compile_spec: A EthosUCompileSpec instance.
590+
compile_spec (EthosUCompileSpec): Backend compile specification for
591+
Ethos-U targets.
592+
507593
"""
508594

509595
def __init__(self, compile_spec: EthosUCompileSpec) -> None:
510596
super().__init__(compile_spec)
511597

512598

513599
class VgfQuantizer(TOSAQuantizer):
514-
"""
515-
Quantizer supported by the Arm Vgf backend.
600+
"""Quantizer supported by the Arm Vgf backend.
516601
517602
Args:
518-
compile_spec: A VgfCompileSpec instance.
603+
compile_spec (VgfCompileSpec): Backend compile specification for Vgf
604+
targets.
605+
519606
"""
520607

521608
def __init__(self, compile_spec: VgfCompileSpec) -> None:

0 commit comments

Comments
 (0)