@@ -940,7 +940,6 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
940940 for get in ['nngp' , 'ntk' ]
941941 for gamma in [1e-6 , 1e-4 , 1e-2 , 1.0 , 2. ]
942942 ))
943-
944943 def test_rbf (self , same_inputs , model , get , gamma ):
945944 activation = stax .Rbf (gamma )
946945 self ._test_activation (activation , same_inputs , model , get ,
@@ -2138,7 +2137,7 @@ def get_attn():
21382137 test_utils .assert_close_matrices (self , empirical , exact , tol )
21392138
21402139
2141- class GNTKTest (test_utils .NeuralTangentsTestCase ):
2140+ class AggregateTest (test_utils .NeuralTangentsTestCase ):
21422141 @jtu .parameterized .named_parameters (
21432142 jtu .cases_from_list ({
21442143 'testcase_name' :
@@ -2157,9 +2156,9 @@ class GNTKTest(test_utils.NeuralTangentsTestCase):
21572156 for test_mask in [True ]
21582157 ))
21592158
2160- def test_GNTK (self , get , readout , same_input , activation , test_mask ):
2159+ def test_aggregate (self , get , readout , same_input , activation , test_mask ):
21612160 batch1 , batch2 = 8 , 6
2162- num_nodes , num_channels = 8 , 12
2161+ num_nodes , num_channels = 4 , 2
21632162 output_dims = 1 if get == 'ntk' else 1024
21642163 key = random .PRNGKey (1 )
21652164 key , split1 , split2 = random .split (key , 3 )
@@ -2183,15 +2182,16 @@ def test_GNTK(self, get, readout, same_input, activation, test_mask):
21832182
21842183 # Build the infinite network.
21852184 init_fn , apply_fn , kernel_fn = stax .serial (
2186- stax .Dense (128 * 8 * 4 ),
2185+ stax .Dense (128 * 8 ),
21872186 activation ,
21882187 stax .Dropout (0.5 , mode = 'train' ),
21892188 stax .Aggregate (),
21902189 readout ,
21912190 stax .Dense (output_dims ))
21922191 kernel_fn = batch .batch (kernel_fn , batch_size = 2 )
21932192 kernel_mc_fn = monte_carlo .monte_carlo_kernel_fn (
2194- init_fn , apply_fn , random .PRNGKey (10 ), 300 )
2193+ init_fn , apply_fn , random .PRNGKey (10 ), 128 ,
2194+ batch_size = 2 if xla_bridge .get_backend ().platform == 'tpu' else 0 )
21952195 empirical = kernel_mc_fn (x1 , x2 , get ,
21962196 mask_constant = mask_constant if test_mask else None ,
21972197 pattern = (pattern1 , pattern2 ))
0 commit comments