Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit 05798b6

Browse files
authored
update checking logic for scale and zp op (#83)
For fake_quantize_ops, `scale` and `zero_point` operands could be lowered by `TorchToTcp` already, so we need to relax the checking condition here (just need to make sure the type and shape are valid).
1 parent a4fba88 commit 05798b6

File tree

1 file changed

+19
-57
lines changed

1 file changed

+19
-57
lines changed

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -154,39 +154,21 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp
154154
helper.addIntAttr("quant_max", op.getQuantMax());
155155

156156
// scale
157-
auto scaleOp = op.getScale().getDefiningOp();
158-
if (!scaleOp)
159-
return rewriter.notifyMatchFailure(op, "Missing scale operation");
160-
auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
161-
if (!scaleTensor)
162-
return rewriter.notifyMatchFailure(
163-
op, "Scale operation is not ValueTensorLiteralOp");
164-
auto scaleElements =
165-
dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr());
166-
// scale should be a [1] tensor.
167-
if (!scaleElements || scaleElements.getNumElements() != 1)
157+
auto scaleTy = adaptor.getScale().getType().dyn_cast<RankedTensorType>();
158+
if (!scaleTy || scaleTy.getShape().size() != 1 ||
159+
scaleTy.getNumElements() != 1)
160+
// scale should be a [1] tensor.
168161
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
169162
helper.addOperand("scale", adaptor.getScale());
170163

171164
// zero_point
172-
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
173-
if (!zeroPointOp)
174-
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
175-
if (auto zeroPointTensor =
176-
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
177-
auto zeroPointElements =
178-
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
165+
auto zeroPointTy =
166+
adaptor.getZeroPoint().getType().dyn_cast<RankedTensorType>();
167+
if (!zeroPointTy || zeroPointTy.getShape().size() != 1 ||
168+
zeroPointTy.getNumElements() != scaleTy.getNumElements())
179169
// zero_point should be a [1] tensor.
180-
if (!zeroPointElements || zeroPointElements.getNumElements() != 1)
181-
return rewriter.notifyMatchFailure(
182-
op, "Unsupported zero point type or size");
183-
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
184-
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
185-
// zero like operations are converted through torch-to-tcp
186-
return rewriter.notifyMatchFailure(
187-
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
188-
"operation");
189-
}
170+
return rewriter.notifyMatchFailure(op,
171+
"Unsupported zero point type or size");
190172
helper.addOperand("zero_point", adaptor.getZeroPoint());
191173

192174
return helper.replace();
@@ -209,40 +191,20 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
209191
helper.addIntAttr("quant_max", op.getQuantMax());
210192

211193
// scale
212-
auto scaleOp = op.getScale().getDefiningOp();
213-
if (!scaleOp)
214-
return rewriter.notifyMatchFailure(op, "Missing scale operation");
215-
auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
216-
if (!scaleTensor)
217-
return rewriter.notifyMatchFailure(
218-
op, "Scale operation is not ValueTensorLiteralOp");
219-
auto scaleElements =
220-
dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr());
221-
// scale should be a [C] tensor.
222-
if (!scaleElements || scaleElements.getType().getShape().size() != 1)
194+
auto scaleTy = adaptor.getScale().getType().dyn_cast<RankedTensorType>();
195+
if (!scaleTy || scaleTy.getShape().size() != 1)
196+
// scale should be a [C] tensor.
223197
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
224198
helper.addOperand("scale", adaptor.getScale());
225199

226200
// zero_point
227-
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
228-
if (!zeroPointOp)
229-
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
230-
if (auto zeroPointTensor =
231-
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
232-
auto zeroPointElements =
233-
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
201+
auto zeroPointTy =
202+
adaptor.getZeroPoint().getType().dyn_cast<RankedTensorType>();
203+
if (!zeroPointTy || zeroPointTy.getShape().size() != 1 ||
204+
zeroPointTy.getNumElements() != scaleTy.getNumElements())
234205
// zero_point should be a [C] tensor.
235-
if (!zeroPointElements ||
236-
zeroPointElements.getType().getShape().size() != 1)
237-
return rewriter.notifyMatchFailure(
238-
op, "Unsupported zero point type or size");
239-
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
240-
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
241-
// zero like operations are converted through torch-to-tcp
242-
return rewriter.notifyMatchFailure(
243-
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
244-
"operation");
245-
}
206+
return rewriter.notifyMatchFailure(op,
207+
"Unsupported zero point type or size");
246208
helper.addOperand("zero_point", adaptor.getZeroPoint());
247209

248210
return helper.replace();

0 commit comments

Comments
 (0)