Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion autofit/mapper/prior_model/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,6 @@ def instance_for_arguments(
):
self.check_assertions(arguments)

# logger.debug(f"Creating an instance for arguments")
return self._instance_for_arguments(
arguments,
ignore_assertions=ignore_assertions,
Expand Down
8 changes: 6 additions & 2 deletions autofit/mapper/prior_model/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from autoconf.dictable import from_dict
from .abstract import AbstractPriorModel
from autofit.mapper.prior.abstract import Prior
from autofit.jax_wrapper import numpy as jnp, use_jax
import numpy as np

from autofit.jax_wrapper import register_pytree_node_class
Expand Down Expand Up @@ -76,7 +77,7 @@ def _instance_for_arguments(
-------
The array with the priors replaced.
"""
array = np.zeros(self.shape)
array = jnp.zeros(self.shape)
for index in self.indices:
value = self[index]
try:
Expand All @@ -87,7 +88,10 @@ def _instance_for_arguments(
except AttributeError:
pass

array[index] = value
if use_jax:
array = array.at[index].set(value)
else:
array[index] = value
return array

def __setitem__(
Expand Down
Loading