diff --git a/autogalaxy/operate/image.py b/autogalaxy/operate/image.py index da457eb8c..44a3ab9c6 100644 --- a/autogalaxy/operate/image.py +++ b/autogalaxy/operate/image.py @@ -1,6 +1,6 @@ from __future__ import annotations +import jax import jax.numpy as jnp -import numpy as np from typing import TYPE_CHECKING, Dict, List, Optional from autoarray import Array2D @@ -10,8 +10,6 @@ import autoarray as aa -from autogalaxy import exc - class OperateImage: """ @@ -189,12 +187,14 @@ def visibilities_from( image_2d = self.image_2d_from(grid=grid) - if not jnp.any(image_2d.array): - return aa.Visibilities.zeros( + return jax.lax.cond( + jnp.any(image_2d.array), + lambda _: transformer.visibilities_from(image=image_2d), + lambda _: aa.Visibilities.zeros( shape_slim=(transformer.uv_wavelengths.shape[0],) - ) - - return transformer.visibilities_from(image=image_2d) + ), + operand=None, + ) class OperateImageList(OperateImage):