diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 4108e72b02..6d3e7b60c6 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -128,9 +128,9 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c layer_tp_plan = { # by default ColwiseParallel input layouts is replicated # and RowwiseParallel output layouts is replicated - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(), "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(), @@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension. If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape. -For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``. +For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension. Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary: @@ -150,11 +150,6 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each `` for layer_id, transformer_block in enumerate(model.layers): layer_tp_plan = {...} # i.e. the plan we just generated - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh, @@ -219,12 +214,12 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS # to represent the input/output tensors sharded on the sequence dimension "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput(