Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit d342ff9

Browse files
benvaniktensorflower-gardener
authored andcommitted
Support folding of StandardOps with DenseElementsAttr.
PiperOrigin-RevId: 282270243
1 parent c3ae3ae commit d342ff9

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

lib/Dialect/StandardOps/Ops.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,25 +244,41 @@ template <class AttrElementT,
244244
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
245245
const CalculationT &calculate) {
246246
assert(operands.size() == 2 && "binary op takes two operands");
247+
if (!operands[0] || !operands[1])
248+
return {};
249+
if (operands[0].getType() != operands[1].getType())
250+
return {};
247251

248-
if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
249-
auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
250-
if (!rhs || lhs.getType() != rhs.getType())
251-
return {};
252+
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
253+
auto lhs = operands[0].cast<AttrElementT>();
254+
auto rhs = operands[1].cast<AttrElementT>();
252255

253256
return AttrElementT::get(lhs.getType(),
254257
calculate(lhs.getValue(), rhs.getValue()));
255-
} else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
256-
auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
257-
if (!rhs || lhs.getType() != rhs.getType())
258-
return {};
259-
260-
auto elementResult = constFoldBinaryOp<AttrElementT>(
261-
{lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
262-
if (!elementResult)
263-
return {};
264-
258+
} else if (operands[0].isa<SplatElementsAttr>() &&
259+
operands[1].isa<SplatElementsAttr>()) {
260+
// Both operands are splats so we can avoid expanding the values out and
261+
// just fold based on the splat value.
262+
auto lhs = operands[0].cast<SplatElementsAttr>();
263+
auto rhs = operands[1].cast<SplatElementsAttr>();
264+
265+
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
266+
rhs.getSplatValue<ElementValueT>());
265267
return DenseElementsAttr::get(lhs.getType(), elementResult);
268+
} else if (operands[0].isa<ElementsAttr>() &&
269+
operands[1].isa<ElementsAttr>()) {
270+
// Operands are ElementsAttr-derived; perform an element-wise fold by
271+
// expanding the values.
272+
auto lhs = operands[0].cast<ElementsAttr>();
273+
auto rhs = operands[1].cast<ElementsAttr>();
274+
275+
auto lhsIt = lhs.getValues<ElementValueT>().begin();
276+
auto rhsIt = rhs.getValues<ElementValueT>().begin();
277+
SmallVector<ElementValueT, 4> elementResults;
278+
elementResults.reserve(lhs.getNumElements());
279+
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
280+
elementResults.push_back(calculate(*lhsIt, *rhsIt));
281+
return DenseElementsAttr::get(lhs.getType(), elementResults);
266282
}
267283
return {};
268284
}

test/Transforms/constant-fold.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,34 @@ func @addf_splat_tensor() -> tensor<4xf32> {
5050

5151
// -----
5252

53+
// CHECK-LABEL: func @addf_dense_tensor
54+
func @addf_dense_tensor() -> tensor<4xf32> {
55+
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
56+
%1 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
57+
58+
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 5.{{0*}}e+00, 7.{{0*}}e+00, 9.{{0*}}e+00]> : tensor<4xf32>
59+
%2 = addf %0, %1 : tensor<4xf32>
60+
61+
// CHECK-NEXT: return [[C]]
62+
return %2 : tensor<4xf32>
63+
}
64+
65+
// -----
66+
67+
// CHECK-LABEL: func @addf_dense_and_splat_tensors
68+
func @addf_dense_and_splat_tensors() -> tensor<4xf32> {
69+
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
70+
%1 = constant dense<1.5> : tensor<4xf32>
71+
72+
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 4.{{0*}}e+00, 5.{{0*}}e+00, 6.{{0*}}e+00]> : tensor<4xf32>
73+
%2 = addf %0, %1 : tensor<4xf32>
74+
75+
// CHECK-NEXT: return [[C]]
76+
return %2 : tensor<4xf32>
77+
}
78+
79+
// -----
80+
5381
// CHECK-LABEL: func @simple_addi
5482
func @simple_addi() -> i32 {
5583
%0 = constant 1 : i32

0 commit comments

Comments
 (0)