Skip to content

Commit 6903f28

Browse files
authored
More Passes updated to use new interface
Differential Revision: D86919809 Pull Request resolved: #16026
1 parent 3b1aeda commit 6903f28

File tree

3 files changed

+193
-97
lines changed

3 files changed

+193
-97
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,15 +2178,31 @@ def conv1d_meta(
21782178
dilation: Tuple[int],
21792179
groups: int,
21802180
) -> torch.Tensor:
2181+
# Validate tensor dimensions
2182+
assert len(input.shape) == 3, f"Conv1d expects 3D input, got {len(input.shape)}D"
2183+
assert len(weight.shape) == 3, f"Conv1d expects 3D weight, got {len(weight.shape)}D"
2184+
2185+
# Extract dimensions
2186+
batch_size, in_channels, length = input.shape
2187+
out_channels, weight_in_channels, kernel_size = weight.shape
2188+
2189+
# Validate groups parameter and channel consistency
2190+
assert groups > 0, f"groups must be positive, got {groups}"
21812191
assert (
2182-
len(weight.shape) == 3
2183-
), f"Conv1d expects a 3D weight, got {len(weight.shape)}D"
2184-
out_channels, _, kernel_size = weight.shape
2185-
in_size = input.shape
2186-
assert len(in_size) == 3, f"conv1d expects 3D input, got {len(in_size)}D"
2192+
in_channels % groups == 0
2193+
), f"in_channels ({in_channels}) must be divisible by groups ({groups})"
2194+
assert (
2195+
out_channels % groups == 0
2196+
), f"out_channels ({out_channels}) must be divisible by groups ({groups})"
2197+
2198+
# Validate weight channels match input channels divided by groups
2199+
expected_weight_in_channels = in_channels // groups
2200+
assert (
2201+
weight_in_channels == expected_weight_in_channels
2202+
), f"Expected weight to have {expected_weight_in_channels} input channels (in_channels/groups), but got {weight_in_channels}"
21872203

21882204
output_size = get_conv1d_output_size(
2189-
in_size,
2205+
input.shape,
21902206
out_channels,
21912207
stride[0],
21922208
padding[0],

backends/cadence/aot/replace_ops.py

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -423,98 +423,120 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
423423

424424

425425
@register_cadence_pass(CadencePassAttribute(opt_level=1))
426-
class ReplacePermuteWithTransposePass(ExportPass):
426+
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
427427
"""
428428
Replace permute op with transpose if the permutation is only along
429429
two dimensions.
430430
"""
431431

432-
def call_operator(self, op, args, kwargs, meta):
433-
if op != exir_ops.edge.aten.permute_copy.default:
434-
return super().call_operator(op, args, kwargs, meta)
432+
@property
433+
def targets(self) -> list[EdgeOpOverload]:
434+
return [exir_ops.edge.aten.permute_copy.default]
435435

436+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
436437
# Get the old dim and new dim order
437-
in_tensor = args[0].to_tensor()
438-
old_dims = tuple(range(in_tensor.dim()))
439-
new_dims = args[1]
438+
in_tensor = node.args[0]
439+
assert isinstance(in_tensor, torch.fx.Node)
440+
in_shape = in_tensor.meta["val"].shape
441+
old_dims = tuple(range(len(in_shape)))
442+
new_dims = cast(Sequence[int], node.args[1])
440443

441444
# Compute the number of positions in which the old and new order differ
442445
diff = [od for od, nd in zip(old_dims, new_dims) if od != nd]
443446

447+
# If the difference is zero, replace with identity (just the input)
448+
if len(diff) == 0:
449+
node.replace_all_uses_with(in_tensor)
450+
return True
451+
444452
# If the difference is in two dimensions, we can replace this permute op
445453
# with transpose op.
446454
if len(diff) == 2:
447-
new_args = (args[0], diff[0], diff[1])
448-
return super().call_operator(
449-
exir_ops.edge.aten.transpose_copy.int, new_args, kwargs, meta
450-
)
455+
with node.graph.inserting_before(node):
456+
new_node = node.graph.call_function(
457+
exir_ops.edge.aten.transpose_copy.int,
458+
args=(node.args[0], diff[0], diff[1]),
459+
)
460+
new_node.meta = node.meta
461+
node.replace_all_uses_with(new_node)
462+
return True
451463

452-
return (
453-
args[0] if len(diff) == 0 else super().call_operator(op, args, kwargs, meta)
454-
)
464+
return False
455465

456466

457467
@register_cadence_pass(CadencePassAttribute(opt_level=0))
458-
class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass):
468+
class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface):
459469
"""
460470
Replace optional tensors with concrete tensors. Currently, we
461471
replace the optional bias tensor with a zero tensor.
462472
"""
463473

