@@ -423,98 +423,120 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
423423
424424
425425@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
426- class ReplacePermuteWithTransposePass (ExportPass ):
426+ class ReplacePermuteWithTransposePass (RemoveOrReplacePassInterface ):
427427 """
428428 Replace permute op with transpose if the permutation is only along
429429 two dimensions.
430430 """
431431
432- def call_operator ( self , op , args , kwargs , meta ):
433- if op != exir_ops . edge . aten . permute_copy . default :
434- return super (). call_operator ( op , args , kwargs , meta )
432+ @ property
433+ def targets ( self ) -> list [ EdgeOpOverload ] :
434+ return [ exir_ops . edge . aten . permute_copy . default ]
435435
436+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
436437 # Get the old dim and new dim order
437- in_tensor = args [0 ].to_tensor ()
438- old_dims = tuple (range (in_tensor .dim ()))
439- new_dims = args [1 ]
438+ in_tensor = node .args [0 ]
439+ assert isinstance (in_tensor , torch .fx .Node )
440+ in_shape = in_tensor .meta ["val" ].shape
441+ old_dims = tuple (range (len (in_shape )))
442+ new_dims = cast (Sequence [int ], node .args [1 ])
440443
441444 # Compute the number of positions in which the old and new order differ
442445 diff = [od for od , nd in zip (old_dims , new_dims ) if od != nd ]
443446
447+ # If the difference is zero, replace with identity (just the input)
448+ if len (diff ) == 0 :
449+ node .replace_all_uses_with (in_tensor )
450+ return True
451+
444452 # If the difference is in two dimensions, we can replace this permute op
445453 # with transpose op.
446454 if len (diff ) == 2 :
447- new_args = (args [0 ], diff [0 ], diff [1 ])
448- return super ().call_operator (
449- exir_ops .edge .aten .transpose_copy .int , new_args , kwargs , meta
450- )
455+ with node .graph .inserting_before (node ):
456+ new_node = node .graph .call_function (
457+ exir_ops .edge .aten .transpose_copy .int ,
458+ args = (node .args [0 ], diff [0 ], diff [1 ]),
459+ )
460+ new_node .meta = node .meta
461+ node .replace_all_uses_with (new_node )
462+ return True
451463
452- return (
453- args [0 ] if len (diff ) == 0 else super ().call_operator (op , args , kwargs , meta )
454- )
464+ return False
455465
456466
457467@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
458- class ReplaceConvolutionOptionalArgsWithConcreteArgsPass (ExportPass ):
468+ class ReplaceConvolutionOptionalArgsWithConcreteArgsPass (RemoveOrReplacePassInterface ):
459469 """
460470 Replace optional tensors with concrete tensors. Currently, we
461471 replace the optional bias tensor with a zero tensor.
462472 """
463473
464- def call_operator ( self , op , args , kwargs , meta ):
465- op_packet = get_edge_overload_packet ( op )
466- if op_packet not in {
467- exir_ops .edge .cadence .conv1d ,
468- exir_ops .edge .cadence .conv2d ,
469- exir_ops .edge .cadence .conv3d ,
474+ @ property
475+ def targets ( self ) -> list [ EdgeOpOverload ]:
476+ return [
477+ exir_ops .edge .cadence .conv1d . default ,
478+ exir_ops .edge .cadence .conv2d . default ,
479+ exir_ops .edge .cadence .conv3d . default ,
470480 exir_ops .edge .cadence .transposed_convolution ,
471- }:
472- return super ().call_operator (op , args , kwargs , meta )
481+ ]
473482
483+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
484+ # Check if this is a transposed convolution
485+ assert isinstance (node .target , EdgeOpOverload )
486+ op_packet = get_edge_overload_packet (node .target )
474487 is_transposed = op_packet == exir_ops .edge .cadence .transposed_convolution
475488 num_expected_args = 9 if is_transposed else 7
476- assert len (args ) == num_expected_args
477- # Check if the bias is already concrete
478- if args [2 ] is not None :
479- return super (). call_operator ( op , args , kwargs , meta )
489+ assert len (node . args ) == num_expected_args
490+ # Check if the bias is concrete
491+ if node . args [2 ] is not None :
492+ return False
480493
481494 # The bias length is the number of out channels.
482- out_shape = meta ["val" ].shape
495+ out_shape = node . meta ["val" ].shape
483496 bias_size = out_shape [1 ]
484497 # Create a zero bias tensor (bias is not a constant tensor,
485- # so it needs to be the result of a graph operation).
486- zero_bias = super ().call_operator (
487- exir_ops .edge .aten .full .default ,
488- ([bias_size ], 0.0 ),
489- {"dtype" : torch .float32 },
490- meta ,
491- )
498+ with node .graph .inserting_before (node ):
499+ zero_bias = node .graph .call_function (
500+ exir_ops .edge .aten .full .default ,
501+ args = ([bias_size ], 0.0 ),
502+ kwargs = {"dtype" : torch .float32 },
503+ )
504+ zero_bias .meta = node .meta
505+ new_args = list (node .args )
506+ new_args [2 ] = zero_bias
507+ new_args = tuple (new_args )
492508
493- # Replace bias with zero_bias
494- args = list (args )
495- args [2 ] = zero_bias
496- args = tuple (args )
509+ new_node = node .graph .call_function (
510+ # pyre-ignore[6]: Target is a call func, but type is union call func and str
511+ node .target ,
512+ args = new_args ,
513+ kwargs = node .kwargs ,
514+ )
515+ new_node .meta = node .meta
497516
498- return super ().call_operator (op , args , kwargs , meta )
517+ node .replace_all_uses_with (new_node )
518+ return True
499519
500520
501521@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
502- class ReplaceRepeatWithCatPass (ExportPass ):
522+ class ReplaceRepeatWithCatPass (RemoveOrReplacePassInterface ):
503523 """
504524 Replace repeat op as successive cat ops along different dimensions.
505525 repeat is not supported, so this is an opt_level=0 pass.
506526 """
507527
508- def call_operator ( self , op , args , kwargs , meta ):
509- if op != exir_ops . edge . aten . repeat . default :
510- return super (). call_operator ( op , args , kwargs , meta )
528+ @ property
529+ def targets ( self ) -> list [ EdgeOpOverload ] :
530+ return [ exir_ops . edge . aten . repeat . default ]
511531
532+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
512533 # Extract the input tensor, and the repeats from the args
513- in_tensor = args [0 ]
514- repeats = args [1 ]
534+ in_tensor = node .args [0 ]
535+ assert isinstance (in_tensor , torch .fx .Node )
536+ repeats = cast (Sequence [int ], node .args [1 ])
515537
516538 # Glean the shapes of input tensor
517- in_shape = list (in_tensor .to_tensor () .shape )
539+ in_shape = list (in_tensor .meta [ "val" ] .shape )
518540
519541 # If the size of repeats is more than the dimensionality of the tensor,
520542 # the output of repeat will be a higher-dimensional tensor. We reshape
@@ -524,30 +546,36 @@ def call_operator(self, op, args, kwargs, meta):
524546 diff >= 0
525547 ), "Repeat arg malformed: expected a repeat along each dimension of input tensor"
526548
549+ graph = node .graph
550+ result_node = in_tensor
551+
527552 if diff > 0 :
528553 # Extend the input shape with 1's along the higher dimensions
529554 in_shape = ([1 ] * diff ) + in_shape
530555 # Insert a view op that reshapes the input tensor to have same
531556 # dimensionality as the output tensor.
532- in_tensor = super (). call_operator (
533- exir_ops . edge . aten . view_copy . default ,
534- ( in_tensor , in_shape ) ,
535- kwargs ,
536- meta ,
537- )
557+ with graph . inserting_before ( node ):
558+ result_node = graph . call_function (
559+ exir_ops . edge . aten . view_copy . default ,
560+ args = ( in_tensor , in_shape ) ,
561+ )
562+ result_node . meta = node . meta
538563 assert len (repeats ) == len (in_shape )
539564
540565 # Repeat op is nothing but successive cat ops along each dimension.
541566 for dim , repeat in reversed (list (enumerate (repeats ))):
542567 # We do not need to do anything if repeat factor is 1
543568 if repeat == 1 :
544569 continue
545- cat_arg = [in_tensor ] * repeat
546- in_tensor = super ().call_operator (
547- exir_ops .edge .aten .cat .default , (cat_arg , dim ), kwargs , meta
548- )
570+ cat_arg = [result_node ] * repeat
571+ with graph .inserting_before (node ):
572+ result_node = graph .call_function (
573+ exir_ops .edge .aten .cat .default , args = (cat_arg , dim )
574+ )
575+ result_node .meta = node .meta
549576
550- return in_tensor
577+ node .replace_all_uses_with (result_node )
578+ return True
551579
552580
553581@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
@@ -632,41 +660,48 @@ def call_operator(self, op, args, kwargs, meta):
632660
633661
634662@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
635- class ReplaceConstantPadNdWithSlicePass (ExportPass ):
663+ class ReplaceConstantPadNdWithSlicePass (RemoveOrReplacePassInterface ):
636664 """
637665 Replace constant pad nd op that does padding on outer-most dimension
638666 with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor)
639667 """
640668
641- def call_operator ( self , op , args , kwargs , meta ):
642- if op != exir_ops . edge . aten . constant_pad_nd . default :
643- return super (). call_operator ( op , args , kwargs , meta )
669+ @ property
670+ def targets ( self ) -> list [ EdgeOpOverload ] :
671+ return [ exir_ops . edge . aten . constant_pad_nd . default ]
644672
645- assert len (args ) >= 2
646- input_node , orig_padding = args [:2 ]
673+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
674+ assert len (node .args ) >= 2
675+ input_node = node .args [0 ]
676+ orig_padding = cast (Sequence [int ], node .args [1 ])
677+ assert isinstance (input_node , torch .fx .Node )
647678
648679 # if there is no padding, this op will be treated in removal pass.
649680 if not orig_padding :
650- return super (). call_operator ( op , args , kwargs , meta )
681+ return False
651682
652- padding = orig_padding + ([0 ] * (len (orig_padding ) % 2 != 0 ))
683+ padding = list ( orig_padding ) + ([0 ] * (len (orig_padding ) % 2 != 0 ))
653684 assert len (padding ) >= 2
685+
686+ # pyre-ignore[6]
654687 (start , diff ) = map (neg , padding [- 2 :])
655688 # Replace only if constant_pad_nd is along the innermost padding dimension.
656689 if any (x != 0 for x in padding [0 :- 2 ]) or start < 0 or diff < 0 :
657- return super (). call_operator ( op , args , kwargs , meta )
690+ return False
658691
659- arg_shape = input_node .to_tensor () .shape
692+ arg_shape = input_node .meta [ "val" ] .shape
660693 dim = len (arg_shape ) - len (padding ) // 2
661694 stop = arg_shape [dim ] - diff
662695 assert start <= stop
663- new_args = (input_node , dim , start , stop )
664- return super ().call_operator (
665- exir_ops .edge .aten .slice .Tensor ,
666- new_args ,
667- kwargs ,
668- meta ,
669- )
696+
697+ with node .graph .inserting_before (node ):
698+ new_node = node .graph .call_function (
699+ exir_ops .edge .aten .slice .Tensor ,
700+ args = (input_node , dim , start , stop ),
701+ )
702+ new_node .meta = node .meta
703+ node .replace_all_uses_with (new_node )
704+ return True
670705
671706
672707# Make that pass runnable standalone at opt level 0.
0 commit comments