Skip to content

Commit af3d278

Browse files
committed
Minor fixes for review
1 parent 636186e commit af3d278

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

python/tests/unit/module/models_test.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,12 @@ def test_dmp_wrapped_model_produces_correct_output(self):
299299
"""
300300
process_group_init_method = get_process_group_init_method()
301301
# Initialize distributed
302-
if not dist.is_initialized():
303-
dist.init_process_group(
304-
backend="gloo",
305-
init_method=process_group_init_method,
306-
rank=0,
307-
world_size=1,
308-
)
302+
dist.init_process_group(
303+
backend="gloo",
304+
init_method=process_group_init_method,
305+
rank=0,
306+
world_size=1,
307+
)
309308

310309
# Create model
311310
model = self._create_lightgcn_model(self.num_nodes)
@@ -336,13 +335,12 @@ def test_dmp_gradient_flow(self):
336335
process_group_init_method = get_process_group_init_method()
337336

338337
# Initialize distributed
339-
if not dist.is_initialized():
340-
dist.init_process_group(
341-
backend="gloo",
342-
init_method=process_group_init_method,
343-
rank=0,
344-
world_size=1,
345-
)
338+
dist.init_process_group(
339+
backend="gloo",
340+
init_method=process_group_init_method,
341+
rank=0,
342+
world_size=1,
343+
)
346344

347345
# Create and wrap model
348346
model = self._create_lightgcn_model(self.num_nodes)
@@ -354,15 +352,13 @@ def test_dmp_gradient_flow(self):
354352

355353
self._set_embeddings(model, "default_homogeneous_node_type")
356354

357-
model.train()
358-
359-
# Forward and backward pass
360-
output = dmp_model(data=self.data, device=self.device)
355+
dmp_model.train()
356+
output = dmp_model(self.data, self.device)
361357
loss = output.sum()
362358
loss.backward()
363359

364360
# Check that gradients exist and are non-zero
365-
embedding_table = model._embedding_bag_collection.embedding_bags[
361+
embedding_table = dmp_model._dmp_wrapped_module._embedding_bag_collection.embedding_bags[
366362
"node_embedding_default_homogeneous_node_type"
367363
]
368364
self.assertIsNotNone(
@@ -383,22 +379,22 @@ def test_dmp_multiprocess(self, _name, world_size):
383379
"""
384380
Test DMP with multiple processes to verify embedding sharding works correctly.
385381
386-
Note: Uses CPU/Gloo backend for unit testing. For GPU/NCCL testing, see integration tests.
382+
Note: Uses CPU/Gloo backend for unit testing.
387383
"""
388384
process_group_init_method = get_process_group_init_method()
389385

390386
# Spawn world_size processes
391387
mp.spawn(
392388
fn=_run_dmp_multiprocess_test,
393389
args=(
394-
world_size,
395-
process_group_init_method,
396-
self.num_nodes,
397-
self.embedding_dim,
398-
self.num_layers,
399-
self.edge_index,
400-
self.test_embeddings,
401-
self.expected_output,
390+
world_size, # total number of processes
391+
process_group_init_method, # initialization method for process group
392+
self.num_nodes, # number of nodes in test graph
393+
self.embedding_dim, # dimension of embeddings
394+
self.num_layers, # number of LightGCN layers
395+
self.edge_index, # edge connectivity
396+
self.test_embeddings, # test embedding values
397+
self.expected_output, # expected model output
402398
),
403399
nprocs=world_size,
404400
)
@@ -426,7 +422,7 @@ def _run_dmp_multiprocess_test(
426422
embedding_dim: Dimension of embeddings
427423
num_layers: Number of LightGCN layers
428424
edge_index: Edge connectivity
429-
test_embeddings: Expected embedding values
425+
test_embeddings: Test embedding values
430426
expected_output: Expected model output
431427
"""
432428
try:

0 commit comments

Comments
 (0)