From 908663c163143b5986b23a8fb46f9ccbfc8f28f0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 7 Apr 2026 17:17:43 +0100 Subject: [PATCH] fix: PYAUTO_DISABLE_JAX env var was being overridden by stale _use_jax assignment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AnalysisDataset.__init__ had `self._use_jax = use_jax` at line 112 which overwrote the env var check in Analysis.__init__. This meant PYAUTO_DISABLE_JAX=1 had no effect — JAX JIT compilation still ran, adding ~3.5s of overhead per likelihood evaluation. Fix: apply the env var check at the top of AnalysisDataset.__init__ before passing use_jax to super(). Remove the stale override (its comment said "Can be deleted after relevant AutoFIT PR merged" — PR #1184 is now merged). Co-Authored-By: Claude Opus 4.6 (1M context) --- autolens/analysis/analysis/dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/autolens/analysis/analysis/dataset.py b/autolens/analysis/analysis/dataset.py index a74b9d300..41bb80847 100644 --- a/autolens/analysis/analysis/dataset.py +++ b/autolens/analysis/analysis/dataset.py @@ -84,6 +84,11 @@ def __init__( anyway. """ + import os + + if os.environ.get("PYAUTO_DISABLE_JAX") == "1": + use_jax = False + super().__init__( dataset=dataset, adapt_images=adapt_images, @@ -108,9 +113,6 @@ def __init__( if is_test_mode(): self.raise_inversion_positions_likelihood_exception = False - # Can be deleted after relevent AutoFIT PR merged - self._use_jax = use_jax - def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection): """ This function is called immediately before the non-linear search begins and performs final tasks and checks