@@ -50,14 +50,22 @@ void shouldApproximateGradient() {
5050 finiteDifferenceShouldApproximateGradient (weights , new ElementSum (List .of (new Relu <>(weights ))));
5151 }
5252
53+ @ Test
54+ void considerSelfGradient () {
55+ Weights <Vector > weights = new Weights <>(new Vector (-1 , 5 , 2 ));
56+ var chainedRelu = new Sigmoid <>(new Relu <>(weights ));
57+
58+ finiteDifferenceShouldApproximateGradient (weights , new ElementSum (List .of (chainedRelu )));
59+ }
60+
5361 @ Test
5462 void shouldComputeRelu () {
5563 double [] vectorData = {14 , -5 , 36 , 0 };
5664 Constant <Vector > p = Constant .vector (vectorData );
5765
5866 Variable <Vector > relu = new Relu <>(p );
5967
60- var expected = new Vector (new double []{ 14 , 0.01 * -5 , 36 , 0 } );
68+ var expected = new Vector (14 , 0.01 * -5 , 36 , 0 );
6169 assertThat (ctx .forward (relu )).isEqualTo (expected );
6270 }
6371
@@ -68,7 +76,7 @@ void returnsEmptyDataForEmptyVariable() {
6876
6977 Variable <Vector > relu = new Relu <>(p );
7078
71- var expected = new Vector (new double []{} );
79+ var expected = new Vector ();
7280 assertThat (ctx .forward (relu )).isEqualTo (expected );
7381 }
7482
0 commit comments