Skip to content

Commit f167f1a

Browse files
committed
fix flux fit by pasing .,array
1 parent 815d429 commit f167f1a

File tree

3 files changed

+4
-24
lines changed

3 files changed

+4
-24
lines changed

autolens/point/fit/fluxes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,5 @@ def chi_squared(self) -> float:
128128
RMS noise-map values squared.
129129
"""
130130
return ag.util.fit.chi_squared_from(
131-
chi_squared_map=self.chi_squared_map,
131+
chi_squared_map=self.chi_squared_map.array,
132132
)

autolens/point/solver/point_solver.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import autoarray as aa
55
from autoarray.structures.triangles.shape import Point
66

7-
from autofit.jax_wrapper import jit, register_pytree_node_class
7+
from autofit.jax_wrapper import register_pytree_node_class
88
from autogalaxy import OperateDeflections
99
from .shape_solver import AbstractSolver
1010

@@ -14,7 +14,7 @@
1414

1515
@register_pytree_node_class
1616
class PointSolver(AbstractSolver):
17-
@jit
17+
1818
def solve(
1919
self,
2020
tracer: OperateDeflections,
@@ -55,21 +55,3 @@ def solve(
5555
)
5656

5757
return aa.Grid2DIrregular([pair for pair in filtered_means])
58-
59-
# filtered_means = [
60-
# pair for pair in filtered_means if not np.any(np.isnan(pair)).all()
61-
# ]
62-
#
63-
# difference = len(kept_triangles.means) - len(filtered_means)
64-
# if difference > 0:
65-
# logger.debug(
66-
# f"Filtered one multiple-image with magnification below threshold."
67-
# )
68-
# elif difference > 1:
69-
# logger.warning(
70-
# f"Filtered {difference} multiple-images with magnification below threshold."
71-
# )
72-
#
73-
# return aa.Grid2DIrregular(
74-
# [pair for pair in filtered_means if not np.isnan(pair).all()]
75-
# )

autolens/point/solver/shape_solver.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import jax.numpy as jnp
2-
from jax import jit
32
import logging
43
import math
54

@@ -208,7 +207,6 @@ def _source_plane_grid(
208207
# noinspection PyTypeChecker
209208
return grid.grid_2d_via_deflection_grid_from(deflection_grid=deflections)
210209

211-
@jit
212210
def solve_triangles(
213211
self,
214212
tracer: OperateDeflections,
@@ -270,7 +268,7 @@ def _filter_low_magnification(
270268
"""
271269
points = jnp.array(points)
272270
magnifications = tracer.magnification_2d_via_hessian_from(
273-
grid=aa.Grid2DIrregular(points),
271+
grid=aa.Grid2DIrregular(points).array,
274272
buffer=self.scale,
275273
)
276274
mask = jnp.abs(magnifications.array) > self.magnification_threshold

0 commit comments

Comments
 (0)