|
21 | 21 | import torch.fx |
22 | 22 | from executorch.backends.cadence.aot.compiler_utils import ( |
23 | 23 | get_shape, |
24 | | - get_tensor_from_attr, |
25 | 24 | get_zero_point, |
26 | 25 | is_node_with_op, |
27 | 26 | quantize_tensor_multiplier, |
@@ -321,90 +320,106 @@ def call_operator(self, op, args, kwargs, meta): |
321 | 320 |
|
322 | 321 |
|
323 | 322 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
324 | | -class ReplaceAddMMWithLinearPass(ExportPass): |
| 323 | +class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface): |
325 | 324 | """ |
326 | 325 | This pass replaces addmm with linear op. |
| 326 | +
|
| 327 | + AddMM computes: beta*bias + alpha*mm(mat1, mat2) |
| 328 | + Linear computes: mat1 @ weight.T + bias |
| 329 | +
|
327 | 330 | """ |
328 | 331 |
|
329 | | - def __init__(self): |
330 | | - super().__init__() |
331 | | - self.counter = 0 |
| 332 | + @property |
| 333 | + def targets(self) -> list[EdgeOpOverload]: |
| 334 | + return [exir_ops.edge.aten.addmm.default] |
332 | 335 |
|
333 | | - def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule): |
334 | | - graph = graph_module.graph |
335 | | - for node in graph.nodes: |
336 | | - # We are only interested in admm nodes |
337 | | - if node.target != exir_ops.edge.aten.addmm.default: |
338 | | - continue |
| 336 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 337 | + # The addmm op has three concrete args: bias, mat1, mat2 |
| 338 | + assert len(node.args) >= 3 |
| 339 | + (bias, mat1, mat2) = node.args[0:3] |
339 | 340 |
|
340 | | - # The addmm op has three concrete args: input, mat1, mat2 |
341 | | - assert len(node.args) >= 3 |
342 | | - (bias, mat1, mat2) = node.args[0:3] |
343 | | - # The other two args are optional scale args |
344 | | - beta = node.kwargs.get("beta", 1.0) |
345 | | - alpha = node.kwargs.get("alpha", 1.0) |
346 | | - |
347 | | - # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert |
348 | | - # it to linear op by multiplying beta to bias, and alpha to mat2.t(). |
349 | | - # However, the following two conditions must hold: |
350 | | - # a. If bias is not a param, then beta must be 1.0 |
351 | | - # b. If mat2 is not a param, then mat2 must be a transpose op. Also, |
352 | | - # the input to the transpose must be a param, or alpha must be 1.0. |
353 | | - fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 |
354 | | - fit_mat2 = is_node_with_op(mat2, "get_attr") |
355 | | - transposed_mat2 = False |
356 | | - if ( |
357 | | - not fit_mat2 |
358 | | - and is_node_with_op(mat2, "call_function") |
359 | | - and mat2.target == exir_ops.edge.aten.transpose_copy.int |
360 | | - ): |
361 | | - mat2, transposed_mat2 = mat2.args[0], True |
362 | | - fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 |
| 341 | + # The other two args are optional scale args |
| 342 | + beta = float(node.kwargs.get("beta", 1.0)) |
| 343 | + alpha = float(node.kwargs.get("alpha", 1.0)) |
363 | 344 |
|
364 | | - if not fit_bias or not fit_mat2: |
365 | | - continue |
| 345 | + bias, mat1, mat2 = cast( |
| 346 | + tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node], |
| 347 | + (bias, mat1, mat2), |
| 348 | + ) |
| 349 | + |
| 350 | + graph = node.graph |
| 351 | + |
| 352 | + # Handle transpose: if mat2 is a transpose op, extract the original tensor |
| 353 | + transposed_mat2 = False |
| 354 | + if ( |
| 355 | + mat2.op == "call_function" |
| 356 | + and mat2.target == exir_ops.edge.aten.transpose_copy.int |
| 357 | + ): |
| 358 | + # mat2 is already transposed, so we use the input to the transpose |
| 359 | + mat2 = cast(torch.fx.Node, mat2.args[0]) |
| 360 | + transposed_mat2 = True |
| 361 | + |
| 362 | + # Multiply bias by beta if needed |
| 363 | + if beta != 1.0: |
| 364 | + # Create a scaled bias using element-wise multiplication in the graph |
| 365 | + with graph.inserting_before(node): |
| 366 | + beta_scalar = graph.call_function( |
| 367 | + exir_ops.edge.aten.full.default, |
| 368 | + args=([1], beta), |
| 369 | + kwargs={"dtype": torch.float32}, |
| 370 | + ) |
| 371 | + beta_scalar.meta = node.meta |
| 372 | + bias = graph.call_function( |
| 373 | + exir_ops.edge.aten.mul.Tensor, |
| 374 | + args=(bias, beta_scalar), |
| 375 | + ) |
366 | 376 |
|
367 | | - # Multiply bias by beta |
368 | | - if beta != 1.0: |
369 | | - assert is_node_with_op(bias, "get_attr") |
370 | | - bias_tensor = get_tensor_from_attr(graph_module, bias) |
371 | | - assert isinstance(bias_tensor, torch.Tensor) |
372 | | - bias_tensor = beta * bias_tensor |
373 | | - with graph.inserting_before(node): |
374 | | - bias_name = f"_bias_addmm_to_linear_{self.counter}" |
375 | | - graph_module.register_buffer(bias_name, bias_tensor) |
376 | | - bias = graph.get_attr(bias_name) |
377 | | - |
378 | | - # Use associativity of scalar multiplication, and multiply alpha to mat2 |
379 | | - if is_node_with_op(mat2, "get_attr"): |
380 | | - mat2_tensor = get_tensor_from_attr(graph_module, mat2) |
381 | | - assert isinstance(mat2_tensor, torch.Tensor) |
382 | | - mat2_tensor = alpha * mat2_tensor |
383 | | - # transpose mat2 |
384 | | - mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t() |
385 | | - with graph.inserting_before(node): |
386 | | - mat2_name = f"_mat2_addmm_to_linear_{self.counter}" |
387 | | - graph_module.register_buffer(mat2_name, mat2_tensor) |
388 | | - mat2 = graph.get_attr(mat2_name) |
389 | | - |
390 | | - # Construct the linear node |
391 | | - linear_args = (mat1, mat2, bias) |
| 377 | + # Metadata copy important |
| 378 | + bias.meta = node.meta |
| 379 | + |
| 380 | + # Multiply mat2 by alpha if needed |
| 381 | + if alpha != 1.0: |
392 | 382 | with graph.inserting_before(node): |
393 | | - linear_node = graph.call_function( |
394 | | - exir_ops.edge.aten.linear.default, args=linear_args |
| 383 | + alpha_scalar = graph.call_function( |
| 384 | + exir_ops.edge.aten.full.default, |
| 385 | + args=([1], alpha), |
| 386 | + kwargs={"dtype": torch.float32}, |
| 387 | + ) |
| 388 | + alpha_scalar.meta = node.meta |
| 389 | + mat2 = graph.call_function( |
| 390 | + exir_ops.edge.aten.mul.Tensor, |
| 391 | + args=(mat2, alpha_scalar), |
395 | 392 | ) |
396 | | - linear_node.meta = node.meta |
397 | | - # Replace all the uses of the addmm op with linear op |
398 | | - node.replace_all_uses_with(linear_node) |
399 | | - self.counter += 1 |
400 | 393 |
|
401 | | - graph_module.recompile() |
402 | | - graph_module.graph.eliminate_dead_code() |
| 394 | + # Metadata copy important |
| 395 | + mat2.meta = node.meta |
403 | 396 |
|
404 | | - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
405 | | - self.replace_addmm_with_linear(graph_module) |
406 | | - result = super().call(graph_module) |
407 | | - return result |
| 397 | + # Transpose mat2 if it wasn't already transposed |
| 398 | + if not transposed_mat2: |
| 399 | + with graph.inserting_before(node): |
| 400 | + mat2 = graph.call_function( |
| 401 | + exir_ops.edge.aten.transpose_copy.int, |
| 402 | + args=(mat2, -1, -2), |
| 403 | + ) |
| 404 | + |
| 405 | + # Metadata copy important |
| 406 | + mat2.meta = node.meta |
| 407 | + |
| 408 | + # Construct the linear node: linear(input, weight, bias) |
| 409 | + # linear computes: input @ weight.T + bias |
| 410 | + linear_args = (mat1, mat2, bias) |
| 411 | + with graph.inserting_before(node): |
| 412 | + linear_node = graph.call_function( |
| 413 | + exir_ops.edge.aten.linear.default, |
| 414 | + args=linear_args, |
| 415 | + ) |
| 416 | + |
| 417 | + # Metadata copy important |
| 418 | + linear_node.meta = node.meta |
| 419 | + |
| 420 | + # Replace all uses of the addmm op with linear op |
| 421 | + node.replace_all_uses_with(linear_node) |
| 422 | + return True |
408 | 423 |
|
409 | 424 |
|
410 | 425 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
|
0 commit comments