Skip to content

Commit d16f016

Browse files
authored
Improve function to identify if an all-gather should be marked as must-recompute (#186)
Instead of marking all all-gather nodes as must-recompute, we now only mark those that can be arbitrarily prefetched, i.e., for which all its recursive inputs are single-input operators and that leads to a graph input.
1 parent 4391c38 commit d16f016

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

autoparallel/activation_checkpointing.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,17 @@ def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:
4949

5050

5151
def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool:
52+
"""
53+
Returns True if the node is a wait_tensor node that is the result of an all_gather
54+
that can be arbitrarily prefetched, i.e., if all its recursive inputs are
55+
single-input operators that leads to a graph input.
56+
"""
5257
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
53-
# TODO: this needs to be improved, its firing in autoparallel "2D" case where input to AG is wait,
54-
# maybe just 2D FSDP
55-
# ag_node = node.args[0]
56-
# assert is_graph_input(ag_node.args[0]) or (
57-
# ag_node.args[0].op == "call_function"
58-
# and ag_node.args[0].target == torch.ops.prims.convert_element_type.default
59-
# and is_graph_input(ag_node.args[0].args[0])
60-
# ), (
61-
# "Assume all_gather_into_tensor input is either graph input "
62-
# + f"or dtype conversion of graph input, but got {ag_node.args[0]}"
63-
# )
64-
return True
58+
n: torch.fx.Node = node.all_input_nodes[0]
59+
while len(n.all_input_nodes) == 1:
60+
if is_graph_input(n.all_input_nodes[0]):
61+
return True
62+
n = n.all_input_nodes[0]
6563
return False
6664

6765

0 commit comments

Comments
 (0)