diff --git a/bdsf/gaul2srl.py b/bdsf/gaul2srl.py index 8dbeafa..b28bde7 100644 --- a/bdsf/gaul2srl.py +++ b/bdsf/gaul2srl.py @@ -369,26 +369,60 @@ def process_Multiple(self, img, g_sublist, mask, src_index, isrc, subim, isl, de # posn from gaussian fit instead. if N.isnan(mompara[1]): mompara[1] = posn[0] - delc[0] - x1 = int(N.floor(mompara[1])) if N.isnan(mompara[2]): mompara[2] = posn[1] - delc[1] - y1 = int(N.floor(mompara[2])) - xind = slice(x1, x1+2, 1); yind = slice(y1, y1+2, 1) - if img.opts.flag_smallsrc and (N.sum(mask[xind, yind]==N.ones((2,2))*isrc) != 4): + + interp_x = float(mompara[1]) + interp_y = float(mompara[2]) + orig_interp_x = interp_x + orig_interp_y = interp_y + can_bilinear = subim_src.shape[0] >= 2 and subim_src.shape[1] >= 2 + if can_bilinear: + interp_x = min(max(interp_x, 0.0), subim_src.shape[0] - 1.0) + interp_y = min(max(interp_y, 0.0), subim_src.shape[1] - 1.0) + x1 = int(N.floor(interp_x)) + y1 = int(N.floor(interp_y)) + x1 = min(max(x1, 0), subim_src.shape[0] - 2) + y1 = min(max(y1, 0), subim_src.shape[1] - 2) + xind = slice(x1, x1+2, 1) + yind = slice(y1, y1+2, 1) + mask_patch = mask[xind, yind] + t = interp_x - x1 + u = interp_y - y1 + else: + x1 = min(max(int(N.round(interp_x)), 0), subim_src.shape[0] - 1) + y1 = min(max(int(N.round(interp_y)), 0), subim_src.shape[1] - 1) + xind = slice(x1, x1+1, 1) + yind = slice(y1, y1+1, 1) + mask_patch = mask[xind, yind] + t = 0.0 + u = 0.0 + + if (orig_interp_x != interp_x) or (orig_interp_y != interp_y): + mylog.debug('Clipped source centroid for interpolation from ' + + repr((orig_interp_x, orig_interp_y)) + ' to ' + + repr((interp_x, interp_y)) + ' in island ' + + str(isl.island_id)) + + patch_is_full = mask_patch.shape == (2, 2) and N.all(mask_patch == isrc) + if img.opts.flag_smallsrc and (not patch_is_full): mylog.debug('Island = '+str(isl.island_id)) - mylog.debug('Mask = '+repr(mask[xind, yind])+'xind, yind, x1, y1 = '+repr(xind)+' '+repr(yind)+' '+repr(x1)+' '+repr(y1)) - t=(mompara[1]-x1)/(x1+1-x1) # in case u change it later - u=(mompara[2]-y1)/(y1+1-y1) - try: - s_peak=((1.0-t)*(1.0-u)*subim_src[x1,y1]+ - t*(1.0-u)*subim_src[x1+1,y1]+ - t*u*subim_src[x1+1,y1+1]+ - (1.0-t)*u*subim_src[x1,y1+1]) - except IndexError: - # interpolation failed because source is too small - # probably pathological, take a guess.. - s_peak=subim_src[x1,y1] - if (not img.opts.flag_smallsrc) and (N.sum(mask[xind, yind]==N.ones((2,2))*isrc) != 4): + mylog.debug('Mask = '+repr(mask_patch)+'xind, yind, x1, y1 = ' + + repr(xind)+' '+repr(yind)+' '+repr(x1)+' '+repr(y1)) + if can_bilinear: + try: + s_peak=((1.0-t)*(1.0-u)*subim_src[x1,y1]+ + t*(1.0-u)*subim_src[x1+1,y1]+ + t*u*subim_src[x1+1,y1+1]+ + (1.0-t)*u*subim_src[x1,y1+1]) + except IndexError: + # Interpolation failed because source lies on the border after + # clipping the centroid into the available image support. + s_peak = subim_src[x1, y1] + else: + # Degenerate 1-pixel support cannot provide a 2x2 interpolation patch. + s_peak = subim_src[x1, y1] + if (not img.opts.flag_smallsrc) and (not patch_is_full): mylog.debug('Speak '+repr(s_peak)+'Mompara = '+repr(mompara)) mylog.debug('x1, y1 : '+repr(x1)+', '+repr(y1)) diff --git a/bdsf/gausfit.py b/bdsf/gausfit.py index c7b6e71..12b041e 100644 --- a/bdsf/gausfit.py +++ b/bdsf/gausfit.py @@ -34,6 +34,57 @@ class Op_gausfit(Op): Prerequisites: module islands should be run first. """ + def _estimate_island_dispatch_cost(self, isl, img, opts, peak_size, maxsize): + """Estimate a stable, cheap dispatch weight for one island. + + Keep this model intentionally conservative: prefer low-cost, + cross-dataset features that are already available before fitting. + """ + size_active = float(max(1, getattr(isl, 'size_active', 1))) + beamarea = img.pixel_beamarea() + size_beams = size_active/beamarea*2.0 if beamarea else size_active + + shape = tuple(getattr(isl, 'shape', (0, 0))) + if len(shape) < 2: + shape = (0, 0) + bbox_area = int(shape[0] * shape[1]) if shape[0] > 0 and shape[1] > 0 else int(size_active) + short_axis = float(min(shape)) if min(shape) > 0 else 1.0 + aspect_ratio = float(max(shape))/short_axis if max(shape) > 0 else 1.0 + + isl_rms = float(getattr(isl, 'rms', 0.0) or 0.0) + isl_mean = float(getattr(isl, 'islmean', 0.0) or 0.0) + max_value = float(getattr(isl, 'max_value', 0.0) or 0.0) + snr_proxy = 0.0 if isl_rms <= 0.0 else max(0.0, (max_value - isl_mean)/isl_rms) + + predict_iterative = bool(opts.peak_fit and size_beams > peak_size) + predict_split = bool(opts.split_isl and size_beams > maxsize) + + weight = size_active + + # Sparse islands with a large bounding box tend to be costlier than + # compact ones with the same number of active pixels. + weight += max(0.0, bbox_area - size_active) * 0.15 + + # Large-island branches dominate the long tail, so bias the scheduler + # toward spreading them out without trying to model the exact cost. + if predict_iterative: + weight *= 4.0 + if predict_split: + weight *= 3.0 + if predict_iterative and size_beams > 4.0*peak_size: + weight *= 1.5 + if predict_split and size_beams > 2.0*maxsize: + weight *= 1.5 + + # Mild shape/signal penalties help separate obviously pathological + # islands without overfitting the model to a single dataset. + if aspect_ratio > 4.0: + weight *= 1.1 + if snr_proxy > 50.0: + weight *= 1.1 + + return int(max(1.0, round(weight))) + def __call__(self, img): from . import functions as func @@ -73,12 +124,13 @@ def __call__(self, img): img_simple.beam2pix = img.beam2pix img_simple.beam = img.beam - # Next, define the weights to use when distributing islands among cores. - # The weight should scale with the processing time. At the moment - # we use the island area, but other parameters may be better. + # Use a cheap heuristic weight so heavy islands are spread more evenly + # across workers without needing a precise runtime model. weights = [] for isl in img.islands: - weights.append(isl.size_active) + weight = self._estimate_island_dispatch_cost(isl, img, opts, peak_size, + maxsize) + weights.append(weight) # Now call the parallel mapping function. Returns a list of # [gaul, fgaul] for each island. If ncores is 1, use the @@ -942,6 +994,71 @@ class Gaussian(object): extracted gaussians in a structured way. """ + _definitions_initialized = False + + @classmethod + def _initialize_definitions(cls): + if cls._definitions_initialized: + return + + cls.source_id_def = Int(doc="Source index", colname='Source_id') + cls.code_def = String(doc='Source code S, C, or M', colname='S_Code') + cls.gaus_num_def = Int(doc="Serial number of the gaussian for the image", colname='Gaus_id') + cls.island_id_def = Int(doc="Serial number of the island", colname='Isl_id') + cls.flag_def = Int(doc="Flag associated with gaussian", colname='Flag') + cls.total_flux_def = Float(doc="Total flux density, Jy", colname='Total_flux', units='Jy') + cls.total_fluxE_def = Float(doc="Total flux density error, Jy", colname='E_Total_flux', + units='Jy') + cls.peak_flux_def = Float(doc="Peak flux density/beam, Jy/beam", colname='Peak_flux', + units='Jy/beam') + cls.peak_fluxE_def = Float(doc="Peak flux density/beam error, Jy/beam", + colname='E_Peak_flux', units='Jy/beam') + cls.centre_sky_def = List(Float(), doc="Sky coordinates of gaussian centre", + colname=['RA', 'DEC'], units=['deg', 'deg']) + cls.centre_skyE_def = List(Float(), doc="Error on sky coordinates of gaussian centre", + colname=['E_RA', 'E_DEC'], units=['deg', 'deg']) + cls.centre_pix_def = List(Float(), doc="Pixel coordinates of gaussian centre", + colname=['Xposn', 'Yposn'], units=['pix', 'pix']) + cls.centre_pixE_def = List(Float(), doc="Error on pixel coordinates of gaussian centre", + colname=['E_Xposn', 'E_Yposn'], units=['pix', 'pix']) + cls.size_sky_def = List(Float(), doc="Shape of the gaussian FWHM, PA, deg", + colname=['Maj', 'Min', 'PA'], units=['deg', 'deg', 'deg']) + cls.size_skyE_def = List(Float(), doc="Error on shape of the gaussian FWHM, PA, deg", + colname=['E_Maj', 'E_Min', 'E_PA'], units=['deg', 'deg', 'deg']) + cls.deconv_size_sky_def = List(Float(), doc="Deconvolved shape of the gaussian FWHM, PA, deg", + colname=['DC_Maj', 'DC_Min', 'DC_PA'], units=['deg', 'deg', 'deg']) + cls.deconv_size_skyE_def = List(Float(), doc="Error on deconvolved shape of the gaussian FWHM, PA, deg", + colname=['E_DC_Maj', 'E_DC_Min', 'E_DC_PA'], units=['deg', 'deg', 'deg']) + cls.size_sky_uncorr_def = List(Float(), doc="Shape in image plane of the gaussian FWHM, PA, deg", + colname=['Maj_img_plane', 'Min_img_plane', 'PA_img_plane'], + units=['deg', 'deg', 'deg']) + cls.size_skyE_uncorr_def = List(Float(), doc="Error on shape in image plane of the gaussian FWHM, PA, deg", + colname=['E_Maj_img_plane', 'E_Min_img_plane', 'E_PA_img_plane'], + units=['deg', 'deg', 'deg']) + cls.deconv_size_sky_uncorr_def = List(Float(), doc="Deconvolved shape in image plane of the gaussian FWHM, PA, deg", + colname=['DC_Maj_img_plane', 'DC_Min_img_plane', 'DC_PA_img_plane'], + units=['deg', 'deg', 'deg']) + cls.deconv_size_skyE_uncorr_def = List(Float(), doc="Error on deconvolved shape in image plane of the gaussian FWHM, PA, deg", + colname=['E_DC_Maj_img_plane', 'E_DC_Min_img_plane', 'E_DC_PA_img_plane'], + units=['deg', 'deg', 'deg']) + cls.rms_def = Float(doc="Island rms, Jy/beam", colname='Isl_rms', units='Jy/beam') + cls.mean_def = Float(doc="Island mean, Jy/beam", colname='Isl_mean', units='Jy/beam') + cls.total_flux_isl_def = Float(doc="Island total flux from sum of pixels", colname='Isl_Total_flux', units='Jy') + cls.total_flux_islE_def = Float(doc="Error on island total flux from sum of pixels", colname='E_Isl_Total_flux', units='Jy') + cls.gresid_rms_def = Float(doc="Island rms in Gaussian residual image", colname='Resid_Isl_rms', units='Jy/beam') + cls.gresid_mean_def = Float(doc="Island mean in Gaussian residual image", colname='Resid_Isl_mean', units='Jy/beam') + cls.sresid_rms_def = Float(doc="Island rms in Shapelet residual image", colname='Resid_Isl_rms', units='Jy/beam') + cls.sresid_mean_def = Float(doc="Island mean in Shapelet residual image", colname='Resid_Isl_mean', units='Jy/beam') + cls.wave_rms_def = Float(doc="Island rms in wavelet image, Jy/beam", colname='Wave_Isl_rms', units='Jy/beam') + cls.wave_mean_def = Float(doc="Island mean in wavelet image, Jy/beam", colname='Wave_Isl_mean', units='Jy/beam') + cls.jlevel_def = Int(doc="Wavelet number to which Gaussian belongs", colname='Wave_id') + cls.spec_indx_def = Float(doc="Spectral index", colname='Spec_Indx', units=None) + cls.e_spec_indx_def = Float(doc="Error in spectral index", colname='E_Spec_Indx', units=None) + cls.specin_flux_def = List(Float(), doc="Total flux density per channel, Jy", colname=['Total_flux'], units=['Jy']) + cls.specin_fluxE_def = List(Float(), doc="Error in total flux density per channel, Jy", colname=['E_Total_flux'], units=['Jy']) + cls.specin_freq_def = List(Float(), doc="Frequency per channel, Hz", colname=['Freq'], units=['Hz']) + cls._definitions_initialized = True + def __init__(self, img, gaussian, isl_idx, g_idx, flg=0): """Initialize Gaussian object from fitting data @@ -955,63 +1072,9 @@ def __init__(self, img, gaussian, isl_idx, g_idx, flg=0): from . import functions as func import numpy as N - # Add attribute definitions needed for output - self.source_id_def = Int(doc="Source index", colname='Source_id') - self.code_def = String(doc='Source code S, C, or M', colname='S_Code') - self.gaus_num_def = Int(doc="Serial number of the gaussian for the image", colname='Gaus_id') - self.island_id_def = Int(doc="Serial number of the island", colname='Isl_id') - self.flag_def = Int(doc="Flag associated with gaussian", colname='Flag') - self.total_flux_def = Float(doc="Total flux density, Jy", colname='Total_flux', units='Jy') - self.total_fluxE_def = Float(doc="Total flux density error, Jy", colname='E_Total_flux', - units='Jy') - self.peak_flux_def = Float(doc="Peak flux density/beam, Jy/beam", colname='Peak_flux', - units='Jy/beam') - self.peak_fluxE_def = Float(doc="Peak flux density/beam error, Jy/beam", - colname='E_Peak_flux', units='Jy/beam') - self.centre_sky_def = List(Float(), doc="Sky coordinates of gaussian centre", - colname=['RA', 'DEC'], units=['deg', 'deg']) - self.centre_skyE_def = List(Float(), doc="Error on sky coordinates of gaussian centre", - colname=['E_RA', 'E_DEC'], units=['deg', 'deg']) - self.centre_pix_def = List(Float(), doc="Pixel coordinates of gaussian centre", - colname=['Xposn', 'Yposn'], units=['pix', 'pix']) - self.centre_pixE_def = List(Float(), doc="Error on pixel coordinates of gaussian centre", - colname=['E_Xposn', 'E_Yposn'], units=['pix', 'pix']) - self.size_sky_def = List(Float(), doc="Shape of the gaussian FWHM, PA, deg", - colname=['Maj', 'Min', 'PA'], units=['deg', 'deg', 'deg']) - self.size_skyE_def = List(Float(), doc="Error on shape of the gaussian FWHM, PA, deg", - colname=['E_Maj', 'E_Min', 'E_PA'], units=['deg', 'deg', 'deg']) - self.deconv_size_sky_def = List(Float(), doc="Deconvolved shape of the gaussian FWHM, PA, deg", - colname=['DC_Maj', 'DC_Min', 'DC_PA'], units=['deg', 'deg', 'deg']) - self.deconv_size_skyE_def = List(Float(), doc="Error on deconvolved shape of the gaussian FWHM, PA, deg", - colname=['E_DC_Maj', 'E_DC_Min', 'E_DC_PA'], units=['deg', 'deg', 'deg']) - self.size_sky_uncorr_def = List(Float(), doc="Shape in image plane of the gaussian FWHM, PA, deg", - colname=['Maj_img_plane', 'Min_img_plane', 'PA_img_plane'], - units=['deg', 'deg', 'deg']) - self.size_skyE_uncorr_def = List(Float(), doc="Error on shape in image plane of the gaussian FWHM, PA, deg", - colname=['E_Maj_img_plane', 'E_Min_img_plane', 'E_PA_img_plane'], - units=['deg', 'deg', 'deg']) - self.deconv_size_sky_uncorr_def = List(Float(), doc="Deconvolved shape in image plane of the gaussian FWHM, PA, deg", - colname=['DC_Maj_img_plane', 'DC_Min_img_plane', 'DC_PA_img_plane'], - units=['deg', 'deg', 'deg']) - self.deconv_size_skyE_uncorr_def = List(Float(), doc="Error on deconvolved shape in image plane of the gaussian FWHM, PA, deg", - colname=['E_DC_Maj_img_plane', 'E_DC_Min_img_plane', 'E_DC_PA_img_plane'], - units=['deg', 'deg', 'deg']) - self.rms_def = Float(doc="Island rms, Jy/beam", colname='Isl_rms', units='Jy/beam') - self.mean_def = Float(doc="Island mean, Jy/beam", colname='Isl_mean', units='Jy/beam') - self.total_flux_isl_def = Float(doc="Island total flux from sum of pixels", colname='Isl_Total_flux', units='Jy') - self.total_flux_islE_def = Float(doc="Error on island total flux from sum of pixels", colname='E_Isl_Total_flux', units='Jy') - self.gresid_rms_def = Float(doc="Island rms in Gaussian residual image", colname='Resid_Isl_rms', units='Jy/beam') - self.gresid_mean_def = Float(doc="Island mean in Gaussian residual image", colname='Resid_Isl_mean', units='Jy/beam') - self.sresid_rms_def = Float(doc="Island rms in Shapelet residual image", colname='Resid_Isl_rms', units='Jy/beam') - self.sresid_mean_def = Float(doc="Island mean in Shapelet residual image", colname='Resid_Isl_mean', units='Jy/beam') - self.wave_rms_def = Float(doc="Island rms in wavelet image, Jy/beam", colname='Wave_Isl_rms', units='Jy/beam') - self.wave_mean_def = Float(doc="Island mean in wavelet image, Jy/beam", colname='Wave_Isl_mean', units='Jy/beam') - self.jlevel_def = Int(doc="Wavelet number to which Gaussian belongs", colname='Wave_id') - self.spec_indx_def = Float(doc="Spectral index", colname='Spec_Indx', units=None) - self.e_spec_indx_def = Float(doc="Error in spectral index", colname='E_Spec_Indx', units=None) - self.specin_flux_def = List(Float(), doc="Total flux density per channel, Jy", colname=['Total_flux'], units=['Jy']) - self.specin_fluxE_def = List(Float(), doc="Error in total flux density per channel, Jy", colname=['E_Total_flux'], units=['Jy']) - self.specin_freq_def = List(Float(), doc="Frequency per channel, Hz", colname=['Freq'], units=['Hz']) + cls = type(self) + if not cls._definitions_initialized: + cls._initialize_definitions() use_wcs = True self.gaussian_idx = g_idx @@ -1022,6 +1085,7 @@ def __init__(self, img, gaussian, isl_idx, g_idx, flg=0): self.parameters = gaussian p = gaussian + isl = img.islands[isl_idx] self.peak_flux = p[0] self.centre_pix = p[1:3] size = p[3:6] @@ -1050,7 +1114,7 @@ def __init__(self, img, gaussian, isl_idx, g_idx, flg=0): tot = p[0]*size[0]*size[1]/(bm_pix[0]*bm_pix[1]) if flg == 0: # These are good Gaussians - errors = func.get_errors(img, p+[tot], img.islands[isl_idx].rms, fixed_to_beam=img.opts.fix_to_beam) + errors = func.get_errors(img, p+[tot], isl.rms, fixed_to_beam=img.opts.fix_to_beam) self.centre_sky = img.pix2sky(p[1:3]) self.centre_skyE = img.pix2coord(errors[1:3], self.centre_pix, use_wcs=use_wcs) self.size_sky = img.pix2gaus(size, self.centre_pix, use_wcs=use_wcs) # FWHM in degrees and P.A. east from north @@ -1081,9 +1145,9 @@ def __init__(self, img, gaussian, isl_idx, g_idx, flg=0): self.total_fluxE = errors[6] self.centre_pixE = errors[1:3] self.size_pixE = errors[3:6] - self.rms = img.islands[isl_idx].rms - self.mean = img.islands[isl_idx].mean + self.rms = isl.rms + self.mean = isl.mean self.wave_rms = 0.0 # set if needed in the wavelet operation self.wave_mean = 0.0 # set if needed in the wavelet operation - self.total_flux_isl = img.islands[isl_idx].total_flux - self.total_flux_islE = img.islands[isl_idx].total_fluxE + self.total_flux_isl = isl.total_flux + self.total_flux_islE = isl.total_fluxE diff --git a/bdsf/multi_proc.py b/bdsf/multi_proc.py index 9d3c010..4e56835 100644 --- a/bdsf/multi_proc.py +++ b/bdsf/multi_proc.py @@ -9,6 +9,7 @@ """ +import heapq import multiprocessing import os import sys @@ -30,7 +31,7 @@ __all__ = ('parallel_map',) -def worker(f, ii, chunk, out_q, err_q, lock, bar, bar_state): +def worker(f, ii, chunk, out_q, err_q, lock, bar, bar_state, preserve_order=False): """ A worker function that maps an input function over a slice of the input iterable. @@ -44,11 +45,15 @@ def worker(f, ii, chunk, out_q, err_q, lock, bar, bar_state): ( useful in extending parallel_map() ) :param bar: statusbar to update during fit :param bar_state: statusbar state dictionary + :param preserve_order: whether chunk entries carry their original index """ vals = [] - # iterate over slice - for val in chunk: + for entry in chunk: + if preserve_order: + val_idx, val = entry + else: + val = entry try: result = f(val) except Exception as e: @@ -61,9 +66,11 @@ def worker(f, ii, chunk, out_q, err_q, lock, bar, bar_state): err_q.put(e) return - vals.append(result) + if preserve_order: + vals.append((val_idx, result)) + else: + vals.append(result) - # update statusbar if bar is not None: if bar_state['started']: bar.pos = bar_state['pos'] @@ -76,11 +83,10 @@ def worker(f, ii, chunk, out_q, err_q, lock, bar, bar_state): if bar_state['spin_pos'] >= 4: bar_state['spin_pos'] = 0 - # output the result and task ID to output queue out_q.put((ii, vals)) -def run_tasks(procs, err_q, out_q, num): +def run_tasks(procs, err_q, out_q, num, preserve_order=False, total_items=None): """ A function that executes populated processes and processes the resultant array. Checks error queue for any exceptions. @@ -89,9 +95,10 @@ def run_tasks(procs, err_q, out_q, num): :param out_q: thread-safe output queue :param err_q: thread-safe queue to populate on exception :param num : length of resultant array + :param preserve_order: whether worker outputs carry original item indices + :param total_items: total number of items to reconstruct """ - # function to terminate processes that are still running. die = (lambda vals: [val.terminate() for val in vals if val.exitcode is None]) @@ -107,23 +114,26 @@ def run_tasks(procs, err_q, out_q, num): ) except Exception as e: - # kill all slave processes on ctrl-C die(procs) raise e if not err_q.empty(): - # kill all on any exception from any one slave die(procs) raise err_q.get() - # Processes finish in arbitrary order. Process IDs double - # as index in the resultant array. + if preserve_order: + results = [None] * total_items + for i in range(num): + idx, result = out_q.get() + for item_idx, item_result in result: + results[item_idx] = item_result + return results + results = [None] * num for i in range(num): idx, result = out_q.get() results[idx] = result - # Remove extra dimension added by array_split result_list = [] for result in results: result_list += result @@ -154,7 +164,8 @@ def parallel_map(function, sequence, numcores=None, bar=None, weights=None): raise TypeError("input '%s' is not iterable" % repr(sequence)) - sequence = numpy.array(list(sequence), dtype=object) + sequence_list = list(sequence) + sequence = numpy.array(sequence_list, dtype=object) size = len(sequence) if size == 1: @@ -163,7 +174,6 @@ def parallel_map(function, sequence, numcores=None, bar=None, weights=None): bar.stop() return results - # Set number of cores to use. Try to leave one core free for pyplot. if numcores is None: numcores = _ncpus - 1 if numcores > _ncpus - 1: @@ -171,15 +181,8 @@ def parallel_map(function, sequence, numcores=None, bar=None, weights=None): if numcores < 1: numcores = 1 - # Returns a started SyncManager object which can be used for sharing - # objects between processes. The returned manager object corresponds - # to a spawned child process and has methods which will create shared - # objects and return corresponding proxies. manager = multiprocessing.Manager() - # Create FIFO queue and lock shared objects and return proxies to them. - # The managers handles a server process that manages shared objects that - # each slave process has access to. Bottom line -- thread-safe. out_q = manager.Queue() err_q = manager.Queue() lock = manager.Lock() @@ -189,40 +192,42 @@ def parallel_map(function, sequence, numcores=None, bar=None, weights=None): bar_state['spin_pos'] = bar.spin_pos bar_state['started'] = bar.started - # if sequence is less than numcores, only use len sequence number of - # processes if size < numcores: numcores = size - # group sequence into numcores-worth of chunks + preserve_order = False if weights is None or numcores == size: - # No grouping specified (or there are as many cores as - # processes), so divide into equal chunks sequence = numpy.array_split(sequence, numcores) else: - # Group so that each group has roughly an equal sum of weights - weight_per_core = numpy.sum(weights)/float(numcores) - cut_values = [] - temp_sum = 0.0 - for indx, weight in enumerate(weights): - temp_sum += weight - if temp_sum > weight_per_core: - cut_values.append(indx+1) - temp_sum = weight - if len(cut_values) > numcores - 1: - cut_values = cut_values[0:numcores-1] - sequence = numpy.array_split(sequence, cut_values) - - # Make sure there are no empty chunks at the end of the sequence + preserve_order = True + weight_array = numpy.asarray(weights, dtype=numpy.float64) + indexed_sequence = list(enumerate(sequence_list)) + bins = [[] for _ in range(numcores)] + heap = [(0.0, idx) for idx in range(numcores)] + heapq.heapify(heap) + + weighted_items = zip(indexed_sequence, weight_array.tolist()) + for (orig_idx, item), weight in sorted(weighted_items, + key=lambda pair: pair[1], + reverse=True): + current_weight, worker_idx = heapq.heappop(heap) + bins[worker_idx].append((orig_idx, item)) + current_weight += float(weight) + heapq.heappush(heap, (current_weight, worker_idx)) + + sequence = bins + while len(sequence[-1]) == 0: sequence.pop() procs = [multiprocessing.Process(target=worker, - args=(function, ii, chunk, out_q, err_q, lock, bar, bar_state)) + args=(function, ii, chunk, out_q, err_q, lock, bar, bar_state, + preserve_order)) for ii, chunk in enumerate(sequence)] try: - results = run_tasks(procs, err_q, out_q, len(sequence)) + results = run_tasks(procs, err_q, out_q, len(sequence), + preserve_order=preserve_order, total_items=size) if bar is not None: if bar.started: bar.stop() diff --git a/bdsf/output.py b/bdsf/output.py index b301451..8dfbc9c 100644 --- a/bdsf/output.py +++ b/bdsf/output.py @@ -1090,14 +1090,38 @@ def make_output_columns(obj, fits=False, objtype='gaul', incl_spin=False, cunits = [] cvals = [] skip_next = False + + def get_definition(name): + def_name = name + '_def' + + # Most objects still store column definitions directly on the + # instance, but Gaussian definitions now live on the class as TC + # descriptors. Looking in __dict__ first avoids triggering the TC + # descriptor and getting its default value instead of the + # definition object. + definition = obj.__dict__.get(def_name) + if hasattr(definition, '_colname'): + return definition + + definition = getattr(type(obj), def_name, None) + if hasattr(definition, '_colname'): + return definition + + definition = getattr(obj, def_name) + if hasattr(definition, '_colname'): + return definition + + raise AttributeError("No output definition found for '%s'" % def_name) + for n, name in enumerate(names): if hasattr(obj, name): if name in ['specin_flux', 'specin_fluxE', 'specin_freq']: # As these are variable length lists, they must # (unfortunately) be treated differently. val = obj.__getattribute__(name) - colname = obj.__dict__[name+'_def']._colname - units = obj.__dict__[name+'_def']._units + definition = get_definition(name) + colname = definition._colname + units = definition._units for i in range(nchan): if i < len(val): cvals.append(val[i]) @@ -1110,8 +1134,9 @@ def make_output_columns(obj, fits=False, objtype='gaul', incl_spin=False, else: if not skip_next: val = obj.__getattribute__(name) - colname = obj.__dict__[name+'_def']._colname - units = obj.__dict__[name+'_def']._units + definition = get_definition(name) + colname = definition._colname + units = definition._units if units is None: units = ' ' if isinstance(val, list) or isinstance(val, tuple): @@ -1120,8 +1145,9 @@ def make_output_columns(obj, fits=False, objtype='gaul', incl_spin=False, # in the order (val, error). next_name = names[n+1] val_next = obj.__getattribute__(next_name) - colname_next = obj.__dict__[next_name+'_def']._colname - units_next = obj.__dict__[next_name+'_def']._units + next_definition = get_definition(next_name) + colname_next = next_definition._colname + units_next = next_definition._units if units_next is None: units_next = ' ' for i in range(len(val)):