Skip to content

Commit 2e45cf2

Browse files
committed
[WIP][Draft] PP + simplefsdp
stack-info: PR: #253, branch: IvanKobzarev/stack/11
1 parent b1c4909 commit 2e45cf2

File tree

5 files changed

+837
-7
lines changed

5 files changed

+837
-7
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,7 @@ def forward(
15561556
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
15571557
"""
15581558

1559+
self.tok_embeddings = None
15591560
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
15601561

15611562
for layer in self.layers.values():
@@ -1630,6 +1631,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
16301631

16311632
def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]):
16321633
if self.tok_embeddings is not None:
1634+
torch.distributed.breakpoint()
16331635
nn.init.normal_(self.tok_embeddings.weight)
16341636

16351637

autoparallel/_testing/models/llama3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
178178
assert ndim > 1
179179
seqlen = x.shape[1]
180180
freqs_cis = freqs_cis[0:seqlen]
181+
print(f"XXX FREQS_CIS.shape:{freqs_cis.shape} assert == {(seqlen, x.shape[-1])}")
181182
assert freqs_cis.shape == (seqlen, x.shape[-1])
182183
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
183184
return freqs_cis.view(*shape)
@@ -297,7 +298,7 @@ def forward(
297298
xv = xv.view(bs, seqlen, -1, self.head_dim)
298299

299300
# TODO: uncomment
300-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
301+
# xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
301302

302303
# repeat k/v heads if n_kv_heads < n_heads
303304
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

autoparallel/apply_sharding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
339339
fqn_to_param = get_named_param_nodes(gm.graph)
340340
fqn_to_buffer = get_named_buffer_nodes(gm.graph)
341341

342+
# simple_fsdp_param_sharding
343+
# simple_fsdp_mesh
344+
342345
for fqn in params_spec:
343346
n = fqn_to_param[fqn]
344347
with unset_fake_temporarily():

autoparallel/init_weights.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def getter(self) -> torch.nn.Parameter:
5252

5353
def setter(self, value: Union[torch.Tensor, torch.nn.Parameter]) -> None:
5454
parallel_value = parallel_model.get_parameter(fqn)
55-
assert isinstance(
56-
parallel_value, DTensor
57-
), "Expected parallel_module params to be DTensors"
55+
# assert isinstance(
56+
# parallel_value, DTensor
57+
# ), "Expected parallel_module params to be DTensors"
5858
_copy_set_value_to_dtensor(fqn, parallel_value, value)
5959

6060
return property(getter, setter)
@@ -66,9 +66,9 @@ def getter(self) -> torch.Tensor:
6666

6767
def setter(self, value: torch.Tensor) -> None:
6868
parallel_value = parallel_model.get_buffer(fqn)
69-
assert isinstance(
70-
parallel_value, DTensor
71-
), "Expected parallel_module params to be DTensors"
69+
# assert isinstance(
70+
# parallel_value, DTensor
71+
# ), "Expected parallel_module params to be DTensors"
7272
_copy_set_value_to_dtensor(fqn, parallel_value, value)
7373

7474
return property(getter, setter)

0 commit comments

Comments
 (0)