464-
def call_operator(self, op, args, kwargs, meta):
465-
op_packet = get_edge_overload_packet(op)
466-
if op_packet not in {
467-
exir_ops.edge.cadence.conv1d,
468-
exir_ops.edge.cadence.conv2d,
469-
exir_ops.edge.cadence.conv3d,
474+
@property
475+
def targets(self) -> list[EdgeOpOverload]:
476+
return [
477+
exir_ops.edge.cadence.conv1d.default,
478+
exir_ops.edge.cadence.conv2d.default,
479+
exir_ops.edge.cadence.conv3d.default,
470480
exir_ops.edge.cadence.transposed_convolution,
471-
}:
472-
return super().call_operator(op, args, kwargs, meta)
481+
]
473482

483+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
484+
# Check if this is a transposed convolution
485+
assert isinstance(node.target, EdgeOpOverload)
486+
op_packet = get_edge_overload_packet(node.target)
474487
is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution
475488
num_expected_args = 9 if is_transposed else 7
476-
assert len(args) == num_expected_args
477-
# Check if the bias is already concrete
478-
if args[2] is not None:
479-
return super().call_operator(op, args, kwargs, meta)
489+
assert len(node.args) == num_expected_args
490+
# Check if the bias is concrete
491+
if node.args[2] is not None:
492+
return False
480493

481494
# The bias length is the number of out channels.
482-
out_shape = meta["val"].shape
495+
out_shape = node.meta["val"].shape
483496
bias_size = out_shape[1]
484497
# Create a zero bias tensor (bias is not a constant tensor,
485-
# so it needs to be the result of a graph operation).
486-
zero_bias = super().call_operator(
487-
exir_ops.edge.aten.full.default,
488-
([bias_size], 0.0),
489-
{"dtype": torch.float32},
490-
meta,
491-
)
498+
with node.graph.inserting_before(node):
499+
zero_bias = node.graph.call_function(
500+
exir_ops.edge.aten.full.default,
501+
args=([bias_size], 0.0),
502+
kwargs={"dtype": torch.float32},
503+
)
504+
zero_bias.meta = node.meta
505+
new_args = list(node.args)
506+
new_args[2] = zero_bias
507+
new_args = tuple(new_args)
492508

493-
# Replace bias with zero_bias
494-
args = list(args)
495-
args[2] = zero_bias
496-
args = tuple(args)
509+
new_node = node.graph.call_function(
510+
# pyre-ignore[6]: Target is a call func, but type is union call func and str
511+
node.target,
512+
args=new_args,
513+
kwargs=node.kwargs,
514+
)
515+
new_node.meta = node.meta
497516

498-
return super().call_operator(op, args, kwargs, meta)
517+
node.replace_all_uses_with(new_node)
518+
return True
499519

500520

501521
@register_cadence_pass(CadencePassAttribute(opt_level=0))
502-
class ReplaceRepeatWithCatPass(ExportPass):
522+
class ReplaceRepeatWithCatPass(RemoveOrReplacePassInterface):
503523
"""
504524
Replace repeat op as successive cat ops along different dimensions.
505525
repeat is not supported, so this is an opt_level=0 pass.
506526
"""
507527

508-
def call_operator(self, op, args, kwargs, meta):
509-
if op != exir_ops.edge.aten.repeat.default:
510-
return super().call_operator(op, args, kwargs, meta)
528+
@property
529+
def targets(self) -> list[EdgeOpOverload]:
530+
return [exir_ops.edge.aten.repeat.default]
511531

532+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
512533
# Extract the input tensor, and the repeats from the args
513-
in_tensor = args[0]
514-
repeats = args[1]
534+
in_tensor = node.args[0]
535+
assert isinstance(in_tensor, torch.fx.Node)
536+
repeats = cast(Sequence[int], node.args[1])
515537

516538
# Glean the shapes of input tensor
517-
in_shape = list(in_tensor.to_tensor().shape)
539+
in_shape = list(in_tensor.meta["val"].shape)
518540

519541
# If the size of repeats is more than the dimensionality of the tensor,
520542
# the output of repeat will be a higher-dimensional tensor. We reshape
@@ -524,30 +546,36 @@ def call_operator(self, op, args, kwargs, meta):
524546
diff >= 0
525547
), "Repeat arg malformed: expected a repeat along each dimension of input tensor"
526548

