Skip to content

Commit 4d48951

Browse files
committed
[WIP][Draft] PP + simplefsdp
stack-info: PR: #253, branch: IvanKobzarev/stack/11
1 parent b1c4909 commit 4d48951

File tree

2 files changed

+830
-6
lines changed

2 files changed

+830
-6
lines changed

autoparallel/init_weights.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def getter(self) -> torch.nn.Parameter:
5252

5353
def setter(self, value: Union[torch.Tensor, torch.nn.Parameter]) -> None:
5454
parallel_value = parallel_model.get_parameter(fqn)
55-
assert isinstance(
56-
parallel_value, DTensor
57-
), "Expected parallel_module params to be DTensors"
55+
# assert isinstance(
56+
# parallel_value, DTensor
57+
# ), "Expected parallel_module params to be DTensors"
5858
_copy_set_value_to_dtensor(fqn, parallel_value, value)
5959

6060
return property(getter, setter)
@@ -66,9 +66,9 @@ def getter(self) -> torch.Tensor:
6666

6767
def setter(self, value: torch.Tensor) -> None:
6868
parallel_value = parallel_model.get_buffer(fqn)
69-
assert isinstance(
70-
parallel_value, DTensor
71-
), "Expected parallel_module params to be DTensors"
69+
# assert isinstance(
70+
# parallel_value, DTensor
71+
# ), "Expected parallel_module params to be DTensors"
7272
_copy_set_value_to_dtensor(fqn, parallel_value, value)
7373

7474
return property(getter, setter)

0 commit comments

Comments
 (0)