Skip to content

Commit 9b0a362

Browse files
authored
Merge branch 'main' into ensemble_model_support
2 parents 417955d + 9546e88 commit 9b0a362

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

tests/test_model.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,28 +116,26 @@ def test_cuda_graph_compatible(model_name):
116116
if not torch.cuda.is_available():
117117
pytest.skip("CUDA not available")
118118
z, pos, batch = create_example_batch()
119-
args = {
120-
"model": model_name,
121-
"embedding_dimension": 128,
122-
"num_layers": 2,
123-
"num_rbf": 32,
124-
"rbf_type": "expnorm",
125-
"trainable_rbf": False,
126-
"activation": "silu",
127-
"cutoff_lower": 0.0,
128-
"cutoff_upper": 5.0,
129-
"max_z": 100,
130-
"max_num_neighbors": 128,
131-
"equivariance_invariance_group": "O(3)",
132-
"prior_model": None,
133-
"atom_filter": -1,
134-
"derivative": True,
135-
"check_error": False,
136-
"static_shapes": True,
137-
"output_model": "Scalar",
138-
"reduce_op": "sum",
139-
"precision": 32,
140-
}
119+
args = {"model": model_name,
120+
"embedding_dimension": 128,
121+
"num_layers": 2,
122+
"num_rbf": 32,
123+
"rbf_type": "expnorm",
124+
"trainable_rbf": False,
125+
"activation": "silu",
126+
"cutoff_lower": 0.0,
127+
"cutoff_upper": 5.0,
128+
"max_z": 100,
129+
"max_num_neighbors": 128,
130+
"equivariance_invariance_group": "O(3)",
131+
"prior_model": None,
132+
"atom_filter": -1,
133+
"derivative": True,
134+
"check_errors": False,
135+
"static_shapes": True,
136+
"output_model": "Scalar",
137+
"reduce_op": "sum",
138+
"precision": 32 }
141139
model = create_model(args).to(device="cuda")
142140
model.eval()
143141
z = z.to("cuda")

0 commit comments

Comments
 (0)