-
Notifications
You must be signed in to change notification settings - Fork 14
feature/jaxify_gnfw_conc #286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,53 +5,141 @@ | |||||||||||||||||
| import numpy as np | ||||||||||||||||||
| from autogalaxy import cosmology as cosmo | ||||||||||||||||||
|
|
||||||||||||||||||
| def is_jax(x): | ||||||||||||||||||
| try: | ||||||||||||||||||
| import jax | ||||||||||||||||||
| from jax import Array | ||||||||||||||||||
| from jax.core import Tracer | ||||||||||||||||||
| return isinstance(x, (Array, Tracer)) | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| return False | ||||||||||||||||||
|
|
||||||||||||||||||
| def kappa_s_and_scale_radius( | ||||||||||||||||||
| cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope | ||||||||||||||||||
| cosmology, | ||||||||||||||||||
| virial_mass, | ||||||||||||||||||
| c_2, | ||||||||||||||||||
| overdens, | ||||||||||||||||||
| redshift_object, | ||||||||||||||||||
| redshift_source, | ||||||||||||||||||
| inner_slope, | ||||||||||||||||||
| ): | ||||||||||||||||||
| from scipy.integrate import quad | ||||||||||||||||||
|
|
||||||||||||||||||
| concentration = (2.0 - inner_slope) * c_2 # gNFW concentration | ||||||||||||||||||
|
|
||||||||||||||||||
| critical_density = cosmology.critical_density( | ||||||||||||||||||
| redshift_object, xp=np | ||||||||||||||||||
| ) # Msun / kpc^3 | ||||||||||||||||||
|
|
||||||||||||||||||
| critical_surface_density = ( | ||||||||||||||||||
| cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( | ||||||||||||||||||
| redshift_0=redshift_object, | ||||||||||||||||||
| redshift_1=redshift_source, | ||||||||||||||||||
| xp=np, | ||||||||||||||||||
| ) | ||||||||||||||||||
| """ | ||||||||||||||||||
| Compute the characteristic convergence and scale radius of a spherical gNFW halo | ||||||||||||||||||
| parameterised by virial mass and concentration. | ||||||||||||||||||
|
|
||||||||||||||||||
| This routine converts a halo defined by its virial mass and concentration into | ||||||||||||||||||
| the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing | ||||||||||||||||||
| calculations. The normalization is computed analytically using the closed-form | ||||||||||||||||||
| hypergeometric expression for the enclosed mass integral, ensuring compatibility | ||||||||||||||||||
| with both NumPy and JAX backends (e.g. within `jax.jit`). | ||||||||||||||||||
|
|
||||||||||||||||||
| The virial radius is defined via: | ||||||||||||||||||
|
|
||||||||||||||||||
| M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3 | ||||||||||||||||||
|
|
||||||||||||||||||
| where Δ is the overdensity with respect to the critical density. If `overdens` | ||||||||||||||||||
| is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used. | ||||||||||||||||||
|
|
||||||||||||||||||
| The gNFW normalization constant is computed as: | ||||||||||||||||||
|
|
||||||||||||||||||
| d_e = (Δ / 3) (3 − γ) c^γ / | ||||||||||||||||||
| ₂F₁(3 − γ, 3 − γ; 4 − γ; −c) | ||||||||||||||||||
|
|
||||||||||||||||||
| where γ is the inner slope and c is the gNFW concentration. | ||||||||||||||||||
|
|
||||||||||||||||||
| Parameters | ||||||||||||||||||
| ---------- | ||||||||||||||||||
| cosmology | ||||||||||||||||||
| Cosmology object providing critical density, angular diameter distance | ||||||||||||||||||
| conversions, and surface mass density calculations. Must support an `xp` | ||||||||||||||||||
| argument for NumPy/JAX interoperability. | ||||||||||||||||||
| virial_mass | ||||||||||||||||||
| Virial mass of the halo in units of solar masses. | ||||||||||||||||||
| c_2 | ||||||||||||||||||
| Concentration-like parameter, converted internally to the gNFW | ||||||||||||||||||
| concentration via `(2 - inner_slope) * c_2`. | ||||||||||||||||||
| overdens | ||||||||||||||||||
| Overdensity with respect to the critical density. If zero, the | ||||||||||||||||||
| Bryan & Norman (1998) redshift-dependent overdensity is used. | ||||||||||||||||||
| redshift_object | ||||||||||||||||||
| Redshift of the lens (halo). | ||||||||||||||||||
| redshift_source | ||||||||||||||||||
| Redshift of the background source. | ||||||||||||||||||
| inner_slope | ||||||||||||||||||
| Inner logarithmic density slope γ of the gNFW profile. | ||||||||||||||||||
| xp | ||||||||||||||||||
| Array backend module (`numpy` or `jax.numpy`). All array operations | ||||||||||||||||||
| are dispatched through this module to ensure compatibility with | ||||||||||||||||||
| both standard NumPy execution and JAX tracing / JIT compilation. | ||||||||||||||||||
|
Comment on lines
+50
to
+73
|
||||||||||||||||||
|
|
||||||||||||||||||
| Returns | ||||||||||||||||||
| ------- | ||||||||||||||||||
| kappa_s | ||||||||||||||||||
| Dimensionless characteristic convergence of the gNFW profile. | ||||||||||||||||||
| scale_radius | ||||||||||||||||||
| Angular scale radius in arcseconds. | ||||||||||||||||||
| virial_radius | ||||||||||||||||||
| Virial radius in kiloparsecs. | ||||||||||||||||||
| overdens | ||||||||||||||||||
| Final overdensity value used in the calculation. | ||||||||||||||||||
|
|
||||||||||||||||||
| Notes | ||||||||||||||||||
| ----- | ||||||||||||||||||
| - This implementation is fully JIT-compatible when `xp=jax.numpy`. | ||||||||||||||||||
| - No Python-side branching depends on traced values; conditional logic | ||||||||||||||||||
| is implemented via backend array operations. | ||||||||||||||||||
| - The analytic normalization avoids numerical quadrature, improving | ||||||||||||||||||
| both performance and differentiability. | ||||||||||||||||||
| """ | ||||||||||||||||||
| is_jax_bool = is_jax(virial_mass) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not is_jax_bool: | ||||||||||||||||||
| xp = np | ||||||||||||||||||
| else: | ||||||||||||||||||
| from jax import numpy as jnp | ||||||||||||||||||
| xp = jnp | ||||||||||||||||||
|
|
||||||||||||||||||
|
Comment on lines
+94
to
+101
|
||||||||||||||||||
| if xp is np: | ||||||||||||||||||
| from scipy.special import hyp2f1 | ||||||||||||||||||
| else: | ||||||||||||||||||
| try: | ||||||||||||||||||
| from jax.scipy.special import hyp2f1 # noqa: F401 | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||
| "This feature requires jax.scipy.special.hyp2f1, which is available in " | ||||||||||||||||||
| "JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`." | ||||||||||||||||||
| ) from e | ||||||||||||||||||
|
|
||||||||||||||||||
| gamma = inner_slope | ||||||||||||||||||
| concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition) | ||||||||||||||||||
|
|
||||||||||||||||||
| critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3 | ||||||||||||||||||
|
|
||||||||||||||||||
| critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( | ||||||||||||||||||
| redshift_0=redshift_object, | ||||||||||||||||||
| redshift_1=redshift_source, | ||||||||||||||||||
| xp=xp, | ||||||||||||||||||
| ) # Msun / kpc^2 | ||||||||||||||||||
|
|
||||||||||||||||||
| kpc_per_arcsec = cosmology.kpc_per_arcsec_from( | ||||||||||||||||||
| redshift=redshift_object, xp=np | ||||||||||||||||||
| ) # kpc / arcsec | ||||||||||||||||||
| kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec | ||||||||||||||||||
|
|
||||||||||||||||||
| if overdens == 0: | ||||||||||||||||||
| x = cosmology.Om(redshift_object, xp=np) - 1.0 | ||||||||||||||||||
| overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998) | ||||||||||||||||||
| # Bryan & Norman (1998) overdensity if overdens == 0 | ||||||||||||||||||
| x = cosmology.Om(redshift_object, xp=xp) - 1.0 | ||||||||||||||||||
| overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 | ||||||||||||||||||
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) | ||||||||||||||||||
|
||||||||||||||||||
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) | |
| # Use a Python scalar branch for NumPy scalars to avoid returning 0-d arrays; | |
| # keep xp.where for JAX and array inputs. | |
| if is_jax(overdens) or not np.isscalar(overdens): | |
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) | |
| else: | |
| if overdens == 0: | |
| overdens = overdens_bn98 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_jaxcatchesExceptionbroadly, which can mask real import/runtime errors inside JAX and silently fall back to NumPy. It would be safer to catch onlyImportError(and potentiallyAttributeErrorfor older JAX types) so unexpected failures don’t change numerical backends without warning.