Skip to content

Commit 94f94d9

Browse files
committed
Add test for ensemble
1 parent d071cd3 commit 94f94d9

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

tests/test_model.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import lightning as pl
1111
from torchmdnet import models
12-
from torchmdnet.models.model import create_model
12+
from torchmdnet.models.model import create_model, load_model
1313
from torchmdnet.models import output_modules
1414
from torchmdnet.models.utils import dtype_mapping
1515

@@ -23,7 +23,9 @@
2323
def test_forward(model_name, use_batch, explicit_q_s, precision):
2424
z, pos, batch = create_example_batch()
2525
pos = pos.to(dtype=dtype_mapping[precision])
26-
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
26+
model = create_model(
27+
load_example_args(model_name, prior_model=None, precision=precision)
28+
)
2729
batch = batch if use_batch else None
2830
if explicit_q_s:
2931
model(z, pos, batch=batch, q=None, s=None)
@@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):
3335

3436
@mark.parametrize("model_name", models.__all_models__)
3537
@mark.parametrize("output_model", output_modules.__all__)
36-
@mark.parametrize("precision", [32,64])
38+
@mark.parametrize("precision", [32, 64])
3739
def test_forward_output_modules(model_name, output_model, precision):
3840
z, pos, batch = create_example_batch()
39-
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
41+
args = load_example_args(
42+
model_name, remove_prior=True, output_model=output_model, precision=precision
43+
)
4044
model = create_model(args)
4145
model(z, pos, batch=batch)
4246

@@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
6165
grad_outputs=grad_outputs,
6266
)[0]
6367

68+
6469
def test_torchscript_output_modification():
65-
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
70+
model = create_model(
71+
load_example_args("tensornet", remove_prior=True, derivative=True)
72+
)
73+
6674
class MyModel(torch.nn.Module):
6775
def __init__(self):
6876
super(MyModel, self).__init__()
6977
self.model = model
78+
7079
def forward(self, z, pos, batch):
7180
y, neg_dy = self.model(z, pos, batch=batch)
7281
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
73-
return y, 2*neg_dy
82+
return y, 2 * neg_dy
83+
7484
torch.jit.script(MyModel())
7585

86+
7687
@mark.parametrize("model_name", models.__all_models__)
7788
@mark.parametrize("device", ["cpu", "cuda"])
7889
def test_torchscript_dynamic_shapes(model_name, device):
@@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
8495
model = torch.jit.script(
8596
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
8697
).to(device=device)
87-
#Repeat the input to make it dynamic
98+
# Repeat the input to make it dynamic
8899
for rep in range(0, 5):
89100
print(rep)
90-
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
91-
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
101+
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
102+
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
92103
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
93104
y, neg_dy = model(zi, posi, batch=batchi)
94105
grad_outputs = [torch.ones_like(neg_dy)]
@@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device):
98109
grad_outputs=grad_outputs,
99110
)[0]
100111

101-
#Currently only tensornet is CUDA graph compatible
112+
113+
# Currently only tensornet is CUDA graph compatible
102114
@mark.parametrize("model_name", ["tensornet"])
103115
def test_cuda_graph_compatible(model_name):
104116
if not torch.cuda.is_available():
105117
pytest.skip("CUDA not available")
106118
z, pos, batch = create_example_batch()
107-
args = {"model": model_name,
108-
"embedding_dimension": 128,
109-
"num_layers": 2,
110-
"num_rbf": 32,
111-
"rbf_type": "expnorm",
112-
"trainable_rbf": False,
113-
"activation": "silu",
114-
"cutoff_lower": 0.0,
115-
"cutoff_upper": 5.0,
116-
"max_z": 100,
117-
"max_num_neighbors": 128,
118-
"equivariance_invariance_group": "O(3)",
119-
"prior_model": None,
120-
"atom_filter": -1,
121-
"derivative": True,
122-
"check_error": False,
123-
"static_shapes": True,
124-
"output_model": "Scalar",
125-
"reduce_op": "sum",
126-
"precision": 32 }
119+
args = {
120+
"model": model_name,
121+
"embedding_dimension": 128,
122+
"num_layers": 2,
123+
"num_rbf": 32,
124+
"rbf_type": "expnorm",
125+
"trainable_rbf": False,
126+
"activation": "silu",
127+
"cutoff_lower": 0.0,
128+
"cutoff_upper": 5.0,
129+
"max_z": 100,
130+
"max_num_neighbors": 128,
131+
"equivariance_invariance_group": "O(3)",
132+
"prior_model": None,
133+
"atom_filter": -1,
134+
"derivative": True,
135+
"check_error": False,
136+
"static_shapes": True,
137+
"output_model": "Scalar",
138+
"reduce_op": "sum",
139+
"precision": 32,
140+
}
127141
model = create_model(args).to(device="cuda")
128142
model.eval()
129143
z = z.to("cuda")
@@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name):
142156
assert torch.allclose(y, y2)
143157
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)
144158

159+
145160
@mark.parametrize("model_name", models.__all_models__)
146161
def test_seed(model_name):
147162
args = load_example_args(model_name, remove_prior=True)
@@ -153,6 +168,7 @@ def test_seed(model_name):
153168
for p1, p2 in zip(m1.parameters(), m2.parameters()):
154169
assert (p1 == p2).all(), "Parameters don't match although using the same seed."
155170

171+
156172
@mark.parametrize("model_name", models.__all_models__)
157173
@mark.parametrize(
158174
"output_model",
@@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
199215
), f"Set new reference outputs for {model_name} with output model {output_model}."
200216

201217
# compare actual ouput with reference
202-
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
218+
torch.testing.assert_close(
219+
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
220+
)
203221
if derivative:
204222
torch.testing.assert_close(
205223
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
@@ -218,7 +236,7 @@ def test_gradients(model_name):
218236
remove_prior=True,
219237
output_model=output_model,
220238
derivative=derivative,
221-
precision=precision
239+
precision=precision,
222240
)
223241
model = create_model(args)
224242
z, pos, batch = create_example_batch(n_atoms=5)
@@ -227,3 +245,20 @@ def test_gradients(model_name):
227245
torch.autograd.gradcheck(
228246
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
229247
)
248+
249+
250+
def test_ensemble():
251+
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
252+
model = load_model(ckpts[0])
253+
ensemble_model = load_model(ckpts)
254+
z, pos, batch = create_example_batch(n_atoms=5)
255+
256+
pred, deriv = model(z, pos, batch)
257+
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
258+
259+
torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
260+
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
261+
assert y_std.shape == pred.shape
262+
assert neg_dy_std.shape == deriv.shape
263+
assert (y_std == 0).all()
264+
assert (neg_dy_std == 0).all()

0 commit comments

Comments
 (0)