99import torch
1010import lightning as pl
1111from torchmdnet import models
12- from torchmdnet .models .model import create_model
12+ from torchmdnet .models .model import create_model , load_model
1313from torchmdnet .models import output_modules
1414from torchmdnet .models .utils import dtype_mapping
1515
2323def test_forward (model_name , use_batch , explicit_q_s , precision ):
2424 z , pos , batch = create_example_batch ()
2525 pos = pos .to (dtype = dtype_mapping [precision ])
26- model = create_model (load_example_args (model_name , prior_model = None , precision = precision ))
26+ model = create_model (
27+ load_example_args (model_name , prior_model = None , precision = precision )
28+ )
2729 batch = batch if use_batch else None
2830 if explicit_q_s :
2931 model (z , pos , batch = batch , q = None , s = None )
@@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):
3335
3436@mark .parametrize ("model_name" , models .__all_models__ )
3537@mark .parametrize ("output_model" , output_modules .__all__ )
36- @mark .parametrize ("precision" , [32 ,64 ])
38+ @mark .parametrize ("precision" , [32 , 64 ])
3739def test_forward_output_modules (model_name , output_model , precision ):
3840 z , pos , batch = create_example_batch ()
39- args = load_example_args (model_name , remove_prior = True , output_model = output_model , precision = precision )
41+ args = load_example_args (
42+ model_name , remove_prior = True , output_model = output_model , precision = precision
43+ )
4044 model = create_model (args )
4145 model (z , pos , batch = batch )
4246
@@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
6165 grad_outputs = grad_outputs ,
6266 )[0 ]
6367
68+
6469def test_torchscript_output_modification ():
65- model = create_model (load_example_args ("tensornet" , remove_prior = True , derivative = True ))
70+ model = create_model (
71+ load_example_args ("tensornet" , remove_prior = True , derivative = True )
72+ )
73+
6674 class MyModel (torch .nn .Module ):
6775 def __init__ (self ):
6876 super (MyModel , self ).__init__ ()
6977 self .model = model
78+
7079 def forward (self , z , pos , batch ):
7180 y , neg_dy = self .model (z , pos , batch = batch )
7281 # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
73- return y , 2 * neg_dy
82+ return y , 2 * neg_dy
83+
7484 torch .jit .script (MyModel ())
7585
86+
7687@mark .parametrize ("model_name" , models .__all_models__ )
7788@mark .parametrize ("device" , ["cpu" , "cuda" ])
7889def test_torchscript_dynamic_shapes (model_name , device ):
@@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
8495 model = torch .jit .script (
8596 create_model (load_example_args (model_name , remove_prior = True , derivative = True ))
8697 ).to (device = device )
87- #Repeat the input to make it dynamic
98+ # Repeat the input to make it dynamic
8899 for rep in range (0 , 5 ):
89100 print (rep )
90- zi = z .repeat_interleave (rep + 1 , dim = 0 ).to (device = device )
91- posi = pos .repeat_interleave (rep + 1 , dim = 0 ).to (device = device )
101+ zi = z .repeat_interleave (rep + 1 , dim = 0 ).to (device = device )
102+ posi = pos .repeat_interleave (rep + 1 , dim = 0 ).to (device = device )
92103 batchi = torch .randint (0 , 10 , (zi .shape [0 ],)).sort ()[0 ].to (device = device )
93104 y , neg_dy = model (zi , posi , batch = batchi )
94105 grad_outputs = [torch .ones_like (neg_dy )]
@@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device):
98109 grad_outputs = grad_outputs ,
99110 )[0 ]
100111
101- #Currently only tensornet is CUDA graph compatible
112+
113+ # Currently only tensornet is CUDA graph compatible
102114@mark .parametrize ("model_name" , ["tensornet" ])
103115def test_cuda_graph_compatible (model_name ):
104116 if not torch .cuda .is_available ():
105117 pytest .skip ("CUDA not available" )
106118 z , pos , batch = create_example_batch ()
107- args = {"model" : model_name ,
108- "embedding_dimension" : 128 ,
109- "num_layers" : 2 ,
110- "num_rbf" : 32 ,
111- "rbf_type" : "expnorm" ,
112- "trainable_rbf" : False ,
113- "activation" : "silu" ,
114- "cutoff_lower" : 0.0 ,
115- "cutoff_upper" : 5.0 ,
116- "max_z" : 100 ,
117- "max_num_neighbors" : 128 ,
118- "equivariance_invariance_group" : "O(3)" ,
119- "prior_model" : None ,
120- "atom_filter" : - 1 ,
121- "derivative" : True ,
122- "check_error" : False ,
123- "static_shapes" : True ,
124- "output_model" : "Scalar" ,
125- "reduce_op" : "sum" ,
126- "precision" : 32 }
119+ args = {
120+ "model" : model_name ,
121+ "embedding_dimension" : 128 ,
122+ "num_layers" : 2 ,
123+ "num_rbf" : 32 ,
124+ "rbf_type" : "expnorm" ,
125+ "trainable_rbf" : False ,
126+ "activation" : "silu" ,
127+ "cutoff_lower" : 0.0 ,
128+ "cutoff_upper" : 5.0 ,
129+ "max_z" : 100 ,
130+ "max_num_neighbors" : 128 ,
131+ "equivariance_invariance_group" : "O(3)" ,
132+ "prior_model" : None ,
133+ "atom_filter" : - 1 ,
134+ "derivative" : True ,
135+ "check_error" : False ,
136+ "static_shapes" : True ,
137+ "output_model" : "Scalar" ,
138+ "reduce_op" : "sum" ,
139+ "precision" : 32 ,
140+ }
127141 model = create_model (args ).to (device = "cuda" )
128142 model .eval ()
129143 z = z .to ("cuda" )
@@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name):
142156 assert torch .allclose (y , y2 )
143157 assert torch .allclose (neg_dy , neg_dy2 , atol = 1e-5 , rtol = 1e-5 )
144158
159+
145160@mark .parametrize ("model_name" , models .__all_models__ )
146161def test_seed (model_name ):
147162 args = load_example_args (model_name , remove_prior = True )
@@ -153,6 +168,7 @@ def test_seed(model_name):
153168 for p1 , p2 in zip (m1 .parameters (), m2 .parameters ()):
154169 assert (p1 == p2 ).all (), "Parameters don't match although using the same seed."
155170
171+
156172@mark .parametrize ("model_name" , models .__all_models__ )
157173@mark .parametrize (
158174 "output_model" ,
@@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
199215 ), f"Set new reference outputs for { model_name } with output model { output_model } ."
200216
201217 # compare actual ouput with reference
202- torch .testing .assert_close (pred , expected [model_name ][output_model ]["pred" ], atol = 1e-5 , rtol = 1e-5 )
218+ torch .testing .assert_close (
219+ pred , expected [model_name ][output_model ]["pred" ], atol = 1e-5 , rtol = 1e-5
220+ )
203221 if derivative :
204222 torch .testing .assert_close (
205223 deriv , expected [model_name ][output_model ]["deriv" ], atol = 1e-5 , rtol = 1e-5
@@ -218,7 +236,7 @@ def test_gradients(model_name):
218236 remove_prior = True ,
219237 output_model = output_model ,
220238 derivative = derivative ,
221- precision = precision
239+ precision = precision ,
222240 )
223241 model = create_model (args )
224242 z , pos , batch = create_example_batch (n_atoms = 5 )
@@ -227,3 +245,20 @@ def test_gradients(model_name):
227245 torch .autograd .gradcheck (
228246 model , (z , pos , batch ), eps = 1e-4 , atol = 1e-3 , rtol = 1e-2 , nondet_tol = 1e-3
229247 )
248+
249+
250+ def test_ensemble ():
251+ ckpts = [join (dirname (dirname (__file__ )), "tests" , "example.ckpt" )] * 3
252+ model = load_model (ckpts [0 ])
253+ ensemble_model = load_model (ckpts )
254+ z , pos , batch = create_example_batch (n_atoms = 5 )
255+
256+ pred , deriv = model (z , pos , batch )
257+ pred_ensemble , deriv_ensemble , y_std , neg_dy_std = ensemble_model (z , pos , batch )
258+
259+ torch .testing .assert_close (pred , pred_ensemble , atol = 1e-5 , rtol = 1e-5 )
260+ torch .testing .assert_close (deriv , deriv_ensemble , atol = 1e-5 , rtol = 1e-5 )
261+ assert y_std .shape == pred .shape
262+ assert neg_dy_std .shape == deriv .shape
263+ assert (y_std == 0 ).all ()
264+ assert (neg_dy_std == 0 ).all ()
0 commit comments