|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +from jax import ShapeDtypeStruct |
1 | 4 | import numpy as np |
2 | 5 | import warnings |
3 | 6 |
|
@@ -56,63 +59,147 @@ def kappa_s_and_scale_radius_for_duffy(mass_at_200, redshift_object, redshift_so |
56 | 59 | return kappa_s, scale_radius, radius_at_200 |
57 | 60 |
|
58 | 61 |
|
59 | | -def kappa_s_and_scale_radius_for_ludlow( |
60 | | - mass_at_200, scatter_sigma, redshift_object, redshift_source |
| 62 | +def _ludlow16_cosmology_callback( |
| 63 | + mass_at_200, |
| 64 | + redshift_object, |
| 65 | + redshift_source, |
61 | 66 | ): |
62 | 67 | """ |
63 | | - Computes the AutoGalaxy NFW parameters (kappa_s, scale_radius) for an NFW halo of the given |
64 | | - mass, enforcing the Ludlow '16 mass-concentration relation. |
65 | | -
|
66 | | - Interprets mass as *`M_{200c}`*, not `M_{200m}`. |
| 68 | + Pure NumPy / Python function. |
| 69 | + Must NEVER see JAX tracers. |
67 | 70 | """ |
| 71 | + |
| 72 | + import numpy as np |
68 | 73 | from astropy import units |
69 | 74 | from colossus.cosmology import cosmology as col_cosmology |
70 | 75 | from colossus.halo.concentration import concentration as col_concentration |
71 | | - |
72 | | - warnings.filterwarnings("ignore") |
73 | | - |
74 | 76 | from autogalaxy.cosmology.wrap import Planck15 |
75 | 77 |
|
76 | | - cosmology = Planck15() |
77 | | - |
| 78 | + # ----------------------- |
| 79 | + # Colossus cosmology |
| 80 | + # ----------------------- |
78 | 81 | col_cosmo = col_cosmology.setCosmology("planck15") |
| 82 | + |
79 | 83 | m_input = mass_at_200 * col_cosmo.h |
80 | 84 | concentration = col_concentration( |
81 | | - m_input, "200c", redshift_object, model="ludlow16" |
| 85 | + m_input, |
| 86 | + "200c", |
| 87 | + redshift_object, |
| 88 | + model="ludlow16", |
82 | 89 | ) |
83 | 90 |
|
84 | | - concentration = 10.0 ** (np.log10(concentration) + scatter_sigma * 0.15) |
| 91 | + # ----------------------- |
| 92 | + # Astropy cosmology |
| 93 | + # ----------------------- |
| 94 | + cosmology = Planck15() |
85 | 95 |
|
86 | 96 | cosmic_average_density = ( |
87 | | - cosmology.critical_density(redshift_object).to(units.solMass / units.kpc**3) |
88 | | - ).value |
| 97 | + cosmology.critical_density(redshift_object) |
| 98 | + .to(units.solMass / units.kpc**3) |
| 99 | + .value |
| 100 | + ) |
89 | 101 |
|
90 | 102 | critical_surface_density = ( |
91 | 103 | cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( |
92 | | - redshift_0=redshift_object, redshift_1=redshift_source |
| 104 | + redshift_0=redshift_object, |
| 105 | + redshift_1=redshift_source, |
93 | 106 | ) |
94 | 107 | ) |
95 | 108 |
|
96 | 109 | kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object) |
97 | 110 |
|
| 111 | + return ( |
| 112 | + np.float64(concentration), |
| 113 | + np.float64(cosmic_average_density), |
| 114 | + np.float64(critical_surface_density), |
| 115 | + np.float64(kpc_per_arcsec), |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +def ludlow16_cosmology_jax( |
| 120 | + mass_at_200, |
| 121 | + redshift_object, |
| 122 | + redshift_source, |
| 123 | +): |
| 124 | + """ |
| 125 | + JAX-safe wrapper around Colossus + Astropy cosmology. |
| 126 | + """ |
| 127 | + |
| 128 | + return jax.pure_callback( |
| 129 | + _ludlow16_cosmology_callback, |
| 130 | + ( |
| 131 | + ShapeDtypeStruct((), jnp.float64), # concentration |
| 132 | + ShapeDtypeStruct((), jnp.float64), # rho_crit(z) |
| 133 | + ShapeDtypeStruct((), jnp.float64), # Sigma_crit |
| 134 | + ShapeDtypeStruct((), jnp.float64), # kpc/arcsec |
| 135 | + ), |
| 136 | + mass_at_200, |
| 137 | + redshift_object, |
| 138 | + redshift_source, |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +def kappa_s_and_scale_radius_for_ludlow( |
| 143 | + mass_at_200, |
| 144 | + scatter_sigma, |
| 145 | + redshift_object, |
| 146 | + redshift_source, |
| 147 | +): |
| 148 | + |
| 149 | + if isinstance(mass_at_200, (float, np.ndarray, np.float64)): |
| 150 | + xp = np |
| 151 | + else: |
| 152 | + xp = jnp |
| 153 | + |
| 154 | + # ------------------------------------ |
| 155 | + # Cosmology + concentration (callback) |
| 156 | + # ------------------------------------ |
| 157 | + |
| 158 | + if xp is np: |
| 159 | + ( |
| 160 | + concentration, |
| 161 | + cosmic_average_density, |
| 162 | + critical_surface_density, |
| 163 | + kpc_per_arcsec, |
| 164 | + ) = _ludlow16_cosmology_callback( |
| 165 | + mass_at_200, |
| 166 | + redshift_object, |
| 167 | + redshift_source, |
| 168 | + ) |
| 169 | + else: |
| 170 | + ( |
| 171 | + concentration, |
| 172 | + cosmic_average_density, |
| 173 | + critical_surface_density, |
| 174 | + kpc_per_arcsec, |
| 175 | + ) = ludlow16_cosmology_jax( |
| 176 | + mass_at_200, |
| 177 | + redshift_object, |
| 178 | + redshift_source, |
| 179 | + ) |
| 180 | + |
| 181 | + # Apply scatter (JAX-safe) |
| 182 | + concentration = 10.0 ** (xp.log10(concentration) + scatter_sigma * 0.15) |
| 183 | + |
| 184 | + # ------------------------------------ |
| 185 | + # JAX-native algebra |
| 186 | + # ------------------------------------ |
98 | 187 | radius_at_200 = ( |
99 | | - mass_at_200 / (200.0 * cosmic_average_density * (4.0 * np.pi / 3.0)) |
100 | | - ) ** ( |
101 | | - 1.0 / 3.0 |
102 | | - ) # r200 |
| 188 | + mass_at_200 / (200.0 * cosmic_average_density * (4.0 * xp.pi / 3.0)) |
| 189 | + ) ** (1.0 / 3.0) |
103 | 190 |
|
104 | 191 | de_c = ( |
105 | 192 | 200.0 |
106 | 193 | / 3.0 |
107 | 194 | * ( |
108 | 195 | concentration**3 |
109 | | - / (np.log(1.0 + concentration) - concentration / (1.0 + concentration)) |
| 196 | + / (xp.log(1.0 + concentration) - concentration / (1.0 + concentration)) |
110 | 197 | ) |
111 | | - ) # rho_c |
| 198 | + ) |
112 | 199 |
|
113 | | - scale_radius_kpc = radius_at_200 / concentration # scale radius in kpc |
114 | | - rho_s = cosmic_average_density * de_c # rho_s |
115 | | - kappa_s = rho_s * scale_radius_kpc / critical_surface_density # kappa_s |
116 | | - scale_radius = scale_radius_kpc / kpc_per_arcsec # scale radius in arcsec |
| 200 | + scale_radius_kpc = radius_at_200 / concentration |
| 201 | + rho_s = cosmic_average_density * de_c |
| 202 | + kappa_s = rho_s * scale_radius_kpc / critical_surface_density |
| 203 | + scale_radius = scale_radius_kpc / kpc_per_arcsec |
117 | 204 |
|
118 | 205 | return kappa_s, scale_radius, radius_at_200 |
0 commit comments