Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 122 additions & 34 deletions autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,141 @@
import numpy as np
from autogalaxy import cosmology as cosmo

def is_jax(x):
try:
import jax
from jax import Array
from jax.core import Tracer
return isinstance(x, (Array, Tracer))
except Exception:
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_jax catches Exception broadly, which can mask real import/runtime errors inside JAX and silently fall back to NumPy. It would be safer to catch only ImportError (and potentially AttributeError for older JAX types) so unexpected failures don’t change numerical backends without warning.

Suggested change
except Exception:
except (ImportError, ModuleNotFoundError, AttributeError):

Copilot uses AI. Check for mistakes.
return False

def kappa_s_and_scale_radius(
cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope
cosmology,
virial_mass,
c_2,
overdens,
redshift_object,
redshift_source,
inner_slope,
):
from scipy.integrate import quad

concentration = (2.0 - inner_slope) * c_2 # gNFW concentration

critical_density = cosmology.critical_density(
redshift_object, xp=np
) # Msun / kpc^3

critical_surface_density = (
cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
redshift_0=redshift_object,
redshift_1=redshift_source,
xp=np,
)
"""
Compute the characteristic convergence and scale radius of a spherical gNFW halo
parameterised by virial mass and concentration.

This routine converts a halo defined by its virial mass and concentration into
the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing
calculations. The normalization is computed analytically using the closed-form
hypergeometric expression for the enclosed mass integral, ensuring compatibility
with both NumPy and JAX backends (e.g. within `jax.jit`).

The virial radius is defined via:

M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3

where Δ is the overdensity with respect to the critical density. If `overdens`
is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used.

The gNFW normalization constant is computed as:

d_e = (Δ / 3) (3 − γ) c^γ /
₂F₁(3 − γ, 3 − γ; 4 − γ; −c)

where γ is the inner slope and c is the gNFW concentration.

Parameters
----------
cosmology
Cosmology object providing critical density, angular diameter distance
conversions, and surface mass density calculations. Must support an `xp`
argument for NumPy/JAX interoperability.
virial_mass
Virial mass of the halo in units of solar masses.
c_2
Concentration-like parameter, converted internally to the gNFW
concentration via `(2 - inner_slope) * c_2`.
overdens
Overdensity with respect to the critical density. If zero, the
Bryan & Norman (1998) redshift-dependent overdensity is used.
redshift_object
Redshift of the lens (halo).
redshift_source
Redshift of the background source.
inner_slope
Inner logarithmic density slope γ of the gNFW profile.
xp
Array backend module (`numpy` or `jax.numpy`). All array operations
are dispatched through this module to ensure compatibility with
both standard NumPy execution and JAX tracing / JIT compilation.
Comment on lines +50 to +73
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring documents an xp parameter and states JIT-compatibility when xp=jax.numpy, but kappa_s_and_scale_radius does not accept xp and instead infers the backend from virial_mass. Update the docstring to match the actual API, or add an explicit xp argument (and propagate it through) so callers can control the backend deterministically.

Copilot uses AI. Check for mistakes.

Returns
-------
kappa_s
Dimensionless characteristic convergence of the gNFW profile.
scale_radius
Angular scale radius in arcseconds.
virial_radius
Virial radius in kiloparsecs.
overdens
Final overdensity value used in the calculation.

Notes
-----
- This implementation is fully JIT-compatible when `xp=jax.numpy`.
- No Python-side branching depends on traced values; conditional logic
is implemented via backend array operations.
- The analytic normalization avoids numerical quadrature, improving
both performance and differentiability.
"""
is_jax_bool = is_jax(virial_mass)

if not is_jax_bool:
xp = np
else:
from jax import numpy as jnp
xp = jnp

Comment on lines +94 to +101
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backend selection is based only on is_jax(virial_mass). If a caller passes virial_mass as a Python/NumPy scalar but supplies other inputs as JAX arrays/tracers (e.g. c_2, overdens, or redshifts), this will select NumPy and then attempt NumPy ops on traced values, breaking JIT/grad. Consider either adding an explicit xp parameter, or inferring the backend from all numeric inputs (or at least from c_2 / overdens as well).

Copilot uses AI. Check for mistakes.
if xp is np:
from scipy.special import hyp2f1
else:
try:
from jax.scipy.special import hyp2f1 # noqa: F401
except Exception as e:
raise RuntimeError(
"This feature requires jax.scipy.special.hyp2f1, which is available in "
"JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`."
) from e

gamma = inner_slope
concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition)

critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3

critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
redshift_0=redshift_object,
redshift_1=redshift_source,
xp=xp,
) # Msun / kpc^2

kpc_per_arcsec = cosmology.kpc_per_arcsec_from(
redshift=redshift_object, xp=np
) # kpc / arcsec
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec

if overdens == 0:
x = cosmology.Om(redshift_object, xp=np) - 1.0
overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998)
# Bryan & Norman (1998) overdensity if overdens == 0
x = cosmology.Om(redshift_object, xp=xp) - 1.0
overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using xp.where here changes the NumPy-path behavior: when overdens is a scalar, np.where(...) returns a 0-d ndarray, which then propagates so the function returns arrays instead of plain scalars under NumPy. If downstream code expects Python / NumPy scalars (e.g. for JSON serialization), consider keeping the original Python if overdens == 0: branch in the NumPy path and reserving xp.where for the JAX path (or explicitly converting 0-d arrays back to scalars in the NumPy backend).

Suggested change
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
# Use a Python scalar branch for NumPy scalars to avoid returning 0-d arrays;
# keep xp.where for JAX and array inputs.
if is_jax(overdens) or not np.isscalar(overdens):
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
else:
if overdens == 0:
overdens = overdens_bn98

Copilot uses AI. Check for mistakes.

# r_vir in kpc
virial_radius = (
virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0))
) ** (1.0 / 3.0)
virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0)

# scale radius in kpc
scale_radius_kpc = virial_radius / concentration

# Normalization integral for gNFW
def integrand(r):
return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** (
inner_slope - 3.0
)
# c = rvir/rs is exactly "concentration" by definition
c = concentration

de_c = (
(overdens / 3.0)
* (virial_radius**3 / scale_radius_kpc**inner_slope)
/ quad(integrand, 0.0, virial_radius)[0]
)
# Analytic normalization
a = 3.0 - gamma
de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c)

rho_s = critical_density * de_c # Msun / kpc^3
kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless
Expand Down
Loading