549+
graph = node.graph
550+
result_node = in_tensor
551+
527552
if diff > 0:
528553
# Extend the input shape with 1's along the higher dimensions
529554
in_shape = ([1] * diff) + in_shape
530555
# Insert a view op that reshapes the input tensor to have same
531556
# dimensionality as the output tensor.
532-
in_tensor = super().call_operator(
533-
exir_ops.edge.aten.view_copy.default,
534-
(in_tensor, in_shape),
535-
kwargs,
536-
meta,
537-
)
557+
with graph.inserting_before(node):
558+
result_node = graph.call_function(
559+
exir_ops.edge.aten.view_copy.default,
560+
args=(in_tensor, in_shape),
561+
)
562+
result_node.meta = node.meta
538563
assert len(repeats) == len(in_shape)
539564

540565
# Repeat op is nothing but successive cat ops along each dimension.
541566
for dim, repeat in reversed(list(enumerate(repeats))):
542567
# We do not need to do anything if repeat factor is 1
543568
if repeat == 1:
544569
continue
545-
cat_arg = [in_tensor] * repeat
546-
in_tensor = super().call_operator(
547-
exir_ops.edge.aten.cat.default, (cat_arg, dim), kwargs, meta
548-
)
570+
cat_arg = [result_node] * repeat
571+
with graph.inserting_before(node):
572+
result_node = graph.call_function(
573+
exir_ops.edge.aten.cat.default, args=(cat_arg, dim)
574+
)
575+
result_node.meta = node.meta
549576

550-
return in_tensor
577+
node.replace_all_uses_with(result_node)
578+
return True
551579

552580

553581
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -632,41 +660,48 @@ def call_operator(self, op, args, kwargs, meta):
632660

633661

634662
@register_cadence_pass(CadencePassAttribute(opt_level=1))
635-
class ReplaceConstantPadNdWithSlicePass(ExportPass):
663+
class ReplaceConstantPadNdWithSlicePass(RemoveOrReplacePassInterface):
636664
"""
637665
Replace constant pad nd op that does padding on outer-most dimension
638666
with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor)
639667
"""
640668

641-
def call_operator(self, op, args, kwargs, meta):
642-
if op != exir_ops.edge.aten.constant_pad_nd.default:
643-
return super().call_operator(op, args, kwargs, meta)
669+
@property
670+
def targets(self) -> list[EdgeOpOverload]:
671+
return [exir_ops.edge.aten.constant_pad_nd.default]
644672

645-
assert len(args) >= 2
646-
input_node, orig_padding = args[:2]
673+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
674+
assert len(node.args) >= 2
675+
input_node = node.args[0]
676+
orig_padding = cast(Sequence[int], node.args[1])
677+
assert isinstance(input_node, torch.fx.Node)
647678

648679
# if there is no padding, this op will be treated in removal pass.
649680
if not orig_padding:
650-
return super().call_operator(op, args, kwargs, meta)
681+
return False
651682

652-
padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0))
683+
padding = list(orig_padding) + ([0] * (len(orig_padding) % 2 != 0))
653684
assert len(padding) >= 2
685+
686+
# pyre-ignore[6]
654687
(start, diff) = map(neg, padding[-2:])
655688
# Replace only if constant_pad_nd is along the innermost padding dimension.
656689
if any(x != 0 for x in padding[0:-2]) or start < 0 or diff < 0:
657-
return super().call_operator(op, args, kwargs, meta)
690+
return False
658691

659-
arg_shape = input_node.to_tensor().shape
692+
arg_shape = input_node.meta["val"].shape
660693
dim = len(arg_shape) - len(padding) // 2
661694
stop = arg_shape[dim] - diff
662695
assert start <= stop
663-
new_args = (input_node, dim, start, stop)
664-
return super().call_operator(
665-
exir_ops.edge.aten.slice.Tensor,
666-
new_args,
667-
kwargs,
668-
meta,
669-
)
696+
697+
with node.graph.inserting_before(node):
698+
new_node = node.graph.call_function(
699+
exir_ops.edge.aten.slice.Tensor,
700+
args=(input_node, dim, start, stop),
701+
)
702+
new_node.meta = node.meta
703+
node.replace_all_uses_with(new_node)
704+
return True
670705

671706

672707
# Make that pass runnable standalone at opt level 0.

0 commit comments

Comments
 (0)