diff --git a/python/lsst/pipe/tasks/desTemplateTask.py b/python/lsst/pipe/tasks/desTemplateTask.py new file mode 100644 index 000000000..d1fb80cdd --- /dev/null +++ b/python/lsst/pipe/tasks/desTemplateTask.py @@ -0,0 +1,711 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Pipeline for generating templates from archival DES data. +""" + +__all__ = ["DesTemplateTask", + "DesTemplateConfig", + "DesTemplateConnections"] + +import numpy as np + +from astropy.table import Table +from astropy.io import fits + +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import lsst.pipe.base.connectionTypes as connectionTypes +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.afw.math as afwMath +import lsst.afw.geom as afwGeom +from lsst.geom import Box2D, Point2D +from lsst.sphgeom import ConvexPolygon, UnitVector3d, LonLat +from lsst.afw.image import makePhotoCalibFromCalibZeroPoint +from lsst.meas.algorithms.installGaussianPsf import ( + InstallGaussianPsfTask, + InstallGaussianPsfConfig, +) +from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig + +from lsst.meas.base import IdGenerator +from lsst.geom import SpherePoint, degrees +import lsst.log + + +class DesTemplateConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for DES template creation task.""" + + bbox = connectionTypes.Input( + doc="Bounding box of the exposure", + name="pvi.bbox", + storageClass="Box2I", + dimensions=("instrument", "visit", "detector"), + ) + wcs = connectionTypes.Input( + doc="WCS of the exposure", + name="pvi.wcs", + storageClass="Wcs", + dimensions=("instrument", "visit", "detector"), + ) + template = connectionTypes.Output( + doc="Template exposure created from DES tiles", + name="desTemplate", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + +class DesTemplateConfig( + pipeBase.PipelineTaskConfig, pipelineConnections=DesTemplateConnections +): + """Configuration for DES template creation task.""" + + tileFile = pexConfig.Field( + dtype=str, + default="merged_tiles_v3_deduplicated.csv", + doc="Path to merged CSV file containing tile information with columns: " + "tilename, band, survey, ra_cent, dec_cent, rac1-4, decc1-4, filepath", + ) + dataSourcePriority = pexConfig.ListField( + dtype=str, + default=["DELVE", "DES", "DECADE"], + doc="Priority order for surveys when duplicate tiles exist (first = highest priority)", + ) + modelPsf = pexConfig.Field( + dtype=bool, + default=False, + doc="Use modeled PSF instead of Gaussian PSF for templates", + ) + gaussianPsfFwhm = pexConfig.Field( + dtype=float, + default=4.0, + doc="FWHM for Gaussian PSF in pixels (used when modelPsf=False)", + ) + gaussianPsfWidth = pexConfig.Field( + dtype=int, + default=21, + doc="Width for Gaussian PSF in pixels (used when modelPsf=False)", + ) + hasInverseVariancePlane = pexConfig.Field( + dtype=bool, + default=True, + doc="Do the template files have a weight map or inverse variance plane", + ) + + +class DesTemplateTask(pipeBase.PipelineTask): + """Create template exposures from survey tile data.""" + + ConfigClass = DesTemplateConfig + _DefaultName = "desTemplate" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Suppress FITS reader warnings for survey tiles + lsst.log.setLevel("lsst.afw.image.MaskedImageFitsReader", lsst.log.ERROR) + + if self.config.modelPsf: + from lsst.pipe.tasks.calibrateImage import ( + CalibrateImageTask, + CalibrateImageConfig, + ) + + cal_config = CalibrateImageConfig() + cal_config.psf_repair.doInterpolate = False + cal_config.psf_repair.doCosmicRay = False + cal_config.psf_detection.thresholdValue = 100.0 + self.calibrateTask = CalibrateImageTask(config=cal_config) + else: + install_psf_config = InstallGaussianPsfConfig() + install_psf_config.fwhm = self.config.gaussianPsfFwhm + install_psf_config.width = self.config.gaussianPsfWidth + self.installPsfTask = InstallGaussianPsfTask(config=install_psf_config) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Extract band and physical_filter from dataId + band = butlerQC.quantum.dataId["band"] + physical_filter = butlerQC.quantum.dataId["physical_filter"] + + inputs = butlerQC.get(inputRefs) + outputs = self.run(band=band, physical_filter=physical_filter, **inputs) + butlerQC.put(outputs, outputRefs) + + def run(self, bbox, wcs, band, physical_filter): + """Create template from survey tiles overlapping the target region. + + Parameters + ---------- + bbox : `lsst.geom.Box2I` + Bounding box of the target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of the target region + band : `str` + Photometric band (e.g., 'r', 'g', 'i') + physical_filter : `str` + Physical filter name (e.g., 'r_03', 'g_01') + + Returns + ------- + result : `lsst.pipe.base.Struct` + Results struct with: + - ``template`` : Template exposure created from survey tiles + """ + self.log.info("Creating survey template for target region") + self.log.info(f"Using band '{band}' and physical_filter '{physical_filter}'") + + # Load tile catalog and filter by band + tile_table = Table.read(self.config.tileFile, format="ascii") + self.log.info(f"Loaded {len(tile_table)} total tiles from catalog") + + # Validate required columns exist + required_columns = [ + "tilename", + "band", + "survey", + "ra_cent", + "dec_cent", + "rac1", + "rac2", + "rac3", + "rac4", + "decc1", + "decc2", + "decc3", + "decc4", + "filepath", + ] + missing_columns = [ + col for col in required_columns if col not in tile_table.colnames + ] + if missing_columns: + raise RuntimeError( + f"Required columns missing from tile table: {missing_columns}. " + "Please use the merged tiles CSV format." + ) + + # Filter by band first + band_filtered_tiles = self._filterTilesByBand(tile_table, band) + self.log.info(f"Found {len(band_filtered_tiles)} tiles for band '{band}'") + + if len(band_filtered_tiles) == 0: + raise RuntimeError(f"No tiles found for band '{band}' in catalog") + + # Find overlapping tiles + overlapping_tiles = self._findOverlappingTiles(bbox, wcs, band_filtered_tiles) + self.log.info(f"Found {len(overlapping_tiles)} overlapping tiles") + + if not overlapping_tiles: + center = wcs.pixelToSky(Point2D(bbox.getCenter())) + target_ra = center.getRa().asDegrees() + target_dec = center.getDec().asDegrees() + self.log.info( + f"Target was centered at RA={target_ra:.3f}, Dec={target_dec:.3f}" + ) + raise RuntimeError("No survey tiles found overlapping with target region") + + # Resolve duplicate tiles and load template exposures + resolved_tiles = self._resolveDuplicateTiles( + overlapping_tiles, band_filtered_tiles + ) + tile_exposures, tile_names = self._loadTileTemplates(resolved_tiles) + self.log.info(f"Successfully loaded {len(tile_exposures)} template exposures") + + if not tile_exposures: + raise RuntimeError("No valid template exposures found") + + # Create coadded template + template = self._createCoaddFromTiles( + tile_exposures, tile_names, wcs, bbox, physical_filter + ) + + self.log.info("Successfully created survey template") + return pipeBase.Struct(template=template) + + def _tileOverlapsRegion(self, tile_corners, bbox, wcs): + """Check if a tile overlaps with the target region using spherical geometry. + + Parameters + ---------- + tile_corners : `list` of `tuple` + List of (ra, dec) tuples for tile corners in degrees + bbox : `lsst.geom.Box2I` + Bounding box of target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of target region + + Returns + ------- + overlaps : `bool` + True if tile overlaps with target region, False otherwise + """ + # Get target region corners + region_corners = [wcs.pixelToSky(Point2D(p)) for p in bbox.getCorners()] + + # Convert tile corners to unit vectors + tile_unit_vectors = [ + UnitVector3d(LonLat.fromDegrees(ra, dec)) for ra, dec in tile_corners + ] + + # Convert region corners to unit vectors + region_unit_vectors = [ + UnitVector3d( + LonLat.fromDegrees(s.getRa().asDegrees(), s.getDec().asDegrees()) + ) + for s in region_corners + ] + + # Create convex polygons + tile_polygon = ConvexPolygon(tile_unit_vectors) + region_polygon = ConvexPolygon(region_unit_vectors) + + # Check for overlap + return tile_polygon.overlaps(region_polygon) + + def _findOverlappingTiles(self, bbox, wcs, tile_table): + """Find tiles that overlap with the target region. + + Parameters + ---------- + bbox : `lsst.geom.Box2I` + Bounding box of target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of target region + tile_table : `astropy.table.Table` + Table containing tile information + + Returns + ------- + overlapping_tiles : `list` of `tuple` + List of (tile_name, distance_deg) for overlapping tiles, sorted by distance + """ + # Get target region center for distance calculation + center = wcs.pixelToSky(Point2D(bbox.getCenter())) + target_ra = center.getRa().asDegrees() + target_dec = center.getDec().asDegrees() + p0 = SpherePoint(target_ra * degrees, target_dec * degrees) + overlapping_tiles = [] + total_tiles_checked = 0 + + self.log.info(f"Checking {len(tile_table)} tiles for overlap with exposure") + + for row in tile_table: + tile_name = row["tilename"] + total_tiles_checked += 1 + + # Extract tile corners from table + required_corner_columns = [ + "rac1", + "decc1", + "rac2", + "decc2", + "rac3", + "decc3", + "rac4", + "decc4", + ] + + if not all(col in tile_table.colnames for col in required_corner_columns): + missing_corners = [ + col + for col in required_corner_columns + if col not in tile_table.colnames + ] + raise RuntimeError( + f"Required corner coordinate columns missing from tile table: {missing_corners}" + ) + + # Use actual corner coordinates + tile_corners = [ + (row["rac1"], row["decc1"]), + (row["rac2"], row["decc2"]), + (row["rac3"], row["decc3"]), + (row["rac4"], row["decc4"]), + ] + + # Check if tile overlaps with target region + has_overlap = self._tileOverlapsRegion(tile_corners, bbox, wcs) + + if has_overlap: + # Calculate distance to tile center for sorting + tile_ra = row["ra_cent"] + tile_dec = row["dec_cent"] + + p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees) + distance = p0.separation(p1).asDegrees() + + self.log.info(f" OVERLAP: {tile_name} - distance={distance:.3f}°") + + overlapping_tiles.append((tile_name, distance)) + else: + # Log nearby tiles that don't overlap for debugging + tile_ra = row["ra_cent"] + tile_dec = row["dec_cent"] + p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees) + distance = p0.separation(p1).asDegrees() + + if distance < 2.0: # Only log tiles within 2 degrees + self.log.info( + f" No overlap: {tile_name} - distance={distance:.3f}° (RA={tile_ra:.3f}, Dec={tile_dec:.3f})" + ) + + self.log.info( + f"Checked {total_tiles_checked} tiles, found {len(overlapping_tiles)} with overlap" + ) + + # Sort by distance + overlapping_tiles.sort(key=lambda x: x[1]) + + # Return the overlapping tiles + return overlapping_tiles + + def _filterTilesByBand(self, tile_table, band): + """Filter tiles by the requested band and validate survey values. + + Parameters + ---------- + tile_table : `astropy.table.Table` + Table containing all tiles + band : `str` + Photometric band to filter for + + Returns + ------- + filtered_table : `astropy.table.Table` + Table containing only tiles for the specified band with valid surveys + """ + if "band" not in tile_table.colnames: + raise RuntimeError( + "Required 'band' column not found in tile table. " + "Please use the merged tiles CSV format." + ) + + if "survey" not in tile_table.colnames: + raise RuntimeError( + "Required 'survey' column not found in tile table. " + "Please use the merged tiles CSV format." + ) + + # First filter by band + band_mask = tile_table["band"] == band + band_filtered = tile_table[band_mask] + + # Then validate survey values + valid_surveys = set(self.config.dataSourcePriority) + survey_column = band_filtered["survey"] + + # Vectorized membership check (case-sensitive) + valid_mask = np.isin(survey_column, list(valid_surveys)) + + invalid_count = len(band_filtered) - np.sum(valid_mask) + if invalid_count > 0: + self.log.warn( + f"Found {invalid_count} tiles with unrecognized survey values, ignoring them" + ) + + filtered_table = band_filtered[valid_mask] + + self.log.info( + f"Filtered {len(tile_table)} tiles to {len(filtered_table)} for band '{band}' with valid surveys" + ) + return filtered_table + + def _resolveDuplicateTiles(self, overlapping_tiles, band_filtered_tiles): + """Resolve duplicate tiles based on survey priority. + + Parameters + ---------- + overlapping_tiles : `list` of `tuple` + List of (tile_name, distance) tuples for overlapping tiles + band_filtered_tiles : `astropy.table.Table` + Band-filtered tile table with filepath and survey columns + + Returns + ------- + resolved_tiles : `list` of `dict` + List of resolved tile dictionaries with tilename, distance, filepath, survey + """ + # Get priority order from config + priority_order = self.config.dataSourcePriority + + resolved_tiles = [] + + for tile_name, distance in overlapping_tiles: + # Find all entries for this tilename in the band-filtered table + tile_mask = band_filtered_tiles["tilename"] == tile_name + tile_entries = band_filtered_tiles[tile_mask] + + if len(tile_entries) == 0: + self.log.warn(f"No entries found for overlapping tile {tile_name}") + continue + elif len(tile_entries) == 1: + # No duplicates, use the single entry + entry = tile_entries[0] + resolved_tiles.append( + { + "tilename": tile_name, + "distance": distance, + "filepath": entry["filepath"], + "survey": entry["survey"], + } + ) + else: + # Multiple entries, resolve based on priority order + selected_entry = None + + # Check for each survey in priority order + for preferred_survey in priority_order: + survey_entries = tile_entries[ + tile_entries["survey"] == preferred_survey + ] + if len(survey_entries) > 0: + selected_entry = survey_entries[0] + self.log.info( + f"Using {preferred_survey} tile for {tile_name} (priority)" + ) + break + + if selected_entry is None: + # No recognized survey, use first available + selected_entry = tile_entries[0] + self.log.info( + f"Using {selected_entry['survey']} tile for {tile_name} (fallback)" + ) + + resolved_tiles.append( + { + "tilename": tile_name, + "distance": distance, + "filepath": selected_entry["filepath"], + "survey": selected_entry["survey"], + } + ) + + self.log.info( + f"Resolved {len(resolved_tiles)} tiles from {len(overlapping_tiles)} overlapping tiles" + ) + return resolved_tiles + + def _loadTileTemplates(self, resolved_tiles): + """Load template exposures from resolved tiles. + + Parameters + ---------- + resolved_tiles : `list` of `dict` + List of resolved tile dictionaries with tilename, distance, filepath, survey + + Returns + ------- + exposures : `list` of `lsst.afw.image.ExposureF` + List of loaded template exposures + tile_names : `list` of `str` + List of tile names corresponding to exposures + """ + exposures = [] + tile_names = [] + + for tile_info in resolved_tiles: + tile_name = tile_info["tilename"] + filepath = tile_info["filepath"] + survey = tile_info["survey"] + + try: + self.log.info(f"Loading {survey} tile {tile_name} from {filepath}") + exp = afwImage.ExposureF(filepath) + + if self.config.hasInverseVariancePlane: + # Need convert DES weight map to variance map for LSST exposure + mi = exp.maskedImage + bad_weight = mi.variance.array <= 0.0 + mi.variance.array[~bad_weight] = ( + 1.0 / mi.variance.array[~bad_weight] + ) + mi.variance.array[bad_weight] = np.nan + mi.mask.array[bad_weight] |= mi.mask.getPlaneBitMask("NO_DATA") + + # Set PSF + if not self.config.modelPsf: + # Install Gaussian PSF + self.installPsfTask.run(exp) + else: + # Use modeled PSF + cat = self.calibrateTask._compute_psf(exp, IdGenerator()) + + # Set photometric calibration + metadata = exp.getInfo().getMetadata() + if "MAGZERO" in metadata: + zp = metadata["MAGZERO"] + flux0 = 10 ** (0.4 * zp) + calib = makePhotoCalibFromCalibZeroPoint(flux0, 0.0) + exp.setPhotoCalib(calib) + else: + self.log.warn( + f"No MAGZERO found in {tile_name}, using default calibration" + ) + + exposures.append(exp) + tile_names.append(tile_name) + + except Exception as e: + self.log.warn(f"Failed to load tile {tile_name}: {e}") + continue + + return exposures, tile_names + + def _createCoaddFromTiles( + self, tile_exposures, tile_names, target_wcs, target_bbox, physical_filter + ): + """Create coadded template from tile exposures. + + Parameters + ---------- + tile_exposures : `list` of `lsst.afw.image.ExposureF` + List of tile exposures + tile_names : `list` of `str` + List of tile names + target_wcs : `lsst.afw.geom.SkyWcs` + Target WCS for template + target_bbox : `lsst.geom.Box2I` + Target bounding box for template + physical_filter : `str` + Physical filter name for the template + + Returns + ------- + template : `lsst.afw.image.ExposureF` + Coadded template exposure + """ + warper = afwMath.Warper.fromConfig(afwMath.Warper.ConfigClass()) + + # Create tile catalog for PSF computation + tile_schema = afwTable.ExposureTable.makeMinimalSchema() + tile_key = tile_schema.addField("tile", type="String", size=12) + weight_key = tile_schema.addField("weight", type=float) + + coadd_config = CoaddPsfConfig() + + # Statistics configuration + stats_flags = afwMath.stringToStatisticsProperty("MEAN") + stats_ctrl = afwMath.StatisticsControl() + stats_ctrl.setNanSafe(True) + stats_ctrl.setWeighted(True) + stats_ctrl.setCalcErrorMosaicMode(True) + + tile_catalog = afwTable.ExposureCatalog(tile_schema) + masked_images = [] + weights = [] + + for i, (exp, tile) in enumerate(zip(tile_exposures, tile_names)): + self.log.debug(f"Processing tile {tile} ({i+1}/{len(tile_exposures)})") + + # Warp template to target coordinate system + warped = warper.warpExposure(target_wcs, exp, maxBBox=target_bbox) + if warped.getBBox().getArea() == 0: + self.log.debug(f"Skipping tile {tile}: no overlap after warping") + continue + + # Create properly initialized exposure + aligned_exp = afwImage.ExposureF(target_bbox, target_wcs) + aligned_exp.maskedImage.set( + np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan + ) + aligned_exp.maskedImage.assign(warped.maskedImage, warped.getBBox()) + masked_images.append(aligned_exp.maskedImage) + + # Calculate weight (inverse of mean variance) + var_array = aligned_exp.variance.array + finite_var = var_array[np.isfinite(var_array)] + if len(finite_var) > 0: + mean_var = np.mean(finite_var) + weight = 1.0 / mean_var if mean_var > 0 else 1.0 + else: + weight = 1.0 + weights.append(weight) + + # Add to tile catalog for PSF computation + record = tile_catalog.addNew() + record.set(tile_key, tile) + record.set(weight_key, weight) + record.setPsf(exp.getPsf()) + record.setWcs(exp.getWcs()) + record.setPhotoCalib(exp.getPhotoCalib()) + record.setBBox(exp.getBBox()) + + polygon = afwGeom.Polygon(Box2D(exp.getBBox()).getCorners()) + record.setValidPolygon(polygon) + + if not masked_images: + raise RuntimeError("No valid warped images for coadd creation") + + # Create coadd exposure + coadd = afwImage.ExposureF(target_bbox, target_wcs) + coadd.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) + xy0 = coadd.getXY0() + + # Perform statistical stacking + coadd.maskedImage = afwMath.statisticsStack( + masked_images, stats_flags, stats_ctrl, weights, clipped=0, maskMap=[] + ) + coadd.maskedImage.setXY0(xy0) + + # Create and set PSF + if len(tile_catalog) > 0: + valid_mask = ( + coadd.maskedImage.mask.array + & coadd.maskedImage.mask.getPlaneBitMask("NO_DATA") + ) == 0 + if np.any(valid_mask): + mask_for_centroid = afwImage.makeMaskFromArray( + valid_mask.astype(afwImage.MaskPixel) + ) + psf_center = afwGeom.SpanSet.fromMask( + mask_for_centroid, 1 + ).computeCentroid() + + ctrl = coadd_config.makeControl() + coadd_psf = CoaddPsf( + tile_catalog, + target_wcs, + psf_center, + ctrl.warpingKernelName, + ctrl.cacheSize, + ) + coadd.setPsf(coadd_psf) + + # Set calibration and filter from first tile exposure and physical_filter + if tile_exposures: + coadd.setPhotoCalib(tile_exposures[0].getPhotoCalib()) + # Create filter label from physical_filter + # Extract band from physical_filter (e.g., 'r_03' -> 'r') + band = ( + physical_filter.split("_")[0] + if "_" in physical_filter + else physical_filter + ) + filter_label = afwImage.FilterLabel(band=band, physical=physical_filter) + coadd.setFilter(filter_label) + + return coadd diff --git a/python/lsst/pipe/tasks/maskReferenceSources.py b/python/lsst/pipe/tasks/maskReferenceSources.py new file mode 100644 index 000000000..a6b458c03 --- /dev/null +++ b/python/lsst/pipe/tasks/maskReferenceSources.py @@ -0,0 +1,424 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Pipeline for masking DIA sources based on a reference catalog +""" + +__all__ = ["MaskReferenceSourcesTask", + "MaskReferenceSourcesConfig", + "MaskReferenceSourcesConnections"] + +import numpy as np +import logging + +from astropy.coordinates import SkyCoord +import astropy.units as u + +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import lsst.pipe.base.connectionTypes as connectionTypes +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.meas.algorithms + + +class MaskReferenceSourcesConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for MaskReferenceSources task.""" + + # Input reference catalog + astrometry_ref_cat = connectionTypes.PrerequisiteInput( + doc="Reference catalog to use for source matching", + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + dimensions=("skypix",), + deferLoad=True, + multiple=True, + ) + + # Input image for matching (default: science image) + matching_image = connectionTypes.Input( + doc="Image to use for reference source matching", + name="calexp", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + # Input sources for matching (default: science sources) + matching_sources = connectionTypes.Input( + doc="Sources detected in matching image", + name="src", + storageClass="SourceCatalog", + dimensions=("instrument", "visit", "detector"), + ) + + # Input difference image to mask + difference_image = connectionTypes.Input( + doc="Difference image to apply masks to", + name="difference_image", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + # Output masked difference image + masked_difference_image = connectionTypes.Output( + doc="Difference image with reference source mask", + name="difference_image_masked", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + +class MaskReferenceSourcesConfig( + pipeBase.PipelineTaskConfig, pipelineConnections=MaskReferenceSourcesConnections +): + """Configuration for MaskReferenceSources task.""" + + matching_radius = pexConfig.Field( + dtype=float, + default=1.0, + doc="Maximum separation for source matching in arcseconds", + ) + reference_mag_column = pexConfig.Field( + dtype=str, + default="r_flux", + doc="Column name for reference catalog flux (in nJy)", + ) + reference_ra_column = pexConfig.Field( + dtype=str, + default="coord_ra", + doc="Column name for reference catalog RA in degrees", + ) + reference_dec_column = pexConfig.Field( + dtype=str, + default="coord_dec", + doc="Column name for reference catalog Dec in degrees", + ) + mask_plane_name = pexConfig.Field( + dtype=str, + default="REFERENCE", + doc="Name of mask plane to create for reference source regions", + ) + reference_buffer = pexConfig.Field( + dtype=int, + default=100, + doc="Buffer in pixels to add around image bounds when loading reference catalog", + ) + astrometry_ref_loader = pexConfig.ConfigField( + dtype=lsst.meas.algorithms.LoadReferenceObjectsConfig, + doc="Configuration of reference object loader for source matching", + ) + + def setDefaults(self): + super().setDefaults() + # Set to Gaii, but we don't use photometry currently + self.astrometry_ref_loader.filterMap = { + "u": "phot_g_mean", + "g": "phot_g_mean", + "r": "phot_g_mean", + "i": "phot_g_mean", + "z": "phot_g_mean", + "y": "phot_g_mean", + } + + +class MaskReferenceSourcesTask(pipeBase.PipelineTask): + """Mask regions around reference sources in difference images.""" + + ConfigClass = MaskReferenceSourcesConfig + _DefaultName = "maskReferenceSources" + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + """Run the task on quantum data.""" + inputs = butlerQC.get(inputRefs) + + # Create reference object loader following the standard pattern + astrometry_loader = lsst.meas.algorithms.ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.astrometry_ref_cat], + refCats=inputs.pop("astrometry_ref_cat"), + name=self.config.connections.astrometry_ref_cat, + config=self.config.astrometry_ref_loader, + log=self.log, + ) + + # Load reference catalog using loadPixelBox with buffer for edge sources + original_bbox = inputs["matching_image"].getBBox() + buffered_bbox = original_bbox.dilatedBy(self.config.reference_buffer) + + self.log.info( + f"Loading reference catalog for bbox: {original_bbox} (buffered to {buffered_bbox})" + ) + + ref_result = astrometry_loader.loadPixelBox( + bbox=buffered_bbox, + wcs=inputs["matching_image"].getWcs(), + filterName=inputs["matching_image"].getFilter().bandLabel, + ) + + self.log.info(f"Loaded {len(ref_result.refCat)} reference sources") + + outputs = self.run( + ref_catalog=ref_result.refCat, + difference_image=inputs["difference_image"], + matching_sources=inputs["matching_sources"], + matching_image=inputs["matching_image"], + ) + + butlerQC.put(outputs, outputRefs) + + def run(self, ref_catalog, difference_image, matching_sources, matching_image): + """Mask reference sources in difference image. + + Parameters + ---------- + ref_catalog : `lsst.afw.table.SourceCatalog` + Reference catalog with sources to mask + difference_image : `lsst.afw.image.ExposureF` + Difference image to apply masks to + matching_sources : `lsst.afw.table.SourceCatalog` + Sources to match against reference catalog + matching_image : `lsst.afw.image.ExposureF` + Image where matching sources were detected + + Returns + ------- + result : `lsst.pipe.base.Struct` + Results struct with: + - ``masked_difference_image`` : Modified difference image with reference mask + """ + self.log.info(f"Masking reference sources in difference image") + self.log.info(f"Reference catalog has {len(ref_catalog)} sources") + self.log.info(f"Matching against {len(matching_sources)} detected sources") + + # Create copy of difference image to modify + masked_diff = difference_image.clone() + + # Add mask plane if it doesn't exist + mask = masked_diff.mask + try: + mask_bit = mask.addMaskPlane(self.config.mask_plane_name) + self.log.info( + f"Added mask plane '{self.config.mask_plane_name}' with bit {mask_bit}" + ) + except Exception: + # Mask plane already exists + mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name) + self.log.info( + f"Using existing mask plane '{self.config.mask_plane_name}' with bit {mask_bit}" + ) + + # Match reference sources to the matching sources + matches = self._matchSources( + ref_catalog, matching_sources, matching_image.getWcs() + ) + self.log.info(f"Found {len(matches)} matched sources") + + if len(matches) == 0: + self.log.warn(f"No reference sources matched to detected sources") + return pipeBase.Struct(masked_difference_image=masked_diff) + + # Apply masks for matched sources + n_masked = self._applyReferenceMasks(matches, masked_diff) + self.log.info(f"Masked {n_masked} reference source regions") + + return pipeBase.Struct(masked_difference_image=masked_diff) + + def _matchSources(self, ref_catalog, matching_sources, wcs): + """Match reference catalog sources to detected sources using Astropy SkyCoord. + + Parameters + ---------- + ref_catalog : `lsst.afw.table.SourceCatalog` + Reference catalog sources + matching_sources : `lsst.afw.table.SourceCatalog` + Detected sources to match against + wcs : `lsst.afw.geom.SkyWcs` + WCS for coordinate conversion + + Returns + ------- + matches : `list` of `tuple` + List of (ref_source, matching_source) pairs that match + """ + if len(ref_catalog) == 0 or len(matching_sources) == 0: + return [] + + # Extract reference source coordinates + ref_ras = [] + ref_decs = [] + ref_sources = [] + + for ref_src in ref_catalog: + if ( + self.config.reference_ra_column in ref_src.schema + and self.config.reference_dec_column in ref_src.schema + ): + # Convert to degrees - these might be Angle objects + ra = ref_src[self.config.reference_ra_column].asDegrees() + dec = ref_src[self.config.reference_dec_column].asDegrees() + else: + # Fallback to coord if specific columns not available + coord = ref_src.getCoord() + ra = coord.getRa().asDegrees() + dec = coord.getDec().asDegrees() + + ref_ras.append(ra) + ref_decs.append(dec) + ref_sources.append(ref_src) + + # Extract detected source coordinates + match_ras = [] + match_decs = [] + match_sources = [] + + for src in matching_sources: + coord = wcs.pixelToSky(Point2D(src.getX(), src.getY())) + match_ras.append(coord.getRa().asDegrees()) + match_decs.append(coord.getDec().asDegrees()) + match_sources.append(src) + + # Create SkyCoord objects + ref_coords = SkyCoord(ref_ras * u.deg, ref_decs * u.deg) + match_coords = SkyCoord(match_ras * u.deg, match_decs * u.deg) + + # Perform matching with maximum separation + max_sep = self.config.matching_radius * u.arcsec + idx, d2d, d3d = match_coords.match_to_catalog_sky(ref_coords) + + # Build matches list for sources within matching radius + matches = [] + n_too_far = 0 + for i, (match_idx, sep) in enumerate(zip(idx, d2d)): + if sep <= max_sep: + matches.append((ref_sources[match_idx], match_sources[i])) + else: + n_too_far += 1 + + self.log.debug( + f"Matching: {len(matches)} matched, {n_too_far} beyond {self.config.matching_radius} arcsec" + ) + return matches + + def _applyReferenceMasks(self, matches, masked_exposure): + """Apply reference masks to matched source regions. + + Parameters + ---------- + matches : `list` of `tuple` + List of (ref_source, matching_source) matched pairs + masked_exposure : `lsst.afw.image.ExposureF` + Difference image exposure to apply masks to + + Returns + ------- + n_masked : `int` + Number of regions masked + """ + mask = masked_exposure.mask + ref_mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name) + + n_masked = 0 + + for ref_src, matching_src in matches: + # Use the source footprint from the matching source to define the mask region + footprint = matching_src.getFootprint() + if footprint is not None: + self._maskSourceFootprint(footprint, mask, ref_mask_bit) + n_masked += 1 + + return n_masked + + def _maskSourceFootprint(self, footprint, mask, mask_bit): + """Mask pixels in the source footprint. + + Parameters + ---------- + footprint : `lsst.afw.detection.Footprint` + Footprint of the source + mask : `lsst.afw.image.Mask` + Difference image mask to modify + mask_bit : `int` + Mask plane bit to set + """ + # Get the footprint pixels - since matching and difference images have same WCS, + # coordinates can be used directly + bbox = mask.getBBox() + n_clipped = 0 + n_masked = 0 + + for span in footprint.getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y): + mask.array[y, x] |= mask_bit + n_masked += 1 + else: + n_clipped += 1 + + if n_clipped > 0: + self.log.debug( + f"Footprint clipped: {n_clipped} pixels outside bounds, {n_masked} pixels masked" + ) + + def _maskConnectedRegion( + self, mask, seed_x, seed_y, detected_bit, detected_neg_bit, ref_mask_bit + ): + """Mask a connected region of DETECTED or DETECTED_NEGATIVE pixels. + + Parameters + ---------- + mask : `lsst.afw.image.Mask` + Mask to modify + seed_x, seed_y : `int` + Seed pixel coordinates + detected_bit, detected_neg_bit : `int` + Mask plane bits for DETECTED and DETECTED_NEGATIVE + ref_mask_bit : `int` + Mask plane bit for REFERENCE_MASK + """ + height, width = mask.array.shape + visited = np.zeros((height, width), dtype=bool) + + # Stack for flood fill algorithm + stack = [(seed_x, seed_y)] + target_bits = detected_bit | detected_neg_bit + + while stack: + x, y = stack.pop() + + # Check bounds and if already visited + if x < 0 or x >= width or y < 0 or y >= height or visited[y, x]: + continue + + visited[y, x] = True + mask_value = mask.array[y, x] + + # If this pixel has DETECTED or DETECTED_NEGATIVE, mark it and add neighbors + if mask_value & target_bits: + mask.array[y, x] |= ref_mask_bit + + # Add 8-connected neighbors to stack + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: # Skip center pixel + stack.append((x + dx, y + dy)) diff --git a/tests/test_desTemplate.py b/tests/test_desTemplate.py new file mode 100644 index 000000000..8d83cba80 --- /dev/null +++ b/tests/test_desTemplate.py @@ -0,0 +1,314 @@ +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import collections +import itertools +import os +import tempfile +import unittest + +import numpy as np +from astropy.table import Table + +import lsst.afw.geom +import lsst.afw.image +import lsst.afw.math +import lsst.geom +import lsst.ip.diffim +import lsst.meas.algorithms +import lsst.meas.base.tests +import lsst.pipe.base as pipeBase +import lsst.utils.tests + +from utils import generate_data_id + +# Change this to True, `setup display_ds9`, and open ds9 (or use another afw +# display backend) to show the tract/patch layouts on the image. +debug = False +if debug: + import lsst.afw.display + + display = lsst.afw.display.Display() + display.frame = 1 + + +class GetTemplateTaskTestCase(lsst.utils.tests.TestCase): + """Test that GetTemplateTask works on both one tract and multiple tract + input coadd exposures. + + Makes a synthetic exposure large enough to fit four small tracts with 2x2 + (300x300 pixel) patches each, extracts pixels for those patches by warping, + and tests GetTemplateTask's output against boxes that overlap various + combinations of one or multiple tracts. + """ + + def setUp(self): + self.scale = 0.2 # arcsec/pixel + # DES pixel scale is approximately 0.263 arcsec/pixel + self.template_scale = 0.263 + self.exposure = self._makeExposure() + + # Track temporary files for cleanup + self.temp_files = [] + self.temp_csv_file = None + + if debug: + display.image(self.exposure, "base exposure") + + def _makeExposure(self): + """Create a large image to break up into tracts and patches. + + The image will have a source every 100 pixels in x and y, and a WCS + that results in the tracts all fitting in the image, with tract=0 + in the lower left, tract=1 to the right, tract=2 above, and tract=3 + to the upper right. + """ + box = lsst.geom.Box2I( + lsst.geom.Point2I(-200, -200), lsst.geom.Point2I(800, 800) + ) + # This WCS was constructed so that tract 0 mostly fills the lower left + # quadrant of the image, and the other tracts fill the rest; slight + # extra rotation as a check on the final warp layout, scaled by 5% + # from the patch pixel scale. + cd_matrix = lsst.afw.geom.makeCdMatrix( + 1.05 * self.scale * lsst.geom.arcseconds, 93 * lsst.geom.degrees + ) + wcs = lsst.afw.geom.makeSkyWcs( + lsst.geom.Point2D(120, 150), + lsst.geom.SpherePoint(0, 0, lsst.geom.radians), + cd_matrix, + ) + dataset = lsst.meas.base.tests.TestDataset(box, wcs=wcs) + for x, y in itertools.product(np.arange(0, 500, 100), np.arange(0, 500, 100)): + dataset.addSource(1e5, lsst.geom.Point2D(x, y)) + exposure, _ = dataset.realize(2, dataset.makeMinimalSchema()) + exposure.setFilter(lsst.afw.image.FilterLabel("r", "r_03")) + return exposure + + def _checkMetadata(self, template, config, box, wcs, nPsfs): + """Check that the various metadata components were set correctly.""" + expectedBox = lsst.geom.Box2I(box) + self.assertEqual(template.getBBox(), expectedBox) + # WCS should match our exposure, not any of the coadd tracts. + self.assertEqual(template.wcs, self.exposure.wcs) + self.assertEqual(template.getXY0(), expectedBox.getMin()) + self.assertEqual(template.filter.bandLabel, "r") + self.assertEqual(template.filter.physicalLabel, "r_03") + self.assertEqual(template.psf.getComponentCount(), nPsfs) + + def _checkPixels(self, template, config, box): + """Check that the pixel values in the template are close to the + original image. + """ + # All pixels should have real values! + expectedBox = lsst.geom.Box2I(box) + + # Check that we fully filled the template + self.assertTrue(np.all(np.isfinite(template.image.array))) + + def _makeSyntheticTileExposure( + self, ra_center=0.0, dec_center=0.0, size_pixels=4096, band="r" + ): + """Create a synthetic tile exposure similar to a DES coadd tile. + + Parameters + ---------- + ra_center : float + RA center of the tile in degrees + dec_center : float + DEC center of the tile in degrees + size_pixels : int + Size of the tile in pixels (DES tiles are typically 10000x10000, + but we use smaller for testing) + band : str + Photometric band + + Returns + ------- + exposure : lsst.afw.image.ExposureF + Synthetic tile exposure with WCS, PSF, and photometric calibration + """ + # Create bounding box for the tile + bbox = lsst.geom.Box2I( + lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(size_pixels, size_pixels) + ) + + # Create WCS centered on the tile center + cd_matrix = lsst.afw.geom.makeCdMatrix( + self.template_scale * lsst.geom.arcseconds, + 0 * lsst.geom.degrees, # No rotation + ) + wcs = lsst.afw.geom.makeSkyWcs( + lsst.geom.Point2D(size_pixels / 2, size_pixels / 2), + lsst.geom.SpherePoint(ra_center, dec_center, lsst.geom.degrees), + cd_matrix, + ) + + # Create dataset with synthetic sources + dataset = lsst.meas.base.tests.TestDataset(bbox, wcs=wcs) + + # Add a grid of sources across the tile (every 500 pixels) + for x, y in itertools.product( + np.arange(500, size_pixels - 500, 500), + np.arange(500, size_pixels - 500, 500), + ): + dataset.addSource(1e5, lsst.geom.Point2D(x, y)) + + # Realize the exposure with noise + exposure, _ = dataset.realize(5.0, dataset.makeMinimalSchema()) + + # Set filter + filter_label = lsst.afw.image.FilterLabel(band=band, physical=f"{band}_03") + exposure.setFilter(filter_label) + + # Add MAGZERO to metadata (typical for DES) + metadata = exposure.getInfo().getMetadata() + metadata.set("MAGZERO", 30.0) # Typical DES zero point + + return exposure + + def _createSyntheticTileCatalog(self, bbox, wcs, band="r", tile_radius=0.02): + """Create a synthetic tile catalog CSV file for testing. + + The tile catalog will be centered on the box, with tile_radius on each side + + Parameters + ---------- + bbox : lsst.geom.Box2I + Bounding box of the region to overlap + wcs : lsst.afw.geom.SkyWcs + WCS of the region + band : str + Photometric band + tile_radius : float + Size of tile on one side + + Returns + ------- + csv_path : str + Path to the temporary CSV file + """ + # Get the center of the bbox in sky coordinates + center = wcs.pixelToSky(lsst.geom.Point2D(bbox.getCenter())) + ra_center = center.getRa().asDegrees() + dec_center = center.getDec().asDegrees() + + # Calculate tile corners (simplified rectangular approximation) + rac1 = ra_center - tile_radius + rac2 = ra_center + tile_radius + rac3 = ra_center + tile_radius + rac4 = ra_center - tile_radius + decc1 = dec_center - tile_radius + decc2 = dec_center - tile_radius + decc3 = dec_center + tile_radius + decc4 = dec_center + tile_radius + + size_pixels = int(2 * tile_radius * 3600 / self.template_scale) + # Create the synthetic tile exposure and save to FITS + tile_exposure = self._makeSyntheticTileExposure( + ra_center, dec_center, size_pixels=size_pixels, band=band + ) + + # Write to temporary FITS file + temp_fits = tempfile.NamedTemporaryFile( + suffix=".fits", delete=False, prefix="test_tile_" + ) + tile_exposure.writeFits(temp_fits.name) + temp_fits.close() + self.temp_files.append(temp_fits.name) + # Create tile catalog data + tile_data = { + "tilename": [f"TES{int(ra_center):04d}{int(dec_center):+05d}"], + "band": [band], + "survey": ["DES"], + "ra_cent": [ra_center], + "dec_cent": [dec_center], + "rac1": [rac1], + "rac2": [rac2], + "rac3": [rac3], + "rac4": [rac4], + "decc1": [decc1], + "decc2": [decc2], + "decc3": [decc3], + "decc4": [decc4], + "filepath": [temp_fits.name], + } + + # Create astropy table + table = Table(tile_data) + + # Write to temporary CSV file + temp_csv = tempfile.NamedTemporaryFile( + mode="w", suffix=".csv", delete=False, prefix="test_tiles_" + ) + table.write(temp_csv.name, format="ascii.csv", overwrite=True) + temp_csv.close() + self.temp_csv_file = temp_csv.name + + return temp_csv.name + + def tearDown(self): + """Clean up temporary files created during testing.""" + # Remove temporary FITS files + for temp_file in self.temp_files: + if os.path.exists(temp_file): + try: + os.remove(temp_file) + except Exception: + pass # Ignore cleanup errors + + # Remove temporary CSV file + if self.temp_csv_file and os.path.exists(self.temp_csv_file): + try: + os.remove(self.temp_csv_file) + except Exception: + pass # Ignore cleanup errors + + def testRunOneTractInput(self): + """Test a bounding box that fits inside single DES tract""" + + box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180)) + + # Create synthetic tile catalog with tiles that overlap the test bbox + synthetic_csv = self._createSyntheticTileCatalog( + box, self.exposure.wcs, band="r" + ) + + config = lsst.ip.diffim.DesTemplateConfig() + config.tileFile = synthetic_csv + task = lsst.ip.diffim.DesTemplateTask(config=config) + result = task.run( + bbox=box, wcs=self.exposure.wcs, band="r", physical_filter="r_03" + ) + + self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 1) + self._checkPixels(result.template, task.config, box) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() diff --git a/tests/test_maskReferenceSources.py b/tests/test_maskReferenceSources.py new file mode 100644 index 000000000..9e398da2a --- /dev/null +++ b/tests/test_maskReferenceSources.py @@ -0,0 +1,223 @@ +import unittest +import numpy as np + +import lsst.utils.tests as utilsTests +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.afw.geom as afwGeom +import lsst.afw.detection as afwDetect +from lsst.afw.geom import makeSkyWcs +from lsst.geom import Point2D, Angle, degrees, SpherePoint + +from lsst.ip.diffim.maskReferenceSources import ( + MaskReferenceSourcesTask, + MaskReferenceSourcesConfig, +) + +# ----------------- helpers ----------------- + + +def make_exposure( + w=300, h=300, scale_arcsec=0.2, crpix=(150.0, 150.0), crval_deg=(10.0, -5.0) +): + exp = afwImage.ExposureF(w, h) + scale_deg = scale_arcsec / 3600.0 + cd = np.array([[scale_deg, 0.0], [0.0, scale_deg]], dtype=float) + wcs = makeSkyWcs( + crpix=Point2D(*crpix), + crval=SpherePoint(Angle(crval_deg[0], degrees), Angle(crval_deg[1], degrees)), + cdMatrix=cd, + ) + exp.setWcs(wcs) + return exp + + +def circular_footprint(xc, yc, r=4): + spans = [] + for y in range(yc - r, yc + r + 1): + dy = y - yc + dx = int(np.floor(np.sqrt(max(0, r * r - dy * dy)))) + spans.append(afwGeom.Span(y, xc - dx, xc + dx)) + ss = afwGeom.SpanSet(spans) + return afwDetect.Footprint(ss) + + +def make_src_cat(wcs, xy_list, radius=4): + + schema = afwTable.SourceTable.makeMinimalSchema() + + xKey = schema.addField(afwTable.FieldD("centroid_x", "centroid x", "pixel")) + yKey = schema.addField(afwTable.FieldD("centroid_y", "centroid y", "pixel")) + + schema.getAliasMap().set("slot_Centroid", "centroid") + table = afwTable.SourceTable.make(schema) + + cat = afwTable.SourceCatalog(table) + for x, y in xy_list: + rec = cat.addNew() + rec.set(xKey, float(x)) + rec.set(yKey, float(y)) + rec.setCoord(wcs.pixelToSky(Point2D(float(x), float(y)))) + rec.setFootprint(circular_footprint(int(x), int(y), radius)) + return cat + + +def make_ref_cat_from_sources(wcs, src_xy, arcsec_offset, offset_axis="ra"): + """Make a reference catalog by offsetting each source's sky coord.""" + schema = afwTable.SourceTable.makeMinimalSchema() + table = afwTable.SourceTable.make(schema) + cat = afwTable.SourceCatalog(table) + deg_off = arcsec_offset / 3600.0 + for x, y in src_xy: + sp = wcs.pixelToSky(Point2D(float(x), float(y))) + ra = sp.getRa().asDegrees() + dec = sp.getDec().asDegrees() + if offset_axis == "ra": + ra += deg_off + else: + dec += deg_off + rec = cat.addNew() + rec.setCoord(SpherePoint(Angle(ra, degrees), Angle(dec, degrees))) + return cat + + +def mask_bit(exp, plane): + try: + return exp.mask.addMaskPlane(plane) + except Exception: + return exp.mask.getPlaneBitMask(plane) + + +def ensure_plane_bitmask(mask, name): + try: + bit_index = mask.addMaskPlane(name) + return 1 << bit_index + except Exception: + return mask.getPlaneBitMask(name) + + +def paint_footprint(mask, footprint, bit): + """OR the given bit into mask pixels inside the footprint (clip to bounds).""" + bbox = mask.getBBox() + for span in footprint.getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y): + mask.array[y, x] |= bit + + +# ----------------- tests ----------------- + + +class TestMaskReferenceSourcesMany(utilsTests.TestCase): + """Build 300x300 exposure with 8 DETECTED sources. + Make 5 reference objects close enough to match + Verify REFERENCE is set exactly over DETECTED for matches only. + """ + + def setUp(self): + self.exp = make_exposure(300, 300, scale_arcsec=0.2) + self.cfg = MaskReferenceSourcesConfig() + self.cfg.mask_plane_name = "REFERENCE" + self.cfg.matching_radius = 1.2 # arcsec + self.task = MaskReferenceSourcesTask(config=self.cfg) + + # 8 sources, well-separated (no overlapping footprints) + self.src_xy = [ + (40, 40), + (90, 60), + (140, 80), + (190, 100), + (240, 120), + (60, 200), + (160, 220), + (260, 240), + ] + self.src_cat = make_src_cat(self.exp.getWcs(), self.src_xy, radius=4) + + self.detected_bit = ensure_plane_bitmask(self.exp.mask, "DETECTED") + for rec in self.src_cat: + paint_footprint(self.exp.mask, rec.getFootprint(), self.detected_bit) + + # Build reference catalog: + # - First 5: +0.8" in RA (=> should match within 1.2") + # - Last 3: not in the reference catalog + refs = make_ref_cat_from_sources( + self.exp.getWcs(), self.src_xy[:5], arcsec_offset=0.8, offset_axis="ra" + ) + + # Concatenate (order shouldn't matter) + self.ref_cat = afwTable.SourceCatalog(refs.getTable()) + for rec in refs: + self.ref_cat.append(rec) + + self.reference_bit = ensure_plane_bitmask( + self.exp.mask, self.cfg.mask_plane_name + ) + + def test_masking(self): + self.assertEqual(int((self.exp.mask.array & self.reference_bit).sum()), 0) + + # Run the algorithm + result = self.task.run( + ref_catalog=self.ref_cat, + difference_image=self.exp, + matching_sources=self.src_cat, + matching_image=self.exp, + ) + out = result.masked_difference_image + mask_arr = out.mask.array + + # 1) REFERENCE must be a subset of DETECTED everywhere + ref_only = (mask_arr & self.reference_bit) & ~( + (mask_arr & self.detected_bit) > 0 + ) + self.assertEqual( + int(ref_only.sum()), 0, "REFERENCE set outside DETECTED footprint(s)" + ) + + # 2) Exactly 5 sources matched; REFERENCE pixels == sum of their footprint areas + matches = self.task._matchSources(self.ref_cat, self.src_cat, self.exp.getWcs()) + self.assertEqual(len(matches), 5, "Expected 5 matches within 1.2 arcsec") + matched_src_ids = {id(src) for _, src in matches} + + expected_pixels = 0 + for src in self.src_cat: + if id(src) in matched_src_ids: + expected_pixels += src.getFootprint().getArea() + + actual_pixels = ((mask_arr & self.reference_bit) == self.reference_bit).sum() + self.assertEqual( + actual_pixels, + expected_pixels, + "REFERENCE mask pixels should equal sum of matched footprint areas", + ) + + # 3) Unmatched sources: DETECTED-only, no REFERENCE bits inside their footprints + for src in self.src_cat: + if id(src) not in matched_src_ids: + # Collect REFERENCE pixels inside this footprint + in_fp = 0 + bbox = self.exp.mask.getBBox() + for span in src.getFootprint().getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y) and ( + mask_arr[y, x] & self.reference_bit + ): + in_fp += 1 + self.assertEqual( + in_fp, 0, "Unmatched source has REFERENCE bits in its footprint" + ) + + +# --------------- boilerplate --------------- + + +def setup_module(module): + utilsTests.init() + + +if __name__ == "__main__": + utilsTests.init() + unittest.main()