From fab90f0fa1762ef644156604e3af84a4c7ee714f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 12 Jun 2025 15:25:22 +0100 Subject: [PATCH] gradient requires sue jax --- autofit/non_linear/search/nest/dynesty/search/static.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/search/nest/dynesty/search/static.py b/autofit/non_linear/search/nest/dynesty/search/static.py index f5c221881..7beade101 100644 --- a/autofit/non_linear/search/nest/dynesty/search/static.py +++ b/autofit/non_linear/search/nest/dynesty/search/static.py @@ -6,6 +6,8 @@ from dynesty import NestedSampler as StaticSampler from autofit.database.sqlalchemy_ import sa +from autofit import jax_wrapper + from autofit.mapper.prior_model.abstract import AbstractPriorModel from .abstract import AbstractDynesty, prior_transform @@ -109,6 +111,8 @@ def search_internal_from( in the dynesty queue for samples. """ + gradient = fitness.grad if self.use_gradient else None + if checkpoint_exists: search_internal = StaticSampler.restore( fname=self.checkpoint_file, pool=pool @@ -127,7 +131,7 @@ def search_internal_from( self.write_uses_pool(uses_pool=True) return StaticSampler( loglikelihood=pool.loglike, - gradient=fitness.grad, + gradient=gradient, prior_transform=pool.prior_transform, ndim=model.prior_count, live_points=live_points, @@ -139,7 +143,7 @@ def search_internal_from( self.write_uses_pool(uses_pool=False) return StaticSampler( loglikelihood=fitness, - gradient=fitness.grad, + gradient=gradient, prior_transform=prior_transform, ndim=model.prior_count, logl_args=[model, fitness],