@@ -244,25 +244,41 @@ template <class AttrElementT,
244
244
Attribute constFoldBinaryOp (ArrayRef<Attribute> operands,
245
245
const CalculationT &calculate) {
246
246
assert (operands.size () == 2 && " binary op takes two operands" );
247
+ if (!operands[0 ] || !operands[1 ])
248
+ return {};
249
+ if (operands[0 ].getType () != operands[1 ].getType ())
250
+ return {};
247
251
248
- if (auto lhs = operands[0 ].dyn_cast_or_null <AttrElementT>()) {
249
- auto rhs = operands[1 ].dyn_cast_or_null <AttrElementT>();
250
- if (!rhs || lhs.getType () != rhs.getType ())
251
- return {};
252
+ if (operands[0 ].isa <AttrElementT>() && operands[1 ].isa <AttrElementT>()) {
253
+ auto lhs = operands[0 ].cast <AttrElementT>();
254
+ auto rhs = operands[1 ].cast <AttrElementT>();
252
255
253
256
return AttrElementT::get (lhs.getType (),
254
257
calculate (lhs.getValue (), rhs.getValue ()));
255
- } else if (auto lhs = operands[0 ].dyn_cast_or_null <SplatElementsAttr>()) {
256
- auto rhs = operands[1 ].dyn_cast_or_null <SplatElementsAttr>();
257
- if (!rhs || lhs.getType () != rhs.getType ())
258
- return {};
259
-
260
- auto elementResult = constFoldBinaryOp<AttrElementT>(
261
- {lhs.getSplatValue (), rhs.getSplatValue ()}, calculate);
262
- if (!elementResult)
263
- return {};
264
-
258
+ } else if (operands[0 ].isa <SplatElementsAttr>() &&
259
+ operands[1 ].isa <SplatElementsAttr>()) {
260
+ // Both operands are splats so we can avoid expanding the values out and
261
+ // just fold based on the splat value.
262
+ auto lhs = operands[0 ].cast <SplatElementsAttr>();
263
+ auto rhs = operands[1 ].cast <SplatElementsAttr>();
264
+
265
+ auto elementResult = calculate (lhs.getSplatValue <ElementValueT>(),
266
+ rhs.getSplatValue <ElementValueT>());
265
267
return DenseElementsAttr::get (lhs.getType (), elementResult);
268
+ } else if (operands[0 ].isa <ElementsAttr>() &&
269
+ operands[1 ].isa <ElementsAttr>()) {
270
+ // Operands are ElementsAttr-derived; perform an element-wise fold by
271
+ // expanding the values.
272
+ auto lhs = operands[0 ].cast <ElementsAttr>();
273
+ auto rhs = operands[1 ].cast <ElementsAttr>();
274
+
275
+ auto lhsIt = lhs.getValues <ElementValueT>().begin ();
276
+ auto rhsIt = rhs.getValues <ElementValueT>().begin ();
277
+ SmallVector<ElementValueT, 4 > elementResults;
278
+ elementResults.reserve (lhs.getNumElements ());
279
+ for (size_t i = 0 , e = lhs.getNumElements (); i < e; ++i, ++lhsIt, ++rhsIt)
280
+ elementResults.push_back (calculate (*lhsIt, *rhsIt));
281
+ return DenseElementsAttr::get (lhs.getType (), elementResults);
266
282
}
267
283
return {};
268
284
}
0 commit comments