-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantizers.py
More file actions
104 lines (90 loc) · 4.09 KB
/
quantizers.py
File metadata and controls
104 lines (90 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject.enum import BitWidthImplType
from brevitas.inject.enum import FloatToIntImplType
from brevitas.inject.enum import QuantType
from brevitas.inject.enum import RestrictValueType
from brevitas.inject.enum import ScalingImplType
from brevitas.inject.enum import StatsOp
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import BiasQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from dependencies import value
class CommonQuant(ExtendedInjector):
bit_width_impl_type = BitWidthImplType.CONST
scaling_impl_type = ScalingImplType.CONST
restrict_scaling_type = RestrictValueType.FP
zero_point_impl = ZeroZeroPoint
float_to_int_impl_type = FloatToIntImplType.ROUND
scaling_per_output_channel = False
narrow_range = True
signed = True
@value
def quant_type(bit_width):
if bit_width is None:
return QuantType.FP
elif bit_width == 1:
return QuantType.BINARY
else:
return QuantType.INT
class CommonWeightQuant(CommonQuant, WeightQuantSolver):
scaling_const = 1.0
class WeightPerChannelQuant(WeightQuantSolver):
quant_type = QuantType.INT # integer quantization
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = (
ScalingImplType.PARAMETER_FROM_STATS
) # scale based on statistics
scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
scaling_per_output_channel = True
signed = True # quantization range is signed
narrow_range = False # quantization range is [-127,127] rather than [-128, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
class WeightPerTensorQuant(WeightQuantSolver):
quant_type = QuantType.INT # integer quantization
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = (
ScalingImplType.PARAMETER_FROM_STATS
) # scale based on statistics
scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
scaling_per_output_channel = False
signed = True # quantization range is signed
narrow_range = False # quantization range is [-127,127] rather than [-128, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
class IntBiasQuant(BiasQuantSolver):
quant_type = QuantType.INT # integer quantization
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = ScalingImplType.PARAMETER
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
requires_input_bit_width = False
requires_input_scale = False
scaling_per_output_channel = False
bit_width = 8
scaling_init = 2 ** (bit_width - 1)
signed = True # quantization range is signed
narrow_range = False # quantization range is [-127,127] rather than [-128, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
class IntActQuant(ActQuantSolver):
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = (
ScalingImplType.CONST
) # scale is a parameter initialized from statistics
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
scaling_per_output_channel = False # scale is per tensor
signed = True # quantization range is signed
narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
@value
def quant_type(bit_width):
if bit_width is None:
return QuantType.FP
elif bit_width == 1:
return QuantType.BINARY
else:
return QuantType.INT