@@ -299,13 +299,12 @@ def test_dmp_wrapped_model_produces_correct_output(self):
299299 """
300300 process_group_init_method = get_process_group_init_method ()
301301 # Initialize distributed
302- if not dist .is_initialized ():
303- dist .init_process_group (
304- backend = "gloo" ,
305- init_method = process_group_init_method ,
306- rank = 0 ,
307- world_size = 1 ,
308- )
302+ dist .init_process_group (
303+ backend = "gloo" ,
304+ init_method = process_group_init_method ,
305+ rank = 0 ,
306+ world_size = 1 ,
307+ )
309308
310309 # Create model
311310 model = self ._create_lightgcn_model (self .num_nodes )
@@ -336,13 +335,12 @@ def test_dmp_gradient_flow(self):
336335 process_group_init_method = get_process_group_init_method ()
337336
338337 # Initialize distributed
339- if not dist .is_initialized ():
340- dist .init_process_group (
341- backend = "gloo" ,
342- init_method = process_group_init_method ,
343- rank = 0 ,
344- world_size = 1 ,
345- )
338+ dist .init_process_group (
339+ backend = "gloo" ,
340+ init_method = process_group_init_method ,
341+ rank = 0 ,
342+ world_size = 1 ,
343+ )
346344
347345 # Create and wrap model
348346 model = self ._create_lightgcn_model (self .num_nodes )
@@ -354,15 +352,13 @@ def test_dmp_gradient_flow(self):
354352
355353 self ._set_embeddings (model , "default_homogeneous_node_type" )
356354
357- model .train ()
358-
359- # Forward and backward pass
360- output = dmp_model (data = self .data , device = self .device )
355+ dmp_model .train ()
356+ output = dmp_model (self .data , self .device )
361357 loss = output .sum ()
362358 loss .backward ()
363359
364360 # Check that gradients exist and are non-zero
365- embedding_table = model ._embedding_bag_collection .embedding_bags [
361+ embedding_table = dmp_model . _dmp_wrapped_module ._embedding_bag_collection .embedding_bags [
366362 "node_embedding_default_homogeneous_node_type"
367363 ]
368364 self .assertIsNotNone (
@@ -383,22 +379,22 @@ def test_dmp_multiprocess(self, _name, world_size):
383379 """
384380 Test DMP with multiple processes to verify embedding sharding works correctly.
385381
386- Note: Uses CPU/Gloo backend for unit testing. For GPU/NCCL testing, see integration tests.
382+ Note: Uses CPU/Gloo backend for unit testing.
387383 """
388384 process_group_init_method = get_process_group_init_method ()
389385
390386 # Spawn world_size processes
391387 mp .spawn (
392388 fn = _run_dmp_multiprocess_test ,
393389 args = (
394- world_size ,
395- process_group_init_method ,
396- self .num_nodes ,
397- self .embedding_dim ,
398- self .num_layers ,
399- self .edge_index ,
400- self .test_embeddings ,
401- self .expected_output ,
390+ world_size , # total number of processes
391+ process_group_init_method , # initialization method for process group
392+ self .num_nodes , # number of nodes in test graph
393+ self .embedding_dim , # dimension of embeddings
394+ self .num_layers , # number of LightGCN layers
395+ self .edge_index , # edge connectivity
396+ self .test_embeddings , # test embedding values
397+ self .expected_output , # expected model output
402398 ),
403399 nprocs = world_size ,
404400 )
@@ -426,7 +422,7 @@ def _run_dmp_multiprocess_test(
426422 embedding_dim: Dimension of embeddings
427423 num_layers: Number of LightGCN layers
428424 edge_index: Edge connectivity
429- test_embeddings: Expected embedding values
425+ test_embeddings: Test embedding values
430426 expected_output: Expected model output
431427 """
432428 try :
0 commit comments