@@ -72,7 +72,25 @@ def get_symmetric_quantization_config(
7272 act_qmax : int = 127 ,
7373 weight_qmin : int = - 127 ,
7474 weight_qmax : int = 127 ,
75- ):
75+ ) -> QuantizationConfig :
76+ """Create symmetric quantization config for activations and weights.
77+
78+ Args:
79+ is_per_channel (bool): Whether to use per-channel quantization for
80+ weights.
81+ is_qat (bool): Whether the configuration targets quantization aware
82+ training.
83+ is_dynamic (bool): Whether to generate dynamic activation observers.
84+ act_qmin (int): Minimum activation quantization value.
85+ act_qmax (int): Maximum activation quantization value.
86+ weight_qmin (int): Minimum weight quantization value.
87+ weight_qmax (int): Maximum weight quantization value.
88+
89+ Returns:
90+ QuantizationConfig: Quantization settings for activations, weights, and
91+ bias.
92+
93+ """
7694 extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
7795 if is_qat :
7896 if is_dynamic :
@@ -169,23 +187,26 @@ def get_symmetric_a16w8_quantization_config(
169187 weight_qmin : int = - 127 ,
170188 weight_qmax : int = 127 ,
171189 epsilon : float = 2 ** - 12 ,
172- ):
173- """
174- 16A8W quantization config: 16-bit activations, 8-bit weights.
190+ ) -> QuantizationConfig :
191+ """16A8W quantization config: 16-bit activations, 8-bit weights.
175192
176193 This configuration provides better accuracy than 8A8W while maintaining
177194 reasonable memory usage through 8-bit weights.
178195
179196 Args:
180- is_per_channel: Whether to use per-channel quantization for weights
181- is_qat: Whether this is for Quantization Aware Training
182- is_dynamic: Whether to use dynamic quantization
183- weight_qmin: Minimum quantization value for weights
184- weight_qmax: Maximum quantization value for weights
185- epsilon: Value used to pad observed [qmin, qmax] before initial zero point and scale calculation
197+ is_per_channel (bool): Whether to use per-channel quantization for
198+ weights.
199+ is_qat (bool): Whether this is for quantization aware training.
200+ is_dynamic (bool): Whether to use dynamic quantization.
201+ weight_qmin (int): Minimum quantization value for weights.
202+ weight_qmax (int): Maximum quantization value for weights.
203+ epsilon (float): Value used to pad observed [qmin, qmax] before initial
204+ zero-point and scale calculation.
186205
187206 Returns:
188- QuantizationConfig with 16-bit activations and 8-bit weights
207+ QuantizationConfig: Configuration with 16-bit activations and 8-bit
208+ weights.
209+
189210 """
190211 extra_args : Dict [str , Any ] = {"eps" : epsilon }
191212
@@ -244,27 +265,39 @@ def get_symmetric_a16w8_quantization_config(
244265
245266
246267NodeFilterType = Callable [[Node ], bool ]
247- """Type for a Node Filter used by annotators. A Node filter is a function that takes
248- a Node and returns whether the node should be annotated or not.
268+ """Type for a Node Filter used by annotators.
269+
270+ A Node filter is a function that takes a Node and returns whether the node
271+ should be annotated or not.
272+
249273"""
250274
251275
252276def _get_module_type_filter (tp : Callable ) -> NodeFilterType :
253- """Get the module_type_filter function for a given module type, the filter accepts
254- a node and checks if the node comes from a module that has certain module type
277+ """Get the module_type_filter function for a given module type.
255278
256- For example:
257- node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
279+ The filter accepts a node and checks if the node comes from a module that
280+ has a certain module type.
281+
282+ Args:
283+ tp (Callable): Module class to match against the graph node metadata.
284+
285+ Returns:
286+ NodeFilterType: Predicate that returns True for nodes from the module
287+ type.
258288
289+ For example:
290+ node: linear_op = call_function[...](...) # type Block -> Sub -> Linear
259291
260- >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
292+ >> module_type_filter = _get_module_type_filter(Sub)
261293 >> print(module_type_filter(node))
262- True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
263- """
294+ True # the node is from the submodule `Sub` (same for `Block` and `Linear`)
264295
296+ """
265297 tp_str = tp .__module__ + "." + tp .__qualname__
266298
267299 def module_type_filter (n : Node ) -> bool :
300+ """Return True if the node originates from the target module type."""
268301 # node_stack example: {
269302 # 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
270303 # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
@@ -279,16 +312,29 @@ def module_type_filter(n: Node) -> bool:
279312def _get_not_module_type_or_name_filter (
280313 tp_list : List [Callable ], module_name_list : List [str ]
281314) -> NodeFilterType :
315+ """Create a filter that excludes provided module types and names.
316+
317+ Args:
318+ tp_list (List[Callable]): Module types to exclude from annotation.
319+ module_name_list (List[str]): Module names to exclude from annotation.
320+
321+ Returns:
322+ NodeFilterType: Filter that returns True when the node does not match
323+ any provided module type or name.
324+
325+ """
282326 module_type_filters = [_get_module_type_filter (tp ) for tp in tp_list ]
283327 module_name_list_filters = [get_module_name_filter (m ) for m in module_name_list ]
284328
285329 def not_module_type_or_name_filter (n : Node ) -> bool :
330+ """Return True when the node matches none of the blocked filters."""
286331 return not any (f (n ) for f in module_type_filters + module_name_list_filters )
287332
288333 return not_module_type_or_name_filter
289334
290335
291336class TOSAQuantizer (Quantizer ):
337+ """Manage quantization annotations for TOSA-compatible backends."""
292338
293339 def __init__ (
294340 self , compile_spec_or_tosa_spec : TosaSpecification | ArmCompileSpec
@@ -314,41 +360,47 @@ def __init__(
314360 self .module_name_config : Dict [str , Optional [QuantizationConfig ]] = {}
315361
316362 def set_global (self , quantization_config : QuantizationConfig ) -> TOSAQuantizer :
317- """
318- Set quantization_config for submodules that are not already annotated by name or type filters.
363+ """Set quantization_config for submodules not matched by other filters.
319364
320365 Args:
321- quantization_config: The QuantizationConfig to set as global configuration.
366+ quantization_config (QuantizationConfig): Configuration to apply to
367+ modules that are not captured by name or type filters.
368+
322369 """
323370 self .global_config = quantization_config
324371 return self
325372
326373 def set_module_type (
327374 self , module_type : Callable , quantization_config : QuantizationConfig
328375 ) -> TOSAQuantizer :
329- """
330- Set quantization_config for a submodule with type: `module_type`, for example:
331- quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
332- patterns in the submodule with this module type with the given ` quantization_config` .
376+ """Set quantization_config for submodules with a given module type.
377+
378+ For example, calling set_module_type(Sub) quantizes supported patterns
379+ in each Sub instance with the provided quantization_config.
333380
334381 Args:
335- module_type: The type of the submodule to set the quantization config for.
336- quantization_config: The QuantizationConfig to set for the submodule.
382+ module_type (Callable): Type whose submodules should use the
383+ provided quantization configuration.
384+ quantization_config (QuantizationConfig): Configuration to apply to
385+ submodules of the given type.
386+
337387 """
338388 self .module_type_config [module_type ] = quantization_config
339389 return self
340390
341391 def set_module_name (
342392 self , module_name : str , quantization_config : Optional [QuantizationConfig ]
343393 ) -> TOSAQuantizer :
344- """
345- Set quantization_config for a submodule with name: `module_name`, for example:
346- quantizer. set_module_name("blocks.sub"), it will quantize all supported operator/operator
347- patterns in the submodule with this module name with the given ` quantization_config`
394+ """Set quantization_config for submodules with a given module name.
395+
396+ For example, calling set_module_name("blocks.sub") quantizes supported
397+ patterns for that submodule with the provided quantization_config.
348398
349399 Args:
350- module_name: The name of the submodule to set the quantization config for.
351- quantization_config: The QuantizationConfig to set for the submodule.
400+ module_name (str): Fully qualified module name to configure.
401+ quantization_config (QuantizationConfig): Configuration applied to
402+ the named submodule.
403+
352404 """
353405 # Validate that quantization_config is provided
354406 if quantization_config is None :
@@ -357,26 +409,28 @@ def set_module_name(
357409 return self
358410
359411 def set_io (self , quantization_config : QuantizationConfig ) -> TOSAQuantizer :
360- """
361- Set quantization_config for input and output nodes.
412+ """Set quantization_config for input and output nodes.
362413
363414 Args:
364- quantization_config: The QuantizationConfig to set for input and output nodes.
415+ quantization_config (QuantizationConfig): Configuration describing
416+ activation quantization for model inputs and outputs.
417+
365418 """
366419 self .io_config = quantization_config
367420 return self
368421
369422 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
370- """
371- An initial pass for transforming the graph to prepare it for annotation.
423+ """Transform the graph to prepare it for quantization annotation.
424+
372425 Currently transforms scalar values to tensor attributes.
373426
374427 Args:
375- model: The model to transform.
428+ model (GraphModule): Model whose graph will be transformed.
429+
376430 Returns:
377- The transformed model.
378- """
431+ GraphModule: Transformed model prepared for annotation.
379432
433+ """
380434 # TODO: Fix the need to lazily import this.
381435 from executorch .backends .arm ._passes import ArmPassManager
382436
@@ -385,12 +439,16 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
385439 )
386440
387441 def annotate (self , model : GraphModule ) -> GraphModule :
388- """Performs the quantization annotation on the graph.
389- Currently only does static quantization annotation.
442+ """Annotate the graph with the configured quantization settings.
443+
444+ Currently only does static quantization annotation.
445+
390446 Args:
391- model: The model to annotate statically.
447+ model (GraphModule): Model to annotate statically.
448+
392449 Returns:
393- The annotated model.
450+ GraphModule: Annotated model ready for export.
451+
394452 """
395453 model = self ._annotate_for_static_quantization_config (model )
396454 return model
@@ -401,14 +459,19 @@ def _annotate_all_static_patterns(
401459 quantization_config : Optional [QuantizationConfig ],
402460 filter_fn : Optional [Callable [[Node ], bool ]] = None ,
403461 ) -> GraphModule :
404- """Loops over all STATIC_OPS and runs the corresponding registered annotator.
462+ """Annotate all static patterns registered for the backend.
463+
405464 Args:
406- model: The model to annotate statically.
407- quantization_config: Specifies the QuantizationSpecs for the model's
408- input activations, output activations, weights and biases.
409- filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
465+ model (GraphModule): Model to annotate statically.
466+ quantization_config (Optional[QuantizationConfig]): Quantization
467+ specs for input activations, output activations, weights, and
468+ biases.
469+ filter_fn (Optional[Callable[[Node], bool]]): Optional node filter
470+ specifying which nodes to annotate.
471+
410472 Returns:
411- The annotated model.
473+ GraphModule: Model populated with quantization annotations.
474+
412475 """
413476 # TODO: implement the support for None to be canceling out previous annotations
414477 if quantization_config is None :
@@ -420,8 +483,15 @@ def _annotate_all_static_patterns(
420483 def _annotate_for_static_quantization_config (
421484 self , model : GraphModule
422485 ) -> GraphModule :
423- """Matches the correct QuantizationConfig with the correct module using a filter
424- when running _annotate_all_static_patterns.
486+ """Match QuantizationConfigs to modules before annotating patterns.
487+
488+ Args:
489+ model (GraphModule): Model whose modules are being matched to
490+ quantization configs.
491+
492+ Returns:
493+ GraphModule: Annotated model after applying configured filters.
494+
425495 """
426496 if self .io_config :
427497 self ._annotate_io (model , self .io_config )
@@ -451,6 +521,14 @@ def _annotate_io(
451521 model : GraphModule ,
452522 quantization_config : QuantizationConfig ,
453523 ):
524+ """Annotate graph inputs and outputs with the provided configuration.
525+
526+ Args:
527+ model (GraphModule): GraphModule being annotated.
528+ quantization_config (QuantizationConfig): Activation qspecs to apply
529+ to IO nodes.
530+
531+ """
454532 for node in model .graph .nodes :
455533 if is_annotated (node ):
456534 continue
@@ -468,6 +546,7 @@ def _annotate_io(
468546 mark_node_as_annotated (node )
469547
470548 def validate (self , model : GraphModule ) -> None :
549+ """TODO: Implement validation of annotated graph for TOSA backend."""
471550 pass
472551
473552 def quantize_with_submodules (
@@ -479,10 +558,16 @@ def quantize_with_submodules(
479558 """Quantizes a GraphModule in a way such that conditional submodules are handled properly.
480559
481560 Args:
482- model: GraphModule, the model to quantize.
483- calibration_samples: list[tuple], a list of inputs to used to calibrate the model during quantization.
484- To properly calibrate a model with submodules, at least one sample per code path is needed.
485- is_qat: bool, whether to do quantization aware training or not.
561+ model (GraphModule): The model to quantize.
562+ calibration_samples (list[tuple]): A list of inputs to used to
563+ calibrate the model during quantization. To properly calibrate a
564+ model with submodules, at least one sample per code path is
565+ needed.
566+ is_qat (bool): Whether to do quantization aware training or not.
567+
568+ Returns:
569+ GraphModule: The quantized model.
570+
486571 """
487572 prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e
488573
@@ -499,23 +584,25 @@ def quantize_with_submodules(
499584
500585
501586class EthosUQuantizer (TOSAQuantizer ):
502- """
503- Quantizer supported by the Arm Ethos-U backend.
587+ """Quantizer supported by the Arm Ethos-U backend.
504588
505589 Args:
506- compile_spec: A EthosUCompileSpec instance.
590+ compile_spec (EthosUCompileSpec): Backend compile specification for
591+ Ethos-U targets.
592+
507593 """
508594
509595 def __init__ (self , compile_spec : EthosUCompileSpec ) -> None :
510596 super ().__init__ (compile_spec )
511597
512598
513599class VgfQuantizer (TOSAQuantizer ):
514- """
515- Quantizer supported by the Arm Vgf backend.
600+ """Quantizer supported by the Arm Vgf backend.
516601
517602 Args:
518- compile_spec: A VgfCompileSpec instance.
603+ compile_spec (VgfCompileSpec): Backend compile specification for Vgf
604+ targets.
605+
519606 """
520607
521608 def __init__ (self , compile_spec : VgfCompileSpec ) -> None :
0 commit comments