diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index aa9e1caab..98a8c815e 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -249,6 +249,14 @@ def get_qparams( self._scale[i, j] = scale_bp self._zero_point[i, j] = zp_bp + elif self.quantization_args.strategy == QuantizationStrategy.ATTN_HEAD: + # observed.shape = [batch, num_kv_heads, tokens, head_dim] + # tested only for GQA models, add support for others as needed + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, + dim=1, + ) + return self._scale, self._zero_point def get_qparams_along_dim(