From aa61e5078634231ac5e8cb93fa038c9d7faf37a4 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Mar 2025 16:37:56 +0000 Subject: [PATCH] proper jax handling in Array --- autofit/mapper/prior_model/abstract.py | 1 - autofit/mapper/prior_model/array.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index c79060b04..a2659aa04 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1305,7 +1305,6 @@ def instance_for_arguments( ): self.check_assertions(arguments) - # logger.debug(f"Creating an instance for arguments") return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, diff --git a/autofit/mapper/prior_model/array.py b/autofit/mapper/prior_model/array.py index c37c786b2..3ad71aec2 100644 --- a/autofit/mapper/prior_model/array.py +++ b/autofit/mapper/prior_model/array.py @@ -3,6 +3,7 @@ from autoconf.dictable import from_dict from .abstract import AbstractPriorModel from autofit.mapper.prior.abstract import Prior +from autofit.jax_wrapper import numpy as jnp, use_jax import numpy as np from autofit.jax_wrapper import register_pytree_node_class @@ -76,7 +77,7 @@ def _instance_for_arguments( ------- The array with the priors replaced. """ - array = np.zeros(self.shape) + array = jnp.zeros(self.shape) for index in self.indices: value = self[index] try: @@ -87,7 +88,10 @@ def _instance_for_arguments( except AttributeError: pass - array[index] = value + if use_jax: + array = array.at[index].set(value) + else: + array[index] = value return array def __setitem__(