5858import static org .eclipse .collections .impl .tuple .primitive .PrimitiveTuples .pair ;
5959import static org .junit .jupiter .api .Assertions .assertEquals ;
6060import static org .junit .jupiter .params .provider .Arguments .arguments ;
61+ import static org .neo4j .gds .assertj .Extractors .keepingFixedNumberOfDecimals ;
6162import static org .neo4j .gds .assertj .Extractors .removingThreadId ;
6263import static org .neo4j .gds .compat .TestLog .INFO ;
6364import static org .neo4j .gds .core .utils .mem .MemoryEstimations .RESIDENT_MEMORY ;
@@ -461,13 +462,15 @@ void memoryEstimationTreeStructure(boolean isMultiLabel) {
461462 void testLogging () {
462463 var config = ImmutableGraphSageTrainConfig .builder ()
463464 .addFeatureProperties (DUMMY_PROPERTY )
464- .embeddingDimension (64 )
465+ .embeddingDimension (12 )
466+ .aggregator (Aggregator .AggregatorType .POOL )
467+ .tolerance (1e-10 )
468+ .sampleSizes (List .of (5 , 3 ))
469+ .batchSize (5 )
470+ .randomSeed (42L )
465471 .modelName ("model" )
466472 .epochs (2 )
467473 .maxIterations (2 )
468- .tolerance (1e-10 )
469- .learningRate (0.001 )
470- .randomSeed (42L )
471474 .build ();
472475
473476 var log = Neo4jProxy .testLog ();
@@ -485,6 +488,7 @@ void testLogging() {
485488 AssertionsForInterfaceTypes .assertThat (messagesInOrder )
486489 // avoid asserting on the thread id
487490 .extracting (removingThreadId ())
491+ .extracting (keepingFixedNumberOfDecimals (2 ))
488492 .containsExactly (
489493 "GraphSageTrain :: Start" ,
490494 "GraphSageTrain :: Prepare batches :: Start" ,
@@ -493,17 +497,23 @@ void testLogging() {
493497 "GraphSageTrain :: Train model :: Start" ,
494498 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Start" ,
495499 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Start" ,
496- "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 531.5699087433 " ,
500+ "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 132.63 " ,
497501 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 100%" ,
498502 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Finished" ,
499503 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Start" ,
504+ "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: LOSS: 129.13" ,
500505 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 100%" ,
501506 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Finished" ,
502507 "GraphSageTrain :: Train model :: Epoch 1 of 2 :: Finished" ,
503508 "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Start" ,
504509 "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Start" ,
510+ "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: LOSS: 123.38" ,
505511 "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 100%" ,
506512 "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Finished" ,
513+ "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Start" ,
514+ "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: LOSS: 116.06" ,
515+ "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 100%" ,
516+ "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Finished" ,
507517 "GraphSageTrain :: Train model :: Epoch 2 of 2 :: Finished" ,
508518 "GraphSageTrain :: Train model :: Finished" ,
509519 "GraphSageTrain :: Finished"
0 commit comments