@@ -154,39 +154,21 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp
154
154
helper.addIntAttr (" quant_max" , op.getQuantMax ());
155
155
156
156
// scale
157
- auto scaleOp = op.getScale ().getDefiningOp ();
158
- if (!scaleOp)
159
- return rewriter.notifyMatchFailure (op, " Missing scale operation" );
160
- auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
161
- if (!scaleTensor)
162
- return rewriter.notifyMatchFailure (
163
- op, " Scale operation is not ValueTensorLiteralOp" );
164
- auto scaleElements =
165
- dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr ());
166
- // scale should be a [1] tensor.
167
- if (!scaleElements || scaleElements.getNumElements () != 1 )
157
+ auto scaleTy = adaptor.getScale ().getType ().dyn_cast <RankedTensorType>();
158
+ if (!scaleTy || scaleTy.getShape ().size () != 1 ||
159
+ scaleTy.getNumElements () != 1 )
160
+ // scale should be a [1] tensor.
168
161
return rewriter.notifyMatchFailure (op, " Unsupported scale type or size" );
169
162
helper.addOperand (" scale" , adaptor.getScale ());
170
163
171
164
// zero_point
172
- auto zeroPointOp = op.getZeroPoint ().getDefiningOp ();
173
- if (!zeroPointOp)
174
- return rewriter.notifyMatchFailure (op, " Missing zero point operation" );
175
- if (auto zeroPointTensor =
176
- dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
177
- auto zeroPointElements =
178
- dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr ());
165
+ auto zeroPointTy =
166
+ adaptor.getZeroPoint ().getType ().dyn_cast <RankedTensorType>();
167
+ if (!zeroPointTy || zeroPointTy.getShape ().size () != 1 ||
168
+ zeroPointTy.getNumElements () != scaleTy.getNumElements ())
179
169
// zero_point should be a [1] tensor.
180
- if (!zeroPointElements || zeroPointElements.getNumElements () != 1 )
181
- return rewriter.notifyMatchFailure (
182
- op, " Unsupported zero point type or size" );
183
- } else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
184
- !dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
185
- // zero like operations are converted through torch-to-tcp
186
- return rewriter.notifyMatchFailure (
187
- op, " Zero point operation is not ValueTensorLiteralOp or Zero "
188
- " operation" );
189
- }
170
+ return rewriter.notifyMatchFailure (op,
171
+ " Unsupported zero point type or size" );
190
172
helper.addOperand (" zero_point" , adaptor.getZeroPoint ());
191
173
192
174
return helper.replace ();
@@ -209,40 +191,20 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
209
191
helper.addIntAttr (" quant_max" , op.getQuantMax ());
210
192
211
193
// scale
212
- auto scaleOp = op.getScale ().getDefiningOp ();
213
- if (!scaleOp)
214
- return rewriter.notifyMatchFailure (op, " Missing scale operation" );
215
- auto scaleTensor = dyn_cast<torch::Torch::ValueTensorLiteralOp>(scaleOp);
216
- if (!scaleTensor)
217
- return rewriter.notifyMatchFailure (
218
- op, " Scale operation is not ValueTensorLiteralOp" );
219
- auto scaleElements =
220
- dyn_cast<DenseFPElementsAttr>(scaleTensor.getValueAttr ());
221
- // scale should be a [C] tensor.
222
- if (!scaleElements || scaleElements.getType ().getShape ().size () != 1 )
194
+ auto scaleTy = adaptor.getScale ().getType ().dyn_cast <RankedTensorType>();
195
+ if (!scaleTy || scaleTy.getShape ().size () != 1 )
196
+ // scale should be a [C] tensor.
223
197
return rewriter.notifyMatchFailure (op, " Unsupported scale type or size" );
224
198
helper.addOperand (" scale" , adaptor.getScale ());
225
199
226
200
// zero_point
227
- auto zeroPointOp = op.getZeroPoint ().getDefiningOp ();
228
- if (!zeroPointOp)
229
- return rewriter.notifyMatchFailure (op, " Missing zero point operation" );
230
- if (auto zeroPointTensor =
231
- dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
232
- auto zeroPointElements =
233
- dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr ());
201
+ auto zeroPointTy =
202
+ adaptor.getZeroPoint ().getType ().dyn_cast <RankedTensorType>();
203
+ if (!zeroPointTy || zeroPointTy.getShape ().size () != 1 ||
204
+ zeroPointTy.getNumElements () != scaleTy.getNumElements ())
234
205
// zero_point should be a [C] tensor.
235
- if (!zeroPointElements ||
236
- zeroPointElements.getType ().getShape ().size () != 1 )
237
- return rewriter.notifyMatchFailure (
238
- op, " Unsupported zero point type or size" );
239
- } else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
240
- !dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
241
- // zero like operations are converted through torch-to-tcp
242
- return rewriter.notifyMatchFailure (
243
- op, " Zero point operation is not ValueTensorLiteralOp or Zero "
244
- " operation" );
245
- }
206
+ return rewriter.notifyMatchFailure (op,
207
+ " Unsupported zero point type or size" );
246
208
helper.addOperand (" zero_point" , adaptor.getZeroPoint ());
247
209
248
210
return helper.replace ();
0 commit comments