@@ -3179,40 +3179,9 @@ def onestep(x, x_tm4):
31793179 f = function ([seq ], results [1 ])
31803180 assert np .all (exp_out == f (inp ))
31813181
3182- def test_shared_borrow (self ):
3183- """
3184- This tests two things. The first is a bug occurring when scan wrongly
3185- used the borrow flag. The second thing it that Scan's infer_shape()
3186- method will be able to remove the Scan node from the graph in this
3187- case.
3188- """
3189-
3190- inp = np .arange (10 ).reshape (- 1 , 1 ).astype (config .floatX )
3191- exp_out = np .zeros ((10 , 1 )).astype (config .floatX )
3192- exp_out [4 :] = inp [:- 4 ]
3193-
3194- def onestep (x , x_tm4 ):
3195- return x , x_tm4
3196-
3197- seq = matrix ()
3198- initial_value = shared (np .zeros ((4 , 1 ), dtype = config .floatX ))
3199- outputs_info = [{"initial" : initial_value , "taps" : [- 4 ]}, None ]
3200- results = scan (
3201- fn = onestep , sequences = seq , outputs_info = outputs_info , return_updates = False
3202- )
3203- sharedvar = shared (np .zeros ((1 , 1 ), dtype = config .floatX ))
3204- updates = {sharedvar : results [0 ][- 1 :]}
3205-
3206- f = function ([seq ], results [1 ], updates = updates )
3207-
3208- # This fails if scan uses wrongly the borrow flag
3209- assert np .all (exp_out == f (inp ))
3210-
3211- # This fails if Scan's infer_shape() is unable to remove the Scan
3212- # node from the graph.
3213- f_infershape = function ([seq ], results [1 ].shape , mode = "FAST_RUN" )
3214- scan_nodes_infershape = scan_nodes_from_fct (f_infershape )
3215- assert len (scan_nodes_infershape ) == 0
3182+ @pytest .mark .parametrize ("static_shape" , (True , False ))
3183+ def test_aliased_inner_outputs (self , static_shape ):
3184+ ScanCompatibilityTests .check_aliased_inner_outputs (static_shape , mode = None )
32163185
32173186 def test_memory_reuse_with_outputs_as_inputs (self ):
32183187 """
@@ -4327,7 +4296,6 @@ def check_higher_order_derivative(mode):
43274296 # FIXME: All implementations of Scan seem to get this one wrong!
43284297 # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
43294298
4330-
43314299 @staticmethod
43324300 def check_grad_until_and_truncate_sequence_taps (mode ):
43334301 """Test case where we need special behavior of zeroing out sequences in Scan"""
@@ -4349,3 +4317,66 @@ def check_grad_until_and_truncate_sequence_taps(mode):
43494317 grad_expected = np .array ([0 , 0 , 0 , 5 , 6 , 10 , 4 , 5 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
43504318 grad_expected = grad_expected .astype (config .floatX )
43514319 np .testing .assert_allclose (grad_res , grad_expected )
4320+
4321+ @staticmethod
4322+ def check_aliased_inner_outputs (static_shape , mode ):
4323+ """
4324+ This tests two things. The first is a bug occurring when scan wrongly
4325+ used the borrow flag. The second thing it that Scan's infer_shape()
4326+ method will be able to remove the Scan node from the graph in this
4327+ case.
4328+
4329+ Here is pure python equivalent of the problem we want to avoid:
4330+ ```python
4331+ def scan(seq, initval):
4332+ # Due to memory optimization we override values of mitsot as we iterate
4333+ # That's why mitsot has shape (4, 1) and not (14, 1)
4334+ mitsot = np.zeros((4, 1))
4335+ mitsot[:4] = initval
4336+ nitsot = np.zeros((10, 1))
4337+ for i, s in enumerate(seq):
4338+ # Incorrect results
4339+ mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
4340+ # Correct results
4341+ # mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
4342+
4343+ return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
4344+
4345+ scan(np.arange(10), np.zeros((4, 1)))
4346+ ```
4347+ """
4348+
4349+ def onestep (seq , seq_tm4 ):
4350+ # Recurring output is just each value of seq
4351+ # And we further map the tap -4 as a new output
4352+ return seq , seq_tm4
4353+
4354+ # Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
4355+ # Otherwise we would be working with scalars and memory alias wouldn't be a concern
4356+ seq = matrix (shape = (10 , 1 ) if static_shape else (None , None ), name = "seq" )
4357+ init = matrix (shape = (4 , 1 ) if static_shape else (None , None ), name = "init" )
4358+ outputs_info = [{"initial" : init , "taps" : [- 4 ]}, None ]
4359+ [out_seq , out_seq_tm4 ] = scan (
4360+ fn = onestep ,
4361+ sequences = seq ,
4362+ outputs_info = outputs_info ,
4363+ return_updates = False ,
4364+ )
4365+
4366+ f = function ([seq , init ], [out_seq [- 1 ].ravel (), out_seq_tm4 .ravel ()], mode = mode )
4367+
4368+ seq_test_val = np .arange (10 , dtype = config .floatX )[:, None ]
4369+ init_test_val = np .zeros ((4 , 1 ), dtype = config .floatX )
4370+
4371+ res0 , res1 = f (seq_test_val , init_test_val )
4372+ expected_res0 = np .array ([9 ], dtype = config .floatX )
4373+ expected_res1 = np .zeros (10 , dtype = config .floatX )
4374+ expected_res1 [4 :] = np .arange (6 )
4375+ np .testing .assert_array_equal (res0 , expected_res0 )
4376+ np .testing .assert_array_equal (res1 , expected_res1 )
4377+
4378+ # This fails if Scan's infer_shape() is unable to remove the Scan
4379+ # node from the graph.
4380+ f_infershape = function ([seq , init ], out_seq_tm4 [1 ].shape )
4381+ scan_nodes_infershape = scan_nodes_from_fct (f_infershape )
4382+ assert len (scan_nodes_infershape ) == 0
0 commit comments