diff --git a/lagrangebench/runner.py b/lagrangebench/runner.py index 26e1ecc..292a041 100644 --- a/lagrangebench/runner.py +++ b/lagrangebench/runner.py @@ -244,7 +244,8 @@ def model_fn(x): MODEL = models.SEGNN elif model_name == "egnn": - box = cfg.box + bounds = np.array(metadata["bounds"]) + box = bounds[:, 1] - bounds[:, 0] if jnp.array(metadata["periodic_boundary_conditions"]).any(): displacement_fn, shift_fn = space.periodic(jnp.array(box)) else: