33
44import re
55import typing
6- from collections .abc import Sequence
6+ from collections .abc import Generator , Sequence
77from copy import copy
88from functools import reduce , singledispatch
99from math import ceil , log2 , prod
@@ -824,6 +824,16 @@ def transform(self, model: 'ModelGraph'):
824824 return True
825825
826826
827+ def get_output_quantizers (node : Layer ) -> Generator [FixedPointQuantizer , None , None ]:
828+ for _node in get_output_layers (node ):
829+ if isinstance (_node , FixedPointQuantizer ):
830+ yield _node
831+ elif isinstance (_node , (Reshape , Transpose )):
832+ yield from get_output_quantizers (_node )
833+ else :
834+ raise ValueError (f'Layer { node .name } ({ node .class_name } ) unexpected input layer chain.' )
835+
836+
827837class FixInputPrecision (OptimizerPass ):
828838 def match (self , node : Layer ):
829839 if not isinstance (node , Input ):
@@ -833,11 +843,7 @@ def match(self, node: Layer):
833843 return node .get_output_variable ().type .precision .width > 100
834844
835845 def transform (self , model , node : Layer ):
836- out_layers : list [FixedPointQuantizer ] = get_output_layers (node ) # type: ignore
837- for layer in out_layers :
838- assert isinstance (
839- layer , FixedPointQuantizer
840- ), f'Input { node .name } connected to non-quantizer { layer .name } with non-trivial configuration'
846+ out_layers = list (get_output_quantizers (node ))
841847
842848 if len (out_layers ) == 0 : # Input connected to nothing
843849 new_type = to_hls4ml_fixed (0 , 0 , 1 , f'{ node .name } _t' )
0 commit comments