From 2436a4eadcb4c38e79bbc10bbf27db188135391e Mon Sep 17 00:00:00 2001 From: Bob Armstrong Date: Fri, 21 Nov 2025 11:55:24 -0800 Subject: [PATCH] Add tasks for ToO pipeline This includes one task to build templates using archival DECam data. The data is expected to be on disk and does not use standard butler methods for input. A text file stores tile boundaries and filepaths. It also includes a task to mask DETECTED and DETECTED_NEGATIVE pixels on an image that are connected to sources from a reference catalog. --- python/lsst/pipe/tasks/desTemplateTask.py | 711 ++++++++++++++++++ .../lsst/pipe/tasks/maskReferenceSources.py | 424 +++++++++++ tests/test_desTemplate.py | 314 ++++++++ tests/test_maskReferenceSources.py | 223 ++++++ 4 files changed, 1672 insertions(+) create mode 100644 python/lsst/pipe/tasks/desTemplateTask.py create mode 100644 python/lsst/pipe/tasks/maskReferenceSources.py create mode 100644 tests/test_desTemplate.py create mode 100644 tests/test_maskReferenceSources.py diff --git a/python/lsst/pipe/tasks/desTemplateTask.py b/python/lsst/pipe/tasks/desTemplateTask.py new file mode 100644 index 000000000..d1fb80cdd --- /dev/null +++ b/python/lsst/pipe/tasks/desTemplateTask.py @@ -0,0 +1,711 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Pipeline for generating templates from archival DES data. +""" + +__all__ = ["DesTemplateTask", + "DesTemplateConfig", + "DesTemplateConnections"] + +import numpy as np + +from astropy.table import Table +from astropy.io import fits + +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import lsst.pipe.base.connectionTypes as connectionTypes +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.afw.math as afwMath +import lsst.afw.geom as afwGeom +from lsst.geom import Box2D, Point2D +from lsst.sphgeom import ConvexPolygon, UnitVector3d, LonLat +from lsst.afw.image import makePhotoCalibFromCalibZeroPoint +from lsst.meas.algorithms.installGaussianPsf import ( + InstallGaussianPsfTask, + InstallGaussianPsfConfig, +) +from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig + +from lsst.meas.base import IdGenerator +from lsst.geom import SpherePoint, degrees +import lsst.log + + +class DesTemplateConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for DES template creation task.""" + + bbox = connectionTypes.Input( + doc="Bounding box of the exposure", + name="pvi.bbox", + storageClass="Box2I", + dimensions=("instrument", "visit", "detector"), + ) + wcs = connectionTypes.Input( + doc="WCS of the exposure", + name="pvi.wcs", + storageClass="Wcs", + dimensions=("instrument", "visit", "detector"), + ) + template = connectionTypes.Output( + doc="Template exposure created from DES tiles", + name="desTemplate", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + +class DesTemplateConfig( + pipeBase.PipelineTaskConfig, pipelineConnections=DesTemplateConnections +): + """Configuration for DES template creation task.""" + + tileFile = pexConfig.Field( + dtype=str, + default="merged_tiles_v3_deduplicated.csv", + doc="Path to merged CSV file containing tile information with columns: " + "tilename, band, survey, ra_cent, dec_cent, rac1-4, decc1-4, filepath", + ) + dataSourcePriority = pexConfig.ListField( + dtype=str, + default=["DELVE", "DES", "DECADE"], + doc="Priority order for surveys when duplicate tiles exist (first = highest priority)", + ) + modelPsf = pexConfig.Field( + dtype=bool, + default=False, + doc="Use modeled PSF instead of Gaussian PSF for templates", + ) + gaussianPsfFwhm = pexConfig.Field( + dtype=float, + default=4.0, + doc="FWHM for Gaussian PSF in pixels (used when modelPsf=False)", + ) + gaussianPsfWidth = pexConfig.Field( + dtype=int, + default=21, + doc="Width for Gaussian PSF in pixels (used when modelPsf=False)", + ) + hasInverseVariancePlane = pexConfig.Field( + dtype=bool, + default=True, + doc="Do the template files have a weight map or inverse variance plane", + ) + + +class DesTemplateTask(pipeBase.PipelineTask): + """Create template exposures from survey tile data.""" + + ConfigClass = DesTemplateConfig + _DefaultName = "desTemplate" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Suppress FITS reader warnings for survey tiles + lsst.log.setLevel("lsst.afw.image.MaskedImageFitsReader", lsst.log.ERROR) + + if self.config.modelPsf: + from lsst.pipe.tasks.calibrateImage import ( + CalibrateImageTask, + CalibrateImageConfig, + ) + + cal_config = CalibrateImageConfig() + cal_config.psf_repair.doInterpolate = False + cal_config.psf_repair.doCosmicRay = False + cal_config.psf_detection.thresholdValue = 100.0 + self.calibrateTask = CalibrateImageTask(config=cal_config) + else: + install_psf_config = InstallGaussianPsfConfig() + install_psf_config.fwhm = self.config.gaussianPsfFwhm + install_psf_config.width = self.config.gaussianPsfWidth + self.installPsfTask = InstallGaussianPsfTask(config=install_psf_config) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Extract band and physical_filter from dataId + band = butlerQC.quantum.dataId["band"] + physical_filter = butlerQC.quantum.dataId["physical_filter"] + + inputs = butlerQC.get(inputRefs) + outputs = self.run(band=band, physical_filter=physical_filter, **inputs) + butlerQC.put(outputs, outputRefs) + + def run(self, bbox, wcs, band, physical_filter): + """Create template from survey tiles overlapping the target region. + + Parameters + ---------- + bbox : `lsst.geom.Box2I` + Bounding box of the target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of the target region + band : `str` + Photometric band (e.g., 'r', 'g', 'i') + physical_filter : `str` + Physical filter name (e.g., 'r_03', 'g_01') + + Returns + ------- + result : `lsst.pipe.base.Struct` + Results struct with: + - ``template`` : Template exposure created from survey tiles + """ + self.log.info("Creating survey template for target region") + self.log.info(f"Using band '{band}' and physical_filter '{physical_filter}'") + + # Load tile catalog and filter by band + tile_table = Table.read(self.config.tileFile, format="ascii") + self.log.info(f"Loaded {len(tile_table)} total tiles from catalog") + + # Validate required columns exist + required_columns = [ + "tilename", + "band", + "survey", + "ra_cent", + "dec_cent", + "rac1", + "rac2", + "rac3", + "rac4", + "decc1", + "decc2", + "decc3", + "decc4", + "filepath", + ] + missing_columns = [ + col for col in required_columns if col not in tile_table.colnames + ] + if missing_columns: + raise RuntimeError( + f"Required columns missing from tile table: {missing_columns}. " + "Please use the merged tiles CSV format." + ) + + # Filter by band first + band_filtered_tiles = self._filterTilesByBand(tile_table, band) + self.log.info(f"Found {len(band_filtered_tiles)} tiles for band '{band}'") + + if len(band_filtered_tiles) == 0: + raise RuntimeError(f"No tiles found for band '{band}' in catalog") + + # Find overlapping tiles + overlapping_tiles = self._findOverlappingTiles(bbox, wcs, band_filtered_tiles) + self.log.info(f"Found {len(overlapping_tiles)} overlapping tiles") + + if not overlapping_tiles: + center = wcs.pixelToSky(Point2D(bbox.getCenter())) + target_ra = center.getRa().asDegrees() + target_dec = center.getDec().asDegrees() + self.log.info( + f"Target was centered at RA={target_ra:.3f}, Dec={target_dec:.3f}" + ) + raise RuntimeError("No survey tiles found overlapping with target region") + + # Resolve duplicate tiles and load template exposures + resolved_tiles = self._resolveDuplicateTiles( + overlapping_tiles, band_filtered_tiles + ) + tile_exposures, tile_names = self._loadTileTemplates(resolved_tiles) + self.log.info(f"Successfully loaded {len(tile_exposures)} template exposures") + + if not tile_exposures: + raise RuntimeError("No valid template exposures found") + + # Create coadded template + template = self._createCoaddFromTiles( + tile_exposures, tile_names, wcs, bbox, physical_filter + ) + + self.log.info("Successfully created survey template") + return pipeBase.Struct(template=template) + + def _tileOverlapsRegion(self, tile_corners, bbox, wcs): + """Check if a tile overlaps with the target region using spherical geometry. + + Parameters + ---------- + tile_corners : `list` of `tuple` + List of (ra, dec) tuples for tile corners in degrees + bbox : `lsst.geom.Box2I` + Bounding box of target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of target region + + Returns + ------- + overlaps : `bool` + True if tile overlaps with target region, False otherwise + """ + # Get target region corners + region_corners = [wcs.pixelToSky(Point2D(p)) for p in bbox.getCorners()] + + # Convert tile corners to unit vectors + tile_unit_vectors = [ + UnitVector3d(LonLat.fromDegrees(ra, dec)) for ra, dec in tile_corners + ] + + # Convert region corners to unit vectors + region_unit_vectors = [ + UnitVector3d( + LonLat.fromDegrees(s.getRa().asDegrees(), s.getDec().asDegrees()) + ) + for s in region_corners + ] + + # Create convex polygons + tile_polygon = ConvexPolygon(tile_unit_vectors) + region_polygon = ConvexPolygon(region_unit_vectors) + + # Check for overlap + return tile_polygon.overlaps(region_polygon) + + def _findOverlappingTiles(self, bbox, wcs, tile_table): + """Find tiles that overlap with the target region. + + Parameters + ---------- + bbox : `lsst.geom.Box2I` + Bounding box of target region + wcs : `lsst.afw.geom.SkyWcs` + WCS of target region + tile_table : `astropy.table.Table` + Table containing tile information + + Returns + ------- + overlapping_tiles : `list` of `tuple` + List of (tile_name, distance_deg) for overlapping tiles, sorted by distance + """ + # Get target region center for distance calculation + center = wcs.pixelToSky(Point2D(bbox.getCenter())) + target_ra = center.getRa().asDegrees() + target_dec = center.getDec().asDegrees() + p0 = SpherePoint(target_ra * degrees, target_dec * degrees) + overlapping_tiles = [] + total_tiles_checked = 0 + + self.log.info(f"Checking {len(tile_table)} tiles for overlap with exposure") + + for row in tile_table: + tile_name = row["tilename"] + total_tiles_checked += 1 + + # Extract tile corners from table + required_corner_columns = [ + "rac1", + "decc1", + "rac2", + "decc2", + "rac3", + "decc3", + "rac4", + "decc4", + ] + + if not all(col in tile_table.colnames for col in required_corner_columns): + missing_corners = [ + col + for col in required_corner_columns + if col not in tile_table.colnames + ] + raise RuntimeError( + f"Required corner coordinate columns missing from tile table: {missing_corners}" + ) + + # Use actual corner coordinates + tile_corners = [ + (row["rac1"], row["decc1"]), + (row["rac2"], row["decc2"]), + (row["rac3"], row["decc3"]), + (row["rac4"], row["decc4"]), + ] + + # Check if tile overlaps with target region + has_overlap = self._tileOverlapsRegion(tile_corners, bbox, wcs) + + if has_overlap: + # Calculate distance to tile center for sorting + tile_ra = row["ra_cent"] + tile_dec = row["dec_cent"] + + p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees) + distance = p0.separation(p1).asDegrees() + + self.log.info(f" OVERLAP: {tile_name} - distance={distance:.3f}°") + + overlapping_tiles.append((tile_name, distance)) + else: + # Log nearby tiles that don't overlap for debugging + tile_ra = row["ra_cent"] + tile_dec = row["dec_cent"] + p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees) + distance = p0.separation(p1).asDegrees() + + if distance < 2.0: # Only log tiles within 2 degrees + self.log.info( + f" No overlap: {tile_name} - distance={distance:.3f}° (RA={tile_ra:.3f}, Dec={tile_dec:.3f})" + ) + + self.log.info( + f"Checked {total_tiles_checked} tiles, found {len(overlapping_tiles)} with overlap" + ) + + # Sort by distance + overlapping_tiles.sort(key=lambda x: x[1]) + + # Return the overlapping tiles + return overlapping_tiles + + def _filterTilesByBand(self, tile_table, band): + """Filter tiles by the requested band and validate survey values. + + Parameters + ---------- + tile_table : `astropy.table.Table` + Table containing all tiles + band : `str` + Photometric band to filter for + + Returns + ------- + filtered_table : `astropy.table.Table` + Table containing only tiles for the specified band with valid surveys + """ + if "band" not in tile_table.colnames: + raise RuntimeError( + "Required 'band' column not found in tile table. " + "Please use the merged tiles CSV format." + ) + + if "survey" not in tile_table.colnames: + raise RuntimeError( + "Required 'survey' column not found in tile table. " + "Please use the merged tiles CSV format." + ) + + # First filter by band + band_mask = tile_table["band"] == band + band_filtered = tile_table[band_mask] + + # Then validate survey values + valid_surveys = set(self.config.dataSourcePriority) + survey_column = band_filtered["survey"] + + # Vectorized membership check (case-sensitive) + valid_mask = np.isin(survey_column, list(valid_surveys)) + + invalid_count = len(band_filtered) - np.sum(valid_mask) + if invalid_count > 0: + self.log.warn( + f"Found {invalid_count} tiles with unrecognized survey values, ignoring them" + ) + + filtered_table = band_filtered[valid_mask] + + self.log.info( + f"Filtered {len(tile_table)} tiles to {len(filtered_table)} for band '{band}' with valid surveys" + ) + return filtered_table + + def _resolveDuplicateTiles(self, overlapping_tiles, band_filtered_tiles): + """Resolve duplicate tiles based on survey priority. + + Parameters + ---------- + overlapping_tiles : `list` of `tuple` + List of (tile_name, distance) tuples for overlapping tiles + band_filtered_tiles : `astropy.table.Table` + Band-filtered tile table with filepath and survey columns + + Returns + ------- + resolved_tiles : `list` of `dict` + List of resolved tile dictionaries with tilename, distance, filepath, survey + """ + # Get priority order from config + priority_order = self.config.dataSourcePriority + + resolved_tiles = [] + + for tile_name, distance in overlapping_tiles: + # Find all entries for this tilename in the band-filtered table + tile_mask = band_filtered_tiles["tilename"] == tile_name + tile_entries = band_filtered_tiles[tile_mask] + + if len(tile_entries) == 0: + self.log.warn(f"No entries found for overlapping tile {tile_name}") + continue + elif len(tile_entries) == 1: + # No duplicates, use the single entry + entry = tile_entries[0] + resolved_tiles.append( + { + "tilename": tile_name, + "distance": distance, + "filepath": entry["filepath"], + "survey": entry["survey"], + } + ) + else: + # Multiple entries, resolve based on priority order + selected_entry = None + + # Check for each survey in priority order + for preferred_survey in priority_order: + survey_entries = tile_entries[ + tile_entries["survey"] == preferred_survey + ] + if len(survey_entries) > 0: + selected_entry = survey_entries[0] + self.log.info( + f"Using {preferred_survey} tile for {tile_name} (priority)" + ) + break + + if selected_entry is None: + # No recognized survey, use first available + selected_entry = tile_entries[0] + self.log.info( + f"Using {selected_entry['survey']} tile for {tile_name} (fallback)" + ) + + resolved_tiles.append( + { + "tilename": tile_name, + "distance": distance, + "filepath": selected_entry["filepath"], + "survey": selected_entry["survey"], + } + ) + + self.log.info( + f"Resolved {len(resolved_tiles)} tiles from {len(overlapping_tiles)} overlapping tiles" + ) + return resolved_tiles + + def _loadTileTemplates(self, resolved_tiles): + """Load template exposures from resolved tiles. + + Parameters + ---------- + resolved_tiles : `list` of `dict` + List of resolved tile dictionaries with tilename, distance, filepath, survey + + Returns + ------- + exposures : `list` of `lsst.afw.image.ExposureF` + List of loaded template exposures + tile_names : `list` of `str` + List of tile names corresponding to exposures + """ + exposures = [] + tile_names = [] + + for tile_info in resolved_tiles: + tile_name = tile_info["tilename"] + filepath = tile_info["filepath"] + survey = tile_info["survey"] + + try: + self.log.info(f"Loading {survey} tile {tile_name} from {filepath}") + exp = afwImage.ExposureF(filepath) + + if self.config.hasInverseVariancePlane: + # Need convert DES weight map to variance map for LSST exposure + mi = exp.maskedImage + bad_weight = mi.variance.array <= 0.0 + mi.variance.array[~bad_weight] = ( + 1.0 / mi.variance.array[~bad_weight] + ) + mi.variance.array[bad_weight] = np.nan + mi.mask.array[bad_weight] |= mi.mask.getPlaneBitMask("NO_DATA") + + # Set PSF + if not self.config.modelPsf: + # Install Gaussian PSF + self.installPsfTask.run(exp) + else: + # Use modeled PSF + cat = self.calibrateTask._compute_psf(exp, IdGenerator()) + + # Set photometric calibration + metadata = exp.getInfo().getMetadata() + if "MAGZERO" in metadata: + zp = metadata["MAGZERO"] + flux0 = 10 ** (0.4 * zp) + calib = makePhotoCalibFromCalibZeroPoint(flux0, 0.0) + exp.setPhotoCalib(calib) + else: + self.log.warn( + f"No MAGZERO found in {tile_name}, using default calibration" + ) + + exposures.append(exp) + tile_names.append(tile_name) + + except Exception as e: + self.log.warn(f"Failed to load tile {tile_name}: {e}") + continue + + return exposures, tile_names + + def _createCoaddFromTiles( + self, tile_exposures, tile_names, target_wcs, target_bbox, physical_filter + ): + """Create coadded template from tile exposures. + + Parameters + ---------- + tile_exposures : `list` of `lsst.afw.image.ExposureF` + List of tile exposures + tile_names : `list` of `str` + List of tile names + target_wcs : `lsst.afw.geom.SkyWcs` + Target WCS for template + target_bbox : `lsst.geom.Box2I` + Target bounding box for template + physical_filter : `str` + Physical filter name for the template + + Returns + ------- + template : `lsst.afw.image.ExposureF` + Coadded template exposure + """ + warper = afwMath.Warper.fromConfig(afwMath.Warper.ConfigClass()) + + # Create tile catalog for PSF computation + tile_schema = afwTable.ExposureTable.makeMinimalSchema() + tile_key = tile_schema.addField("tile", type="String", size=12) + weight_key = tile_schema.addField("weight", type=float) + + coadd_config = CoaddPsfConfig() + + # Statistics configuration + stats_flags = afwMath.stringToStatisticsProperty("MEAN") + stats_ctrl = afwMath.StatisticsControl() + stats_ctrl.setNanSafe(True) + stats_ctrl.setWeighted(True) + stats_ctrl.setCalcErrorMosaicMode(True) + + tile_catalog = afwTable.ExposureCatalog(tile_schema) + masked_images = [] + weights = [] + + for i, (exp, tile) in enumerate(zip(tile_exposures, tile_names)): + self.log.debug(f"Processing tile {tile} ({i+1}/{len(tile_exposures)})") + + # Warp template to target coordinate system + warped = warper.warpExposure(target_wcs, exp, maxBBox=target_bbox) + if warped.getBBox().getArea() == 0: + self.log.debug(f"Skipping tile {tile}: no overlap after warping") + continue + + # Create properly initialized exposure + aligned_exp = afwImage.ExposureF(target_bbox, target_wcs) + aligned_exp.maskedImage.set( + np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan + ) + aligned_exp.maskedImage.assign(warped.maskedImage, warped.getBBox()) + masked_images.append(aligned_exp.maskedImage) + + # Calculate weight (inverse of mean variance) + var_array = aligned_exp.variance.array + finite_var = var_array[np.isfinite(var_array)] + if len(finite_var) > 0: + mean_var = np.mean(finite_var) + weight = 1.0 / mean_var if mean_var > 0 else 1.0 + else: + weight = 1.0 + weights.append(weight) + + # Add to tile catalog for PSF computation + record = tile_catalog.addNew() + record.set(tile_key, tile) + record.set(weight_key, weight) + record.setPsf(exp.getPsf()) + record.setWcs(exp.getWcs()) + record.setPhotoCalib(exp.getPhotoCalib()) + record.setBBox(exp.getBBox()) + + polygon = afwGeom.Polygon(Box2D(exp.getBBox()).getCorners()) + record.setValidPolygon(polygon) + + if not masked_images: + raise RuntimeError("No valid warped images for coadd creation") + + # Create coadd exposure + coadd = afwImage.ExposureF(target_bbox, target_wcs) + coadd.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) + xy0 = coadd.getXY0() + + # Perform statistical stacking + coadd.maskedImage = afwMath.statisticsStack( + masked_images, stats_flags, stats_ctrl, weights, clipped=0, maskMap=[] + ) + coadd.maskedImage.setXY0(xy0) + + # Create and set PSF + if len(tile_catalog) > 0: + valid_mask = ( + coadd.maskedImage.mask.array + & coadd.maskedImage.mask.getPlaneBitMask("NO_DATA") + ) == 0 + if np.any(valid_mask): + mask_for_centroid = afwImage.makeMaskFromArray( + valid_mask.astype(afwImage.MaskPixel) + ) + psf_center = afwGeom.SpanSet.fromMask( + mask_for_centroid, 1 + ).computeCentroid() + + ctrl = coadd_config.makeControl() + coadd_psf = CoaddPsf( + tile_catalog, + target_wcs, + psf_center, + ctrl.warpingKernelName, + ctrl.cacheSize, + ) + coadd.setPsf(coadd_psf) + + # Set calibration and filter from first tile exposure and physical_filter + if tile_exposures: + coadd.setPhotoCalib(tile_exposures[0].getPhotoCalib()) + # Create filter label from physical_filter + # Extract band from physical_filter (e.g., 'r_03' -> 'r') + band = ( + physical_filter.split("_")[0] + if "_" in physical_filter + else physical_filter + ) + filter_label = afwImage.FilterLabel(band=band, physical=physical_filter) + coadd.setFilter(filter_label) + + return coadd diff --git a/python/lsst/pipe/tasks/maskReferenceSources.py b/python/lsst/pipe/tasks/maskReferenceSources.py new file mode 100644 index 000000000..a6b458c03 --- /dev/null +++ b/python/lsst/pipe/tasks/maskReferenceSources.py @@ -0,0 +1,424 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Pipeline for masking DIA sources based on a reference catalog +""" + +__all__ = ["MaskReferenceSourcesTask", + "MaskReferenceSourcesConfig", + "MaskReferenceSourcesConnections"] + +import numpy as np +import logging + +from astropy.coordinates import SkyCoord +import astropy.units as u + +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import lsst.pipe.base.connectionTypes as connectionTypes +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.meas.algorithms + + +class MaskReferenceSourcesConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for MaskReferenceSources task.""" + + # Input reference catalog + astrometry_ref_cat = connectionTypes.PrerequisiteInput( + doc="Reference catalog to use for source matching", + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + dimensions=("skypix",), + deferLoad=True, + multiple=True, + ) + + # Input image for matching (default: science image) + matching_image = connectionTypes.Input( + doc="Image to use for reference source matching", + name="calexp", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + # Input sources for matching (default: science sources) + matching_sources = connectionTypes.Input( + doc="Sources detected in matching image", + name="src", + storageClass="SourceCatalog", + dimensions=("instrument", "visit", "detector"), + ) + + # Input difference image to mask + difference_image = connectionTypes.Input( + doc="Difference image to apply masks to", + name="difference_image", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + # Output masked difference image + masked_difference_image = connectionTypes.Output( + doc="Difference image with reference source mask", + name="difference_image_masked", + storageClass="ExposureF", + dimensions=("instrument", "visit", "detector"), + ) + + +class MaskReferenceSourcesConfig( + pipeBase.PipelineTaskConfig, pipelineConnections=MaskReferenceSourcesConnections +): + """Configuration for MaskReferenceSources task.""" + + matching_radius = pexConfig.Field( + dtype=float, + default=1.0, + doc="Maximum separation for source matching in arcseconds", + ) + reference_mag_column = pexConfig.Field( + dtype=str, + default="r_flux", + doc="Column name for reference catalog flux (in nJy)", + ) + reference_ra_column = pexConfig.Field( + dtype=str, + default="coord_ra", + doc="Column name for reference catalog RA in degrees", + ) + reference_dec_column = pexConfig.Field( + dtype=str, + default="coord_dec", + doc="Column name for reference catalog Dec in degrees", + ) + mask_plane_name = pexConfig.Field( + dtype=str, + default="REFERENCE", + doc="Name of mask plane to create for reference source regions", + ) + reference_buffer = pexConfig.Field( + dtype=int, + default=100, + doc="Buffer in pixels to add around image bounds when loading reference catalog", + ) + astrometry_ref_loader = pexConfig.ConfigField( + dtype=lsst.meas.algorithms.LoadReferenceObjectsConfig, + doc="Configuration of reference object loader for source matching", + ) + + def setDefaults(self): + super().setDefaults() + # Set to Gaii, but we don't use photometry currently + self.astrometry_ref_loader.filterMap = { + "u": "phot_g_mean", + "g": "phot_g_mean", + "r": "phot_g_mean", + "i": "phot_g_mean", + "z": "phot_g_mean", + "y": "phot_g_mean", + } + + +class MaskReferenceSourcesTask(pipeBase.PipelineTask): + """Mask regions around reference sources in difference images.""" + + ConfigClass = MaskReferenceSourcesConfig + _DefaultName = "maskReferenceSources" + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + """Run the task on quantum data.""" + inputs = butlerQC.get(inputRefs) + + # Create reference object loader following the standard pattern + astrometry_loader = lsst.meas.algorithms.ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.astrometry_ref_cat], + refCats=inputs.pop("astrometry_ref_cat"), + name=self.config.connections.astrometry_ref_cat, + config=self.config.astrometry_ref_loader, + log=self.log, + ) + + # Load reference catalog using loadPixelBox with buffer for edge sources + original_bbox = inputs["matching_image"].getBBox() + buffered_bbox = original_bbox.dilatedBy(self.config.reference_buffer) + + self.log.info( + f"Loading reference catalog for bbox: {original_bbox} (buffered to {buffered_bbox})" + ) + + ref_result = astrometry_loader.loadPixelBox( + bbox=buffered_bbox, + wcs=inputs["matching_image"].getWcs(), + filterName=inputs["matching_image"].getFilter().bandLabel, + ) + + self.log.info(f"Loaded {len(ref_result.refCat)} reference sources") + + outputs = self.run( + ref_catalog=ref_result.refCat, + difference_image=inputs["difference_image"], + matching_sources=inputs["matching_sources"], + matching_image=inputs["matching_image"], + ) + + butlerQC.put(outputs, outputRefs) + + def run(self, ref_catalog, difference_image, matching_sources, matching_image): + """Mask reference sources in difference image. + + Parameters + ---------- + ref_catalog : `lsst.afw.table.SourceCatalog` + Reference catalog with sources to mask + difference_image : `lsst.afw.image.ExposureF` + Difference image to apply masks to + matching_sources : `lsst.afw.table.SourceCatalog` + Sources to match against reference catalog + matching_image : `lsst.afw.image.ExposureF` + Image where matching sources were detected + + Returns + ------- + result : `lsst.pipe.base.Struct` + Results struct with: + - ``masked_difference_image`` : Modified difference image with reference mask + """ + self.log.info(f"Masking reference sources in difference image") + self.log.info(f"Reference catalog has {len(ref_catalog)} sources") + self.log.info(f"Matching against {len(matching_sources)} detected sources") + + # Create copy of difference image to modify + masked_diff = difference_image.clone() + + # Add mask plane if it doesn't exist + mask = masked_diff.mask + try: + mask_bit = mask.addMaskPlane(self.config.mask_plane_name) + self.log.info( + f"Added mask plane '{self.config.mask_plane_name}' with bit {mask_bit}" + ) + except Exception: + # Mask plane already exists + mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name) + self.log.info( + f"Using existing mask plane '{self.config.mask_plane_name}' with bit {mask_bit}" + ) + + # Match reference sources to the matching sources + matches = self._matchSources( + ref_catalog, matching_sources, matching_image.getWcs() + ) + self.log.info(f"Found {len(matches)} matched sources") + + if len(matches) == 0: + self.log.warn(f"No reference sources matched to detected sources") + return pipeBase.Struct(masked_difference_image=masked_diff) + + # Apply masks for matched sources + n_masked = self._applyReferenceMasks(matches, masked_diff) + self.log.info(f"Masked {n_masked} reference source regions") + + return pipeBase.Struct(masked_difference_image=masked_diff) + + def _matchSources(self, ref_catalog, matching_sources, wcs): + """Match reference catalog sources to detected sources using Astropy SkyCoord. + + Parameters + ---------- + ref_catalog : `lsst.afw.table.SourceCatalog` + Reference catalog sources + matching_sources : `lsst.afw.table.SourceCatalog` + Detected sources to match against + wcs : `lsst.afw.geom.SkyWcs` + WCS for coordinate conversion + + Returns + ------- + matches : `list` of `tuple` + List of (ref_source, matching_source) pairs that match + """ + if len(ref_catalog) == 0 or len(matching_sources) == 0: + return [] + + # Extract reference source coordinates + ref_ras = [] + ref_decs = [] + ref_sources = [] + + for ref_src in ref_catalog: + if ( + self.config.reference_ra_column in ref_src.schema + and self.config.reference_dec_column in ref_src.schema + ): + # Convert to degrees - these might be Angle objects + ra = ref_src[self.config.reference_ra_column].asDegrees() + dec = ref_src[self.config.reference_dec_column].asDegrees() + else: + # Fallback to coord if specific columns not available + coord = ref_src.getCoord() + ra = coord.getRa().asDegrees() + dec = coord.getDec().asDegrees() + + ref_ras.append(ra) + ref_decs.append(dec) + ref_sources.append(ref_src) + + # Extract detected source coordinates + match_ras = [] + match_decs = [] + match_sources = [] + + for src in matching_sources: + coord = wcs.pixelToSky(Point2D(src.getX(), src.getY())) + match_ras.append(coord.getRa().asDegrees()) + match_decs.append(coord.getDec().asDegrees()) + match_sources.append(src) + + # Create SkyCoord objects + ref_coords = SkyCoord(ref_ras * u.deg, ref_decs * u.deg) + match_coords = SkyCoord(match_ras * u.deg, match_decs * u.deg) + + # Perform matching with maximum separation + max_sep = self.config.matching_radius * u.arcsec + idx, d2d, d3d = match_coords.match_to_catalog_sky(ref_coords) + + # Build matches list for sources within matching radius + matches = [] + n_too_far = 0 + for i, (match_idx, sep) in enumerate(zip(idx, d2d)): + if sep <= max_sep: + matches.append((ref_sources[match_idx], match_sources[i])) + else: + n_too_far += 1 + + self.log.debug( + f"Matching: {len(matches)} matched, {n_too_far} beyond {self.config.matching_radius} arcsec" + ) + return matches + + def _applyReferenceMasks(self, matches, masked_exposure): + """Apply reference masks to matched source regions. + + Parameters + ---------- + matches : `list` of `tuple` + List of (ref_source, matching_source) matched pairs + masked_exposure : `lsst.afw.image.ExposureF` + Difference image exposure to apply masks to + + Returns + ------- + n_masked : `int` + Number of regions masked + """ + mask = masked_exposure.mask + ref_mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name) + + n_masked = 0 + + for ref_src, matching_src in matches: + # Use the source footprint from the matching source to define the mask region + footprint = matching_src.getFootprint() + if footprint is not None: + self._maskSourceFootprint(footprint, mask, ref_mask_bit) + n_masked += 1 + + return n_masked + + def _maskSourceFootprint(self, footprint, mask, mask_bit): + """Mask pixels in the source footprint. + + Parameters + ---------- + footprint : `lsst.afw.detection.Footprint` + Footprint of the source + mask : `lsst.afw.image.Mask` + Difference image mask to modify + mask_bit : `int` + Mask plane bit to set + """ + # Get the footprint pixels - since matching and difference images have same WCS, + # coordinates can be used directly + bbox = mask.getBBox() + n_clipped = 0 + n_masked = 0 + + for span in footprint.getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y): + mask.array[y, x] |= mask_bit + n_masked += 1 + else: + n_clipped += 1 + + if n_clipped > 0: + self.log.debug( + f"Footprint clipped: {n_clipped} pixels outside bounds, {n_masked} pixels masked" + ) + + def _maskConnectedRegion( + self, mask, seed_x, seed_y, detected_bit, detected_neg_bit, ref_mask_bit + ): + """Mask a connected region of DETECTED or DETECTED_NEGATIVE pixels. + + Parameters + ---------- + mask : `lsst.afw.image.Mask` + Mask to modify + seed_x, seed_y : `int` + Seed pixel coordinates + detected_bit, detected_neg_bit : `int` + Mask plane bits for DETECTED and DETECTED_NEGATIVE + ref_mask_bit : `int` + Mask plane bit for REFERENCE_MASK + """ + height, width = mask.array.shape + visited = np.zeros((height, width), dtype=bool) + + # Stack for flood fill algorithm + stack = [(seed_x, seed_y)] + target_bits = detected_bit | detected_neg_bit + + while stack: + x, y = stack.pop() + + # Check bounds and if already visited + if x < 0 or x >= width or y < 0 or y >= height or visited[y, x]: + continue + + visited[y, x] = True + mask_value = mask.array[y, x] + + # If this pixel has DETECTED or DETECTED_NEGATIVE, mark it and add neighbors + if mask_value & target_bits: + mask.array[y, x] |= ref_mask_bit + + # Add 8-connected neighbors to stack + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: # Skip center pixel + stack.append((x + dx, y + dy)) diff --git a/tests/test_desTemplate.py b/tests/test_desTemplate.py new file mode 100644 index 000000000..8d83cba80 --- /dev/null +++ b/tests/test_desTemplate.py @@ -0,0 +1,314 @@ +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import collections +import itertools +import os +import tempfile +import unittest + +import numpy as np +from astropy.table import Table + +import lsst.afw.geom +import lsst.afw.image +import lsst.afw.math +import lsst.geom +import lsst.ip.diffim +import lsst.meas.algorithms +import lsst.meas.base.tests +import lsst.pipe.base as pipeBase +import lsst.utils.tests + +from utils import generate_data_id + +# Change this to True, `setup display_ds9`, and open ds9 (or use another afw +# display backend) to show the tract/patch layouts on the image. +debug = False +if debug: + import lsst.afw.display + + display = lsst.afw.display.Display() + display.frame = 1 + + +class GetTemplateTaskTestCase(lsst.utils.tests.TestCase): + """Test that GetTemplateTask works on both one tract and multiple tract + input coadd exposures. + + Makes a synthetic exposure large enough to fit four small tracts with 2x2 + (300x300 pixel) patches each, extracts pixels for those patches by warping, + and tests GetTemplateTask's output against boxes that overlap various + combinations of one or multiple tracts. + """ + + def setUp(self): + self.scale = 0.2 # arcsec/pixel + # DES pixel scale is approximately 0.263 arcsec/pixel + self.template_scale = 0.263 + self.exposure = self._makeExposure() + + # Track temporary files for cleanup + self.temp_files = [] + self.temp_csv_file = None + + if debug: + display.image(self.exposure, "base exposure") + + def _makeExposure(self): + """Create a large image to break up into tracts and patches. + + The image will have a source every 100 pixels in x and y, and a WCS + that results in the tracts all fitting in the image, with tract=0 + in the lower left, tract=1 to the right, tract=2 above, and tract=3 + to the upper right. + """ + box = lsst.geom.Box2I( + lsst.geom.Point2I(-200, -200), lsst.geom.Point2I(800, 800) + ) + # This WCS was constructed so that tract 0 mostly fills the lower left + # quadrant of the image, and the other tracts fill the rest; slight + # extra rotation as a check on the final warp layout, scaled by 5% + # from the patch pixel scale. + cd_matrix = lsst.afw.geom.makeCdMatrix( + 1.05 * self.scale * lsst.geom.arcseconds, 93 * lsst.geom.degrees + ) + wcs = lsst.afw.geom.makeSkyWcs( + lsst.geom.Point2D(120, 150), + lsst.geom.SpherePoint(0, 0, lsst.geom.radians), + cd_matrix, + ) + dataset = lsst.meas.base.tests.TestDataset(box, wcs=wcs) + for x, y in itertools.product(np.arange(0, 500, 100), np.arange(0, 500, 100)): + dataset.addSource(1e5, lsst.geom.Point2D(x, y)) + exposure, _ = dataset.realize(2, dataset.makeMinimalSchema()) + exposure.setFilter(lsst.afw.image.FilterLabel("r", "r_03")) + return exposure + + def _checkMetadata(self, template, config, box, wcs, nPsfs): + """Check that the various metadata components were set correctly.""" + expectedBox = lsst.geom.Box2I(box) + self.assertEqual(template.getBBox(), expectedBox) + # WCS should match our exposure, not any of the coadd tracts. + self.assertEqual(template.wcs, self.exposure.wcs) + self.assertEqual(template.getXY0(), expectedBox.getMin()) + self.assertEqual(template.filter.bandLabel, "r") + self.assertEqual(template.filter.physicalLabel, "r_03") + self.assertEqual(template.psf.getComponentCount(), nPsfs) + + def _checkPixels(self, template, config, box): + """Check that the pixel values in the template are close to the + original image. + """ + # All pixels should have real values! + expectedBox = lsst.geom.Box2I(box) + + # Check that we fully filled the template + self.assertTrue(np.all(np.isfinite(template.image.array))) + + def _makeSyntheticTileExposure( + self, ra_center=0.0, dec_center=0.0, size_pixels=4096, band="r" + ): + """Create a synthetic tile exposure similar to a DES coadd tile. + + Parameters + ---------- + ra_center : float + RA center of the tile in degrees + dec_center : float + DEC center of the tile in degrees + size_pixels : int + Size of the tile in pixels (DES tiles are typically 10000x10000, + but we use smaller for testing) + band : str + Photometric band + + Returns + ------- + exposure : lsst.afw.image.ExposureF + Synthetic tile exposure with WCS, PSF, and photometric calibration + """ + # Create bounding box for the tile + bbox = lsst.geom.Box2I( + lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(size_pixels, size_pixels) + ) + + # Create WCS centered on the tile center + cd_matrix = lsst.afw.geom.makeCdMatrix( + self.template_scale * lsst.geom.arcseconds, + 0 * lsst.geom.degrees, # No rotation + ) + wcs = lsst.afw.geom.makeSkyWcs( + lsst.geom.Point2D(size_pixels / 2, size_pixels / 2), + lsst.geom.SpherePoint(ra_center, dec_center, lsst.geom.degrees), + cd_matrix, + ) + + # Create dataset with synthetic sources + dataset = lsst.meas.base.tests.TestDataset(bbox, wcs=wcs) + + # Add a grid of sources across the tile (every 500 pixels) + for x, y in itertools.product( + np.arange(500, size_pixels - 500, 500), + np.arange(500, size_pixels - 500, 500), + ): + dataset.addSource(1e5, lsst.geom.Point2D(x, y)) + + # Realize the exposure with noise + exposure, _ = dataset.realize(5.0, dataset.makeMinimalSchema()) + + # Set filter + filter_label = lsst.afw.image.FilterLabel(band=band, physical=f"{band}_03") + exposure.setFilter(filter_label) + + # Add MAGZERO to metadata (typical for DES) + metadata = exposure.getInfo().getMetadata() + metadata.set("MAGZERO", 30.0) # Typical DES zero point + + return exposure + + def _createSyntheticTileCatalog(self, bbox, wcs, band="r", tile_radius=0.02): + """Create a synthetic tile catalog CSV file for testing. + + The tile catalog will be centered on the box, with tile_radius on each side + + Parameters + ---------- + bbox : lsst.geom.Box2I + Bounding box of the region to overlap + wcs : lsst.afw.geom.SkyWcs + WCS of the region + band : str + Photometric band + tile_radius : float + Size of tile on one side + + Returns + ------- + csv_path : str + Path to the temporary CSV file + """ + # Get the center of the bbox in sky coordinates + center = wcs.pixelToSky(lsst.geom.Point2D(bbox.getCenter())) + ra_center = center.getRa().asDegrees() + dec_center = center.getDec().asDegrees() + + # Calculate tile corners (simplified rectangular approximation) + rac1 = ra_center - tile_radius + rac2 = ra_center + tile_radius + rac3 = ra_center + tile_radius + rac4 = ra_center - tile_radius + decc1 = dec_center - tile_radius + decc2 = dec_center - tile_radius + decc3 = dec_center + tile_radius + decc4 = dec_center + tile_radius + + size_pixels = int(2 * tile_radius * 3600 / self.template_scale) + # Create the synthetic tile exposure and save to FITS + tile_exposure = self._makeSyntheticTileExposure( + ra_center, dec_center, size_pixels=size_pixels, band=band + ) + + # Write to temporary FITS file + temp_fits = tempfile.NamedTemporaryFile( + suffix=".fits", delete=False, prefix="test_tile_" + ) + tile_exposure.writeFits(temp_fits.name) + temp_fits.close() + self.temp_files.append(temp_fits.name) + # Create tile catalog data + tile_data = { + "tilename": [f"TES{int(ra_center):04d}{int(dec_center):+05d}"], + "band": [band], + "survey": ["DES"], + "ra_cent": [ra_center], + "dec_cent": [dec_center], + "rac1": [rac1], + "rac2": [rac2], + "rac3": [rac3], + "rac4": [rac4], + "decc1": [decc1], + "decc2": [decc2], + "decc3": [decc3], + "decc4": [decc4], + "filepath": [temp_fits.name], + } + + # Create astropy table + table = Table(tile_data) + + # Write to temporary CSV file + temp_csv = tempfile.NamedTemporaryFile( + mode="w", suffix=".csv", delete=False, prefix="test_tiles_" + ) + table.write(temp_csv.name, format="ascii.csv", overwrite=True) + temp_csv.close() + self.temp_csv_file = temp_csv.name + + return temp_csv.name + + def tearDown(self): + """Clean up temporary files created during testing.""" + # Remove temporary FITS files + for temp_file in self.temp_files: + if os.path.exists(temp_file): + try: + os.remove(temp_file) + except Exception: + pass # Ignore cleanup errors + + # Remove temporary CSV file + if self.temp_csv_file and os.path.exists(self.temp_csv_file): + try: + os.remove(self.temp_csv_file) + except Exception: + pass # Ignore cleanup errors + + def testRunOneTractInput(self): + """Test a bounding box that fits inside single DES tract""" + + box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180)) + + # Create synthetic tile catalog with tiles that overlap the test bbox + synthetic_csv = self._createSyntheticTileCatalog( + box, self.exposure.wcs, band="r" + ) + + config = lsst.ip.diffim.DesTemplateConfig() + config.tileFile = synthetic_csv + task = lsst.ip.diffim.DesTemplateTask(config=config) + result = task.run( + bbox=box, wcs=self.exposure.wcs, band="r", physical_filter="r_03" + ) + + self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 1) + self._checkPixels(result.template, task.config, box) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() diff --git a/tests/test_maskReferenceSources.py b/tests/test_maskReferenceSources.py new file mode 100644 index 000000000..9e398da2a --- /dev/null +++ b/tests/test_maskReferenceSources.py @@ -0,0 +1,223 @@ +import unittest +import numpy as np + +import lsst.utils.tests as utilsTests +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.afw.geom as afwGeom +import lsst.afw.detection as afwDetect +from lsst.afw.geom import makeSkyWcs +from lsst.geom import Point2D, Angle, degrees, SpherePoint + +from lsst.ip.diffim.maskReferenceSources import ( + MaskReferenceSourcesTask, + MaskReferenceSourcesConfig, +) + +# ----------------- helpers ----------------- + + +def make_exposure( + w=300, h=300, scale_arcsec=0.2, crpix=(150.0, 150.0), crval_deg=(10.0, -5.0) +): + exp = afwImage.ExposureF(w, h) + scale_deg = scale_arcsec / 3600.0 + cd = np.array([[scale_deg, 0.0], [0.0, scale_deg]], dtype=float) + wcs = makeSkyWcs( + crpix=Point2D(*crpix), + crval=SpherePoint(Angle(crval_deg[0], degrees), Angle(crval_deg[1], degrees)), + cdMatrix=cd, + ) + exp.setWcs(wcs) + return exp + + +def circular_footprint(xc, yc, r=4): + spans = [] + for y in range(yc - r, yc + r + 1): + dy = y - yc + dx = int(np.floor(np.sqrt(max(0, r * r - dy * dy)))) + spans.append(afwGeom.Span(y, xc - dx, xc + dx)) + ss = afwGeom.SpanSet(spans) + return afwDetect.Footprint(ss) + + +def make_src_cat(wcs, xy_list, radius=4): + + schema = afwTable.SourceTable.makeMinimalSchema() + + xKey = schema.addField(afwTable.FieldD("centroid_x", "centroid x", "pixel")) + yKey = schema.addField(afwTable.FieldD("centroid_y", "centroid y", "pixel")) + + schema.getAliasMap().set("slot_Centroid", "centroid") + table = afwTable.SourceTable.make(schema) + + cat = afwTable.SourceCatalog(table) + for x, y in xy_list: + rec = cat.addNew() + rec.set(xKey, float(x)) + rec.set(yKey, float(y)) + rec.setCoord(wcs.pixelToSky(Point2D(float(x), float(y)))) + rec.setFootprint(circular_footprint(int(x), int(y), radius)) + return cat + + +def make_ref_cat_from_sources(wcs, src_xy, arcsec_offset, offset_axis="ra"): + """Make a reference catalog by offsetting each source's sky coord.""" + schema = afwTable.SourceTable.makeMinimalSchema() + table = afwTable.SourceTable.make(schema) + cat = afwTable.SourceCatalog(table) + deg_off = arcsec_offset / 3600.0 + for x, y in src_xy: + sp = wcs.pixelToSky(Point2D(float(x), float(y))) + ra = sp.getRa().asDegrees() + dec = sp.getDec().asDegrees() + if offset_axis == "ra": + ra += deg_off + else: + dec += deg_off + rec = cat.addNew() + rec.setCoord(SpherePoint(Angle(ra, degrees), Angle(dec, degrees))) + return cat + + +def mask_bit(exp, plane): + try: + return exp.mask.addMaskPlane(plane) + except Exception: + return exp.mask.getPlaneBitMask(plane) + + +def ensure_plane_bitmask(mask, name): + try: + bit_index = mask.addMaskPlane(name) + return 1 << bit_index + except Exception: + return mask.getPlaneBitMask(name) + + +def paint_footprint(mask, footprint, bit): + """OR the given bit into mask pixels inside the footprint (clip to bounds).""" + bbox = mask.getBBox() + for span in footprint.getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y): + mask.array[y, x] |= bit + + +# ----------------- tests ----------------- + + +class TestMaskReferenceSourcesMany(utilsTests.TestCase): + """Build 300x300 exposure with 8 DETECTED sources. + Make 5 reference objects close enough to match + Verify REFERENCE is set exactly over DETECTED for matches only. + """ + + def setUp(self): + self.exp = make_exposure(300, 300, scale_arcsec=0.2) + self.cfg = MaskReferenceSourcesConfig() + self.cfg.mask_plane_name = "REFERENCE" + self.cfg.matching_radius = 1.2 # arcsec + self.task = MaskReferenceSourcesTask(config=self.cfg) + + # 8 sources, well-separated (no overlapping footprints) + self.src_xy = [ + (40, 40), + (90, 60), + (140, 80), + (190, 100), + (240, 120), + (60, 200), + (160, 220), + (260, 240), + ] + self.src_cat = make_src_cat(self.exp.getWcs(), self.src_xy, radius=4) + + self.detected_bit = ensure_plane_bitmask(self.exp.mask, "DETECTED") + for rec in self.src_cat: + paint_footprint(self.exp.mask, rec.getFootprint(), self.detected_bit) + + # Build reference catalog: + # - First 5: +0.8" in RA (=> should match within 1.2") + # - Last 3: not in the reference catalog + refs = make_ref_cat_from_sources( + self.exp.getWcs(), self.src_xy[:5], arcsec_offset=0.8, offset_axis="ra" + ) + + # Concatenate (order shouldn't matter) + self.ref_cat = afwTable.SourceCatalog(refs.getTable()) + for rec in refs: + self.ref_cat.append(rec) + + self.reference_bit = ensure_plane_bitmask( + self.exp.mask, self.cfg.mask_plane_name + ) + + def test_masking(self): + self.assertEqual(int((self.exp.mask.array & self.reference_bit).sum()), 0) + + # Run the algorithm + result = self.task.run( + ref_catalog=self.ref_cat, + difference_image=self.exp, + matching_sources=self.src_cat, + matching_image=self.exp, + ) + out = result.masked_difference_image + mask_arr = out.mask.array + + # 1) REFERENCE must be a subset of DETECTED everywhere + ref_only = (mask_arr & self.reference_bit) & ~( + (mask_arr & self.detected_bit) > 0 + ) + self.assertEqual( + int(ref_only.sum()), 0, "REFERENCE set outside DETECTED footprint(s)" + ) + + # 2) Exactly 5 sources matched; REFERENCE pixels == sum of their footprint areas + matches = self.task._matchSources(self.ref_cat, self.src_cat, self.exp.getWcs()) + self.assertEqual(len(matches), 5, "Expected 5 matches within 1.2 arcsec") + matched_src_ids = {id(src) for _, src in matches} + + expected_pixels = 0 + for src in self.src_cat: + if id(src) in matched_src_ids: + expected_pixels += src.getFootprint().getArea() + + actual_pixels = ((mask_arr & self.reference_bit) == self.reference_bit).sum() + self.assertEqual( + actual_pixels, + expected_pixels, + "REFERENCE mask pixels should equal sum of matched footprint areas", + ) + + # 3) Unmatched sources: DETECTED-only, no REFERENCE bits inside their footprints + for src in self.src_cat: + if id(src) not in matched_src_ids: + # Collect REFERENCE pixels inside this footprint + in_fp = 0 + bbox = self.exp.mask.getBBox() + for span in src.getFootprint().getSpans(): + y = span.getY() + for x in range(span.getMinX(), span.getMaxX() + 1): + if bbox.contains(x, y) and ( + mask_arr[y, x] & self.reference_bit + ): + in_fp += 1 + self.assertEqual( + in_fp, 0, "Unmatched source has REFERENCE bits in its footprint" + ) + + +# --------------- boilerplate --------------- + + +def setup_module(module): + utilsTests.init() + + +if __name__ == "__main__": + utilsTests.init() + unittest.main()