Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,25 @@ These inherit from `AnalysisDataset` → `Analysis` (in `analysis/analysis/`), w

### JAX Support

JAX is integrated via the `xp` parameter pattern throughout the codebase. Fit classes accept `xp=np` (NumPy, default) or `xp=jnp` (JAX). The `AbstractFitInversion.use_jax` property tracks which backend is active. The `AnalysisImaging.__init__` has `use_jax: bool = True`. The conftest.py forces JAX backend initialization before tests run.
The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.

The `xp` parameter pattern is the single point of control:
- `xp=np` (default throughout) — pure NumPy path, no JAX dependency at runtime
- `xp=jnp` — JAX path, imports `jax` / `jax.numpy` locally inside the function

This means:
- **Unit tests** (`test_autogalaxy/`) always run on the NumPy path. No test should import JAX or pass `xp=jnp` unless it is explicitly testing the JAX path.
- **Integration tests** (in `autogalaxy_workspace_test/`) are where the JAX path is exercised, typically wrapped in `jax.jit` to test both correctness and compilation.
- `conftest.py` forces JAX backend initialisation before the test suite runs, but this only ensures JAX is available — it does not switch the default backend.

`AbstractFitInversion.use_jax` tracks whether a fit was constructed with JAX. `AnalysisImaging` has `use_jax: bool = True` to opt into the JAX path for model-fitting.

When adding a new function that should support JAX:
1. Default the parameter to `xp=np`
2. Guard any JAX imports with `if xp is not np:` and import `jax` / `jax.numpy` locally inside that branch
3. Add the NumPy implementation as the default path (finite-difference, `np.*` calls, etc.)
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`

### Linear Light Profiles & Inversions

Expand All @@ -101,6 +119,10 @@ Default priors, visualization settings, and general config live in `autogalaxy/c

Both are mixin classes inherited by `LightProfile`, `MassProfile`, `Galaxy`, and `Galaxies`.

### Workspace Script Style

Scripts in `autogalaxy_workspace` and `autogalaxy_workspace_test` use `"""..."""` docstring blocks as prose commentary throughout — **not** `#` comments. Every script opens with a module-level docstring (title + underline + description), and each logical section of code is preceded by a `"""..."""` block with a `__Section Name__` header explaining what follows. See any script in `autogalaxy_workspace/scripts/` for examples of this style.

### Workspace (Examples & Notebooks)

The `autogalaxy_workspace` at `/mnt/c/Users/Jammy/Code/PyAutoJAX/autogalaxy_workspace` contains runnable examples and tutorials. Key locations:
Expand Down
2 changes: 1 addition & 1 deletion autogalaxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .operate.image import OperateImage
from .operate.image import OperateImageList
from .operate.image import OperateImageGalaxies
from .operate.deflections import OperateDeflections
from .operate.lens_calc import LensCalc
from .gui.scribbler import Scribbler
from .imaging.fit_imaging import FitImaging
from .imaging.model.analysis import AnalysisImaging
Expand Down
18 changes: 5 additions & 13 deletions autogalaxy/analysis/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,14 @@ def mge_point_model_from(
# and twice the pixel scale, with a floor to avoid taking log10 of
# very small or non-positive values.
min_log10_sigma = -2.0 # corresponds to 0.01 arcsec
max_sigma = max(2.0 * pixel_scales, 10 ** min_log10_sigma)
max_sigma = max(2.0 * pixel_scales, 10**min_log10_sigma)
max_log10_sigma = np.log10(max_sigma)

log10_sigma_list = np.linspace(
min_log10_sigma, max_log10_sigma, total_gaussians
)
centre_0 = af.UniformPrior(
lower_limit=centre[0] - 0.1, upper_limit=centre[0] + 0.1
)
centre_1 = af.UniformPrior(
lower_limit=centre[1] - 0.1, upper_limit=centre[1] + 0.1
)
log10_sigma_list = np.linspace(min_log10_sigma, max_log10_sigma, total_gaussians)
centre_0 = af.UniformPrior(lower_limit=centre[0] - 0.1, upper_limit=centre[0] + 0.1)
centre_1 = af.UniformPrior(lower_limit=centre[1] - 0.1, upper_limit=centre[1] + 0.1)

gaussian_list = af.Collection(
af.Model(Gaussian) for _ in range(total_gaussians)
)
gaussian_list = af.Collection(af.Model(Gaussian) for _ in range(total_gaussians))

for i, gaussian in enumerate(gaussian_list):
gaussian.centre.centre_0 = centre_0
Expand Down
3 changes: 1 addition & 2 deletions autogalaxy/galaxy/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from autogalaxy.profiles.basis import Basis
from autogalaxy.profiles.light.linear import LightProfileLinear
from autogalaxy.operate.image import OperateImageGalaxies
from autogalaxy.operate.deflections import OperateDeflections


class Galaxies(List, OperateImageGalaxies, OperateDeflections):
class Galaxies(List, OperateImageGalaxies):
def __init__(
self,
galaxies: List[Galaxy],
Expand Down
3 changes: 1 addition & 2 deletions autogalaxy/galaxy/galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import autofit as af

from autogalaxy import exc
from autogalaxy.operate.deflections import OperateDeflections
from autogalaxy.operate.image import OperateImageList
from autogalaxy.profiles.geometry_profiles import GeometryProfile
from autogalaxy.profiles.light.abstract import LightProfile
Expand All @@ -17,7 +16,7 @@
from autogalaxy.profiles.mass.abstract.abstract import MassProfile


class Galaxy(af.ModelObject, OperateImageList, OperateDeflections):
class Galaxy(af.ModelObject, OperateImageList):
"""
@DynamicAttrs
"""
Expand Down
Loading
Loading