diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 951ccfc..56d611a 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -280,7 +280,10 @@ def computeSparsityPattern(self, *spps): def nonZeroFlops(self): nzFlops = 0 for child in self: - nzFlops += child.eqspp().count_nonzero() + permuted = self.broadcast(child.indices, self.permute(child.indices, child.eqspp(), False)) + nzFlops += permuted.count_nonzero() + + # ignore all first adds against zero (i.e. those in self.eqspp()) return nzFlops - self.eqspp().count_nonzero() class UnaryOp(Op):