diff --git a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py index cae42f239..12c01b8af 100644 --- a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py +++ b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py @@ -5,53 +5,141 @@ import numpy as np from autogalaxy import cosmology as cosmo +def is_jax(x): + try: + import jax + from jax import Array + from jax.core import Tracer + return isinstance(x, (Array, Tracer)) + except Exception: + return False def kappa_s_and_scale_radius( - cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope + cosmology, + virial_mass, + c_2, + overdens, + redshift_object, + redshift_source, + inner_slope, ): - from scipy.integrate import quad - - concentration = (2.0 - inner_slope) * c_2 # gNFW concentration - - critical_density = cosmology.critical_density( - redshift_object, xp=np - ) # Msun / kpc^3 - - critical_surface_density = ( - cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( - redshift_0=redshift_object, - redshift_1=redshift_source, - xp=np, - ) + """ + Compute the characteristic convergence and scale radius of a spherical gNFW halo + parameterised by virial mass and concentration. + + This routine converts a halo defined by its virial mass and concentration into + the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing + calculations. The normalization is computed analytically using the closed-form + hypergeometric expression for the enclosed mass integral, ensuring compatibility + with both NumPy and JAX backends (e.g. within `jax.jit`). + + The virial radius is defined via: + + M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3 + + where Δ is the overdensity with respect to the critical density. If `overdens` + is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used. + + The gNFW normalization constant is computed as: + + d_e = (Δ / 3) (3 − γ) c^γ / + ₂F₁(3 − γ, 3 − γ; 4 − γ; −c) + + where γ is the inner slope and c is the gNFW concentration. + + Parameters + ---------- + cosmology + Cosmology object providing critical density, angular diameter distance + conversions, and surface mass density calculations. Must support an `xp` + argument for NumPy/JAX interoperability. + virial_mass + Virial mass of the halo in units of solar masses. + c_2 + Concentration-like parameter, converted internally to the gNFW + concentration via `(2 - inner_slope) * c_2`. + overdens + Overdensity with respect to the critical density. If zero, the + Bryan & Norman (1998) redshift-dependent overdensity is used. + redshift_object + Redshift of the lens (halo). + redshift_source + Redshift of the background source. + inner_slope + Inner logarithmic density slope γ of the gNFW profile. + xp + Array backend module (`numpy` or `jax.numpy`). All array operations + are dispatched through this module to ensure compatibility with + both standard NumPy execution and JAX tracing / JIT compilation. + + Returns + ------- + kappa_s + Dimensionless characteristic convergence of the gNFW profile. + scale_radius + Angular scale radius in arcseconds. + virial_radius + Virial radius in kiloparsecs. + overdens + Final overdensity value used in the calculation. + + Notes + ----- + - This implementation is fully JIT-compatible when `xp=jax.numpy`. + - No Python-side branching depends on traced values; conditional logic + is implemented via backend array operations. + - The analytic normalization avoids numerical quadrature, improving + both performance and differentiability. + """ + is_jax_bool = is_jax(virial_mass) + + if not is_jax_bool: + xp = np + else: + from jax import numpy as jnp + xp = jnp + + if xp is np: + from scipy.special import hyp2f1 + else: + try: + from jax.scipy.special import hyp2f1 # noqa: F401 + except Exception as e: + raise RuntimeError( + "This feature requires jax.scipy.special.hyp2f1, which is available in " + "JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`." + ) from e + + gamma = inner_slope + concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition) + + critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3 + + critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( + redshift_0=redshift_object, + redshift_1=redshift_source, + xp=xp, ) # Msun / kpc^2 - kpc_per_arcsec = cosmology.kpc_per_arcsec_from( - redshift=redshift_object, xp=np - ) # kpc / arcsec + kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec - if overdens == 0: - x = cosmology.Om(redshift_object, xp=np) - 1.0 - overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998) + # Bryan & Norman (1998) overdensity if overdens == 0 + x = cosmology.Om(redshift_object, xp=xp) - 1.0 + overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 + overdens = xp.where(overdens == 0, overdens_bn98, overdens) # r_vir in kpc - virial_radius = ( - virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0)) - ) ** (1.0 / 3.0) + virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0) # scale radius in kpc scale_radius_kpc = virial_radius / concentration - # Normalization integral for gNFW - def integrand(r): - return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** ( - inner_slope - 3.0 - ) + # c = rvir/rs is exactly "concentration" by definition + c = concentration - de_c = ( - (overdens / 3.0) - * (virial_radius**3 / scale_radius_kpc**inner_slope) - / quad(integrand, 0.0, virial_radius)[0] - ) + # Analytic normalization + a = 3.0 - gamma + de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c) rho_s = critical_density * de_c # Msun / kpc^3 kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless