Skip to content

Commit 7a8a31e

Browse files
Jammy2211Jammy2211
authored andcommitted
callback implemented and tested
1 parent 8ab428f commit 7a8a31e

File tree

1 file changed

+113
-26
lines changed

1 file changed

+113
-26
lines changed

autogalaxy/profiles/mass/dark/mcr_util.py

Lines changed: 113 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from jax import ShapeDtypeStruct
14
import numpy as np
25
import warnings
36

@@ -56,63 +59,147 @@ def kappa_s_and_scale_radius_for_duffy(mass_at_200, redshift_object, redshift_so
5659
return kappa_s, scale_radius, radius_at_200
5760

5861

59-
def kappa_s_and_scale_radius_for_ludlow(
60-
mass_at_200, scatter_sigma, redshift_object, redshift_source
62+
def _ludlow16_cosmology_callback(
63+
mass_at_200,
64+
redshift_object,
65+
redshift_source,
6166
):
6267
"""
63-
Computes the AutoGalaxy NFW parameters (kappa_s, scale_radius) for an NFW halo of the given
64-
mass, enforcing the Ludlow '16 mass-concentration relation.
65-
66-
Interprets mass as *`M_{200c}`*, not `M_{200m}`.
68+
Pure NumPy / Python function.
69+
Must NEVER see JAX tracers.
6770
"""
71+
72+
import numpy as np
6873
from astropy import units
6974
from colossus.cosmology import cosmology as col_cosmology
7075
from colossus.halo.concentration import concentration as col_concentration
71-
72-
warnings.filterwarnings("ignore")
73-
7476
from autogalaxy.cosmology.wrap import Planck15
7577

76-
cosmology = Planck15()
77-
78+
# -----------------------
79+
# Colossus cosmology
80+
# -----------------------
7881
col_cosmo = col_cosmology.setCosmology("planck15")
82+
7983
m_input = mass_at_200 * col_cosmo.h
8084
concentration = col_concentration(
81-
m_input, "200c", redshift_object, model="ludlow16"
85+
m_input,
86+
"200c",
87+
redshift_object,
88+
model="ludlow16",
8289
)
8390

84-
concentration = 10.0 ** (np.log10(concentration) + scatter_sigma * 0.15)
91+
# -----------------------
92+
# Astropy cosmology
93+
# -----------------------
94+
cosmology = Planck15()
8595

8696
cosmic_average_density = (
87-
cosmology.critical_density(redshift_object).to(units.solMass / units.kpc**3)
88-
).value
97+
cosmology.critical_density(redshift_object)
98+
.to(units.solMass / units.kpc**3)
99+
.value
100+
)
89101

90102
critical_surface_density = (
91103
cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
92-
redshift_0=redshift_object, redshift_1=redshift_source
104+
redshift_0=redshift_object,
105+
redshift_1=redshift_source,
93106
)
94107
)
95108

96109
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object)
97110

111+
return (
112+
np.float64(concentration),
113+
np.float64(cosmic_average_density),
114+
np.float64(critical_surface_density),
115+
np.float64(kpc_per_arcsec),
116+
)
117+
118+
119+
def ludlow16_cosmology_jax(
120+
mass_at_200,
121+
redshift_object,
122+
redshift_source,
123+
):
124+
"""
125+
JAX-safe wrapper around Colossus + Astropy cosmology.
126+
"""
127+
128+
return jax.pure_callback(
129+
_ludlow16_cosmology_callback,
130+
(
131+
ShapeDtypeStruct((), jnp.float64), # concentration
132+
ShapeDtypeStruct((), jnp.float64), # rho_crit(z)
133+
ShapeDtypeStruct((), jnp.float64), # Sigma_crit
134+
ShapeDtypeStruct((), jnp.float64), # kpc/arcsec
135+
),
136+
mass_at_200,
137+
redshift_object,
138+
redshift_source,
139+
)
140+
141+
142+
def kappa_s_and_scale_radius_for_ludlow(
143+
mass_at_200,
144+
scatter_sigma,
145+
redshift_object,
146+
redshift_source,
147+
):
148+
149+
if isinstance(mass_at_200, (float, np.ndarray, np.float64)):
150+
xp = np
151+
else:
152+
xp = jnp
153+
154+
# ------------------------------------
155+
# Cosmology + concentration (callback)
156+
# ------------------------------------
157+
158+
if xp is np:
159+
(
160+
concentration,
161+
cosmic_average_density,
162+
critical_surface_density,
163+
kpc_per_arcsec,
164+
) = _ludlow16_cosmology_callback(
165+
mass_at_200,
166+
redshift_object,
167+
redshift_source,
168+
)
169+
else:
170+
(
171+
concentration,
172+
cosmic_average_density,
173+
critical_surface_density,
174+
kpc_per_arcsec,
175+
) = ludlow16_cosmology_jax(
176+
mass_at_200,
177+
redshift_object,
178+
redshift_source,
179+
)
180+
181+
# Apply scatter (JAX-safe)
182+
concentration = 10.0 ** (xp.log10(concentration) + scatter_sigma * 0.15)
183+
184+
# ------------------------------------
185+
# JAX-native algebra
186+
# ------------------------------------
98187
radius_at_200 = (
99-
mass_at_200 / (200.0 * cosmic_average_density * (4.0 * np.pi / 3.0))
100-
) ** (
101-
1.0 / 3.0
102-
) # r200
188+
mass_at_200 / (200.0 * cosmic_average_density * (4.0 * xp.pi / 3.0))
189+
) ** (1.0 / 3.0)
103190

104191
de_c = (
105192
200.0
106193
/ 3.0
107194
* (
108195
concentration**3
109-
/ (np.log(1.0 + concentration) - concentration / (1.0 + concentration))
196+
/ (xp.log(1.0 + concentration) - concentration / (1.0 + concentration))
110197
)
111-
) # rho_c
198+
)
112199

113-
scale_radius_kpc = radius_at_200 / concentration # scale radius in kpc
114-
rho_s = cosmic_average_density * de_c # rho_s
115-
kappa_s = rho_s * scale_radius_kpc / critical_surface_density # kappa_s
116-
scale_radius = scale_radius_kpc / kpc_per_arcsec # scale radius in arcsec
200+
scale_radius_kpc = radius_at_200 / concentration
201+
rho_s = cosmic_average_density * de_c
202+
kappa_s = rho_s * scale_radius_kpc / critical_surface_density
203+
scale_radius = scale_radius_kpc / kpc_per_arcsec
117204

118205
return kappa_s, scale_radius, radius_at_200

0 commit comments

Comments
 (0)