Skip to content

Commit 434a177

Browse files
committed
Expand logging test
1 parent 270ca18 commit 434a177

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactoryTest.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import static org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples.pair;
5959
import static org.junit.jupiter.api.Assertions.assertEquals;
6060
import static org.junit.jupiter.params.provider.Arguments.arguments;
61+
import static org.neo4j.gds.assertj.Extractors.keepingFixedNumberOfDecimals;
6162
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
6263
import static org.neo4j.gds.compat.TestLog.INFO;
6364
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.RESIDENT_MEMORY;
@@ -461,13 +462,15 @@ void memoryEstimationTreeStructure(boolean isMultiLabel) {
461462
void testLogging() {
462463
var config = ImmutableGraphSageTrainConfig.builder()
463464
.addFeatureProperties(DUMMY_PROPERTY)
464-
.embeddingDimension(64)
465+
.embeddingDimension(12)
466+
.aggregator(Aggregator.AggregatorType.POOL)
467+
.tolerance(1e-10)
468+
.sampleSizes(List.of(5, 3))
469+
.batchSize(5)
470+
.randomSeed(42L)
465471
.modelName("model")
466472
.epochs(2)
467473
.maxIterations(2)
468-
.tolerance(1e-10)
469-
.learningRate(0.001)
470-
.randomSeed(42L)
471474
.build();
472475

473476
var log = Neo4jProxy.testLog();
@@ -485,6 +488,7 @@ void testLogging() {
485488
AssertionsForInterfaceTypes.assertThat(messagesInOrder)
486489
// avoid asserting on the thread id
487490
.extracting(removingThreadId())
491+
.extracting(keepingFixedNumberOfDecimals(2))
488492
.containsExactly(
489493
"GraphSageTrain :: Start",
490494
"GraphSageTrain :: Prepare batches :: Start",
@@ -493,17 +497,23 @@ void testLogging() {
493497
"GraphSageTrain :: Train model :: Start",
494498
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Start",
495499
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Start",
496-
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 531.5699087433",
500+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: LOSS: 132.63",
497501
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 100%",
498502
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Finished",
499503
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Start",
504+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: LOSS: 129.13",
500505
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 100%",
501506
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Finished",
502507
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Finished",
503508
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Start",
504509
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Start",
510+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: LOSS: 123.38",
505511
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 100%",
506512
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Finished",
513+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Start",
514+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: LOSS: 116.06",
515+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 100%",
516+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Finished",
507517
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Finished",
508518
"GraphSageTrain :: Train model :: Finished",
509519
"GraphSageTrain :: Finished"

0 commit comments

Comments
 (0)