From c174bc43bd83f04f3369ba7df7f5917cff80976c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 2 Mar 2026 16:25:09 +0000 Subject: [PATCH 1/2] hack to determine xp on the fly, until I make higher level decisions architectuallu --- .../mass/dark/gnfw_virial_mass_conc.py | 190 +++++++++++++++--- 1 file changed, 159 insertions(+), 31 deletions(-) diff --git a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py index cae42f239..a889c722a 100644 --- a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py +++ b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py @@ -5,53 +5,181 @@ import numpy as np from autogalaxy import cosmology as cosmo +def is_jax(x): + try: + import jax + from jax import Array + from jax.core import Tracer + return isinstance(x, (Array, Tracer)) + except Exception: + return False + +def _hyp2f1_jax(xp, *, max_terms: int = 256): + """ + Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp. + + - NumPy: scipy.special.hyp2f1 + - JAX (if available): jax.scipy.special.hyp2f1 + - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case) + """ + import jax + import jax.numpy as jnp + + # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z) + # We implement general 2F1 series: + # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n! + # + # Recurrence for terms: + # t_0 = 1 + # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z + # + # This is JIT-safe with static max_terms. + def hyp2f1_series(a, b, c, z): + a = jnp.asarray(a) + b = jnp.asarray(b) + c = jnp.asarray(c) + z = jnp.asarray(z) + + def body_fun(n, carry): + t, s = carry + n_f = jnp.asarray(n, dtype=t.dtype) + t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z + s = s + t + return (t, s) + + # Start: t0 = 1, s0 = 1 + t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z)) + s0 = t0 + + # fori_loop has static iteration count => good under jit/vmap + tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0)) + return sN + + return hyp2f1_series def kappa_s_and_scale_radius( - cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope + cosmology, + virial_mass, + c_2, + overdens, + redshift_object, + redshift_source, + inner_slope, ): - from scipy.integrate import quad + """ + Compute the characteristic convergence and scale radius of a spherical gNFW halo + parameterised by virial mass and concentration. - concentration = (2.0 - inner_slope) * c_2 # gNFW concentration + This routine converts a halo defined by its virial mass and concentration into + the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing + calculations. The normalization is computed analytically using the closed-form + hypergeometric expression for the enclosed mass integral, ensuring compatibility + with both NumPy and JAX backends (e.g. within `jax.jit`). - critical_density = cosmology.critical_density( - redshift_object, xp=np - ) # Msun / kpc^3 + The virial radius is defined via: - critical_surface_density = ( - cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( - redshift_0=redshift_object, - redshift_1=redshift_source, - xp=np, - ) + M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3 + + where Δ is the overdensity with respect to the critical density. If `overdens` + is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used. + + The gNFW normalization constant is computed as: + + d_e = (Δ / 3) (3 − γ) c^γ / + ₂F₁(3 − γ, 3 − γ; 4 − γ; −c) + + where γ is the inner slope and c is the gNFW concentration. + + Parameters + ---------- + cosmology + Cosmology object providing critical density, angular diameter distance + conversions, and surface mass density calculations. Must support an `xp` + argument for NumPy/JAX interoperability. + virial_mass + Virial mass of the halo in units of solar masses. + c_2 + Concentration-like parameter, converted internally to the gNFW + concentration via `(2 - inner_slope) * c_2`. + overdens + Overdensity with respect to the critical density. If zero, the + Bryan & Norman (1998) redshift-dependent overdensity is used. + redshift_object + Redshift of the lens (halo). + redshift_source + Redshift of the background source. + inner_slope + Inner logarithmic density slope γ of the gNFW profile. + xp + Array backend module (`numpy` or `jax.numpy`). All array operations + are dispatched through this module to ensure compatibility with + both standard NumPy execution and JAX tracing / JIT compilation. + + Returns + ------- + kappa_s + Dimensionless characteristic convergence of the gNFW profile. + scale_radius + Angular scale radius in arcseconds. + virial_radius + Virial radius in kiloparsecs. + overdens + Final overdensity value used in the calculation. + + Notes + ----- + - This implementation is fully JIT-compatible when `xp=jax.numpy`. + - No Python-side branching depends on traced values; conditional logic + is implemented via backend array operations. + - The analytic normalization avoids numerical quadrature, improving + both performance and differentiability. + """ + is_jax_bool = is_jax(virial_mass) + + if not is_jax_bool: + xp = np + else: + from jax import numpy as jnp + xp = jnp + + if xp is np: + from scipy.special import hyp2f1 + else: + try: + from jax.scipy.special import hyp2f1 + except ImportError: + hyp2f1 = _hyp2f1_jax(xp) + + gamma = inner_slope + concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition) + + critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3 + + critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( + redshift_0=redshift_object, + redshift_1=redshift_source, + xp=xp, ) # Msun / kpc^2 - kpc_per_arcsec = cosmology.kpc_per_arcsec_from( - redshift=redshift_object, xp=np - ) # kpc / arcsec + kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec - if overdens == 0: - x = cosmology.Om(redshift_object, xp=np) - 1.0 - overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998) + # Bryan & Norman (1998) overdensity if overdens == 0 + x = cosmology.Om(redshift_object, xp=xp) - 1.0 + overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 + overdens = xp.where(overdens == 0, overdens_bn98, overdens) # r_vir in kpc - virial_radius = ( - virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0)) - ) ** (1.0 / 3.0) + virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0) # scale radius in kpc scale_radius_kpc = virial_radius / concentration - # Normalization integral for gNFW - def integrand(r): - return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** ( - inner_slope - 3.0 - ) + # c = rvir/rs is exactly "concentration" by definition + c = concentration - de_c = ( - (overdens / 3.0) - * (virial_radius**3 / scale_radius_kpc**inner_slope) - / quad(integrand, 0.0, virial_radius)[0] - ) + # Analytic normalization + a = 3.0 - gamma + de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c) rho_s = critical_density * de_c # Msun / kpc^3 kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless From 9d1b9d8d17a709f6637c8e089eb83d15ece874c1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 2 Mar 2026 16:52:04 +0000 Subject: [PATCH 2/2] require JAX version update for hypf1 --- .../mass/dark/gnfw_virial_mass_conc.py | 52 +++---------------- 1 file changed, 6 insertions(+), 46 deletions(-) diff --git a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py index a889c722a..12c01b8af 100644 --- a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py +++ b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py @@ -14,49 +14,6 @@ def is_jax(x): except Exception: return False -def _hyp2f1_jax(xp, *, max_terms: int = 256): - """ - Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp. - - - NumPy: scipy.special.hyp2f1 - - JAX (if available): jax.scipy.special.hyp2f1 - - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case) - """ - import jax - import jax.numpy as jnp - - # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z) - # We implement general 2F1 series: - # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n! - # - # Recurrence for terms: - # t_0 = 1 - # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z - # - # This is JIT-safe with static max_terms. - def hyp2f1_series(a, b, c, z): - a = jnp.asarray(a) - b = jnp.asarray(b) - c = jnp.asarray(c) - z = jnp.asarray(z) - - def body_fun(n, carry): - t, s = carry - n_f = jnp.asarray(n, dtype=t.dtype) - t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z - s = s + t - return (t, s) - - # Start: t0 = 1, s0 = 1 - t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z)) - s0 = t0 - - # fori_loop has static iteration count => good under jit/vmap - tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0)) - return sN - - return hyp2f1_series - def kappa_s_and_scale_radius( cosmology, virial_mass, @@ -146,9 +103,12 @@ def kappa_s_and_scale_radius( from scipy.special import hyp2f1 else: try: - from jax.scipy.special import hyp2f1 - except ImportError: - hyp2f1 = _hyp2f1_jax(xp) + from jax.scipy.special import hyp2f1 # noqa: F401 + except Exception as e: + raise RuntimeError( + "This feature requires jax.scipy.special.hyp2f1, which is available in " + "JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`." + ) from e gamma = inner_slope concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition)