diff --git a/bdsf/rmsimage.py b/bdsf/rmsimage.py index e73567e..157ce5e 100644 --- a/bdsf/rmsimage.py +++ b/bdsf/rmsimage.py @@ -27,6 +27,50 @@ pass from .functions import read_image_from_file +from concurrent.futures import ThreadPoolExecutor +import numpy +from scipy import ndimage + + +def mapcoord_threaded(a, axs, *args, ncores=8, **kwargs): + """Threaded map_coordinates on cartesian coordinate grid (meshgrid) + + :param a: Array to be regridded + + :param axs: List of axes onto which to regrid. Result is gridded + to equivalent of meshgrid(*axs) + + """ + output = kwargs.get("output", None) + kwargs["output"] = None + # Prefilter only once to awoid repeated work in the workers. See + # _interpolation.py in scipy + order = kwargs.get("order", 3) + if order > 1: + a = ndimage.spline_filter(a, order, + output=numpy.float64, + mode=kwargs.get("mode", "constant")) + + def tworker(cl1): + # Construct the subset of meshgrid to which worker has been + # applied. + # NB: The axis reversal is specific to this program. The indexing parameter + # does not do exactly the same thing + cl = numpy.meshgrid( * ([cl1]+axs[1:]))[-1::-1] + return ndimage.map_coordinates(a, + cl, + # NB we pulled the pre-filter out + prefilter=False, + *args, **kwargs) + + with ThreadPoolExecutor(max_workers=ncores) as te: + res=te.map(tworker, + axs[0]) + res=numpy.hstack(list(res)) + if output is not None: + numpy.copyto(output, res) + return res + class Op_rmsimage(Op): """Calculate rms & noise maps @@ -555,8 +599,7 @@ def map_2d(self, arr, out_mean, out_rms, mask=False, """ mask_small = mask axes, mean_map1, rms_map1 = self.rms_mean_map(arr, mask_small, kappa, box, ncores) - ax = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(arr.shape, axes)] - ax = N.meshgrid(*ax[-1::-1]) + ax = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(arr.shape, axes)][-1::-1] pt_src_scale = box[0] if do_adapt: out_rms2 = N.zeros(rms_map1.shape, dtype=N.float32) @@ -569,10 +612,9 @@ def map_2d(self, arr, out_mean, out_rms, mask=False, axes2mod = axes2[:] axes2mod[0] = axes2[0]/arr.shape[0]*mean_map1.shape[0] axes2mod[1] = axes2[1]/arr.shape[1]*mean_map1.shape[1] - ax2 = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(out_rms2.shape, axes2mod)] - ax2 = N.meshgrid(*ax2[-1::-1]) - nd.map_coordinates(rms_map2, ax2[-1::-1], order=interp, output=out_rms2) - nd.map_coordinates(mean_map2, ax2[-1::-1], order=interp, output=out_mean2) + ax2 = [self.remap_axis(ashp, axv) for (ashp, axv) in zip(out_rms2.shape, axes2mod)][-1::-1] + mapcoord_threaded(rms_map2, ax2, order=interp, output=out_rms2, ncores=ncores) + mapcoord_threaded(mean_map2, ax2, order=interp, output=out_mean2, ncores=ncores) rms_map = out_rms2 mean_map = out_mean2 @@ -615,8 +657,8 @@ def map_2d(self, arr, out_mean, out_rms, mask=False, # Interpolate to image coords mylog = mylogger.logging.getLogger(logname+"Rmsimage") - nd.map_coordinates(rms_map, ax[-1::-1], order=interp, output=out_rms) - nd.map_coordinates(mean_map, ax[-1::-1], order=interp, output=out_mean) + mapcoord_threaded(rms_map, ax, order=interp, output=out_rms, ncores=ncores) + mapcoord_threaded(mean_map, ax, order=interp, output=out_mean, ncores=ncores) # Apply mask to mean_map and rms_map by setting masked values to NaN if isinstance(mask, N.ndarray):