@@ -2619,22 +2619,7 @@ def test_grad_until_and_truncate(self):
26192619 utt .assert_allclose (pytensor_gradient , self .numpy_gradient )
26202620
26212621 def test_grad_until_and_truncate_sequence_taps (self ):
2622- n = 3
2623- r = scan (
2624- lambda x , y , u : (x * y , until (y > u )),
2625- sequences = dict (input = self .x , taps = [- 2 , 0 ]),
2626- non_sequences = [self .threshold ],
2627- truncate_gradient = n ,
2628- return_updates = False ,
2629- )
2630- g = grad (r .sum (), self .x )
2631- f = function ([self .x , self .threshold ], [r , g ])
2632- _pytensor_output , pytensor_gradient = f (self .seq , 6 )
2633-
2634- # Gradient computed by hand:
2635- numpy_grad = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
2636- numpy_grad = numpy_grad .astype (config .floatX )
2637- utt .assert_allclose (pytensor_gradient , numpy_grad )
2622+ ScanCompatibilityTests .check_grad_until_and_truncate_sequence_taps (mode = None )
26382623
26392624
26402625def test_mintap_onestep ():
@@ -4341,3 +4326,26 @@ def check_higher_order_derivative(mode):
43414326 np .testing .assert_allclose (gg_res , (16 * 15 ) * x_test ** 14 )
43424327 # FIXME: All implementations of Scan seem to get this one wrong!
43434328 # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
4329+
4330+
4331+ @staticmethod
4332+ def check_grad_until_and_truncate_sequence_taps (mode ):
4333+ """Test case where we need special behavior of zeroing out sequences in Scan"""
4334+ x = pt .vector ("x" )
4335+ threshold = pt .scalar (name = "threshold" , dtype = "int64" )
4336+
4337+ r = scan (
4338+ lambda x , y , u : (x * y , until (y > u )),
4339+ sequences = dict (input = x , taps = [- 2 , 0 ]),
4340+ non_sequences = [threshold ],
4341+ truncate_gradient = 3 ,
4342+ return_updates = False ,
4343+ )
4344+ g = grad (r .sum (), x )
4345+ f = function ([x , threshold ], [r , g ], mode = mode )
4346+ _ , grad_res = f (np .arange (15 , dtype = x .dtype ), 6 )
4347+
4348+ # Gradient computed by hand:
4349+ grad_expected = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
4350+ grad_expected = grad_expected .astype (config .floatX )
4351+ np .testing .assert_allclose (grad_res , grad_expected )
0 commit comments