Skip to content

Control flow support#124

Open
HeydrichBeillschmidt wants to merge 13 commits intoalpa-projects:masterfrom
HeydrichBeillschmidt:control-flow-support
Open

Control flow support#124
HeydrichBeillschmidt wants to merge 13 commits intoalpa-projects:masterfrom
HeydrichBeillschmidt:control-flow-support

Conversation

@HeydrichBeillschmidt
Copy link
Copy Markdown

[WIP] support tuple-shaped parameters for while instruction

@tdietert
Copy link
Copy Markdown

Hi @HeydrichBeillschmidt, when I merge your changes into my fork and try to call run_auto_sharding_pass on a simple MNIST model, I get this error:

  File "/workspaces/alpa/alpa/shard_parallel/auto_sharding.py", line 355, in run_auto_sharding_pass
    xe.run_auto_sharding(hlo_module, compile_options)
IndexError: absl::container_internal::raw_hash_map<>::at

The source of the error is the CreateStrategyVector code, where apparently a select operation has not been added to the strategy_map, and thus results in an error when iterating through the operands of the dot.278 instruction. Below is some HLO that comes from an intermediate stage of compilation, after the spmd_simplify pipeline, and before the spmd_pipeline that runs the auto sharding pass:

  broadcast.6 = f32[2048,1600]{1,0} broadcast(constant.171), dimensions={}
  select = f32[2048,1600]{1,0} select(compare.183, reshape.29, broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}
  arg34.35 = f32[1600,10]{1,0} parameter(34), parameter_replication={false}, metadata={op_name="XLA_Args"}
  dot.268 = f32[2048,10]{1,0} dot(select, arg34.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="mnist/sequential/dense/MatMul" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/layers/core/dense.py" source_line=221}

And finally, here is some logging output I've generated that shows the sequence of events leading up to this failed indexing into the strategy map:

HandleDot[0]: dot.268
CreateLeafStrategyVector: dot.268
Potential Failing operand instruction: %select = f32[2048,1600]{1,0} select(pred[2048,1600]{1,0} %compare.183, f32[2048,1600]{1,0} %reshape.29, f32[2048,1600]{1,0} %broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}

Do you have any idea what could be the problem?

@tdietert
Copy link
Copy Markdown

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

@HeydrichBeillschmidt
Copy link
Copy Markdown
Author

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

Hi @tdietert , thank you for your issue. The BuildStrategyAndCost is designed as a recursive structure, and entry_sequence here is passed for avoiding repeated construction for instructions in computations such as while body. However, simply letting entry_sequence = entry_computation->instructions() was incorrect. The problem is addressed in the latest commit.

merrymercy and others added 9 commits August 30, 2022 11:12
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Hexu Zhao <zhaohx19@mails.tsinghua.edu.cn>
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Hexu Zhao <zhaohx19@mails.tsinghua.edu.cn>
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
@tdietert
Copy link
Copy Markdown

tdietert commented Sep 2, 2022

@HeydrichBeillschmidt Thanks for your response! We have tried your latest changes and they work well for us, thank you. We have not validated the output, that the while loops are parallelized "correctly", but we don't experience any of the issues we experienced before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants