diff --git a/python/lsst/pipe/tasks/desTemplateTask.py b/python/lsst/pipe/tasks/desTemplateTask.py
new file mode 100644
index 000000000..d1fb80cdd
--- /dev/null
+++ b/python/lsst/pipe/tasks/desTemplateTask.py
@@ -0,0 +1,711 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Pipeline for generating templates from archival DES data.
+"""
+
+__all__ = ["DesTemplateTask",
+ "DesTemplateConfig",
+ "DesTemplateConnections"]
+
+import numpy as np
+
+from astropy.table import Table
+from astropy.io import fits
+
+import lsst.pex.config as pexConfig
+import lsst.pipe.base as pipeBase
+import lsst.pipe.base.connectionTypes as connectionTypes
+import lsst.afw.image as afwImage
+import lsst.afw.table as afwTable
+import lsst.afw.math as afwMath
+import lsst.afw.geom as afwGeom
+from lsst.geom import Box2D, Point2D
+from lsst.sphgeom import ConvexPolygon, UnitVector3d, LonLat
+from lsst.afw.image import makePhotoCalibFromCalibZeroPoint
+from lsst.meas.algorithms.installGaussianPsf import (
+ InstallGaussianPsfTask,
+ InstallGaussianPsfConfig,
+)
+from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig
+
+from lsst.meas.base import IdGenerator
+from lsst.geom import SpherePoint, degrees
+import lsst.log
+
+
+class DesTemplateConnections(
+ pipeBase.PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+):
+ """Connections for DES template creation task."""
+
+ bbox = connectionTypes.Input(
+ doc="Bounding box of the exposure",
+ name="pvi.bbox",
+ storageClass="Box2I",
+ dimensions=("instrument", "visit", "detector"),
+ )
+ wcs = connectionTypes.Input(
+ doc="WCS of the exposure",
+ name="pvi.wcs",
+ storageClass="Wcs",
+ dimensions=("instrument", "visit", "detector"),
+ )
+ template = connectionTypes.Output(
+ doc="Template exposure created from DES tiles",
+ name="desTemplate",
+ storageClass="ExposureF",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+
+class DesTemplateConfig(
+ pipeBase.PipelineTaskConfig, pipelineConnections=DesTemplateConnections
+):
+ """Configuration for DES template creation task."""
+
+ tileFile = pexConfig.Field(
+ dtype=str,
+ default="merged_tiles_v3_deduplicated.csv",
+ doc="Path to merged CSV file containing tile information with columns: "
+ "tilename, band, survey, ra_cent, dec_cent, rac1-4, decc1-4, filepath",
+ )
+ dataSourcePriority = pexConfig.ListField(
+ dtype=str,
+ default=["DELVE", "DES", "DECADE"],
+ doc="Priority order for surveys when duplicate tiles exist (first = highest priority)",
+ )
+ modelPsf = pexConfig.Field(
+ dtype=bool,
+ default=False,
+ doc="Use modeled PSF instead of Gaussian PSF for templates",
+ )
+ gaussianPsfFwhm = pexConfig.Field(
+ dtype=float,
+ default=4.0,
+ doc="FWHM for Gaussian PSF in pixels (used when modelPsf=False)",
+ )
+ gaussianPsfWidth = pexConfig.Field(
+ dtype=int,
+ default=21,
+ doc="Width for Gaussian PSF in pixels (used when modelPsf=False)",
+ )
+ hasInverseVariancePlane = pexConfig.Field(
+ dtype=bool,
+ default=True,
+ doc="Do the template files have a weight map or inverse variance plane",
+ )
+
+
+class DesTemplateTask(pipeBase.PipelineTask):
+ """Create template exposures from survey tile data."""
+
+ ConfigClass = DesTemplateConfig
+ _DefaultName = "desTemplate"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ # Suppress FITS reader warnings for survey tiles
+ lsst.log.setLevel("lsst.afw.image.MaskedImageFitsReader", lsst.log.ERROR)
+
+ if self.config.modelPsf:
+ from lsst.pipe.tasks.calibrateImage import (
+ CalibrateImageTask,
+ CalibrateImageConfig,
+ )
+
+ cal_config = CalibrateImageConfig()
+ cal_config.psf_repair.doInterpolate = False
+ cal_config.psf_repair.doCosmicRay = False
+ cal_config.psf_detection.thresholdValue = 100.0
+ self.calibrateTask = CalibrateImageTask(config=cal_config)
+ else:
+ install_psf_config = InstallGaussianPsfConfig()
+ install_psf_config.fwhm = self.config.gaussianPsfFwhm
+ install_psf_config.width = self.config.gaussianPsfWidth
+ self.installPsfTask = InstallGaussianPsfTask(config=install_psf_config)
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ # Extract band and physical_filter from dataId
+ band = butlerQC.quantum.dataId["band"]
+ physical_filter = butlerQC.quantum.dataId["physical_filter"]
+
+ inputs = butlerQC.get(inputRefs)
+ outputs = self.run(band=band, physical_filter=physical_filter, **inputs)
+ butlerQC.put(outputs, outputRefs)
+
+ def run(self, bbox, wcs, band, physical_filter):
+ """Create template from survey tiles overlapping the target region.
+
+ Parameters
+ ----------
+ bbox : `lsst.geom.Box2I`
+ Bounding box of the target region
+ wcs : `lsst.afw.geom.SkyWcs`
+ WCS of the target region
+ band : `str`
+ Photometric band (e.g., 'r', 'g', 'i')
+ physical_filter : `str`
+ Physical filter name (e.g., 'r_03', 'g_01')
+
+ Returns
+ -------
+ result : `lsst.pipe.base.Struct`
+ Results struct with:
+ - ``template`` : Template exposure created from survey tiles
+ """
+ self.log.info("Creating survey template for target region")
+ self.log.info(f"Using band '{band}' and physical_filter '{physical_filter}'")
+
+ # Load tile catalog and filter by band
+ tile_table = Table.read(self.config.tileFile, format="ascii")
+ self.log.info(f"Loaded {len(tile_table)} total tiles from catalog")
+
+ # Validate required columns exist
+ required_columns = [
+ "tilename",
+ "band",
+ "survey",
+ "ra_cent",
+ "dec_cent",
+ "rac1",
+ "rac2",
+ "rac3",
+ "rac4",
+ "decc1",
+ "decc2",
+ "decc3",
+ "decc4",
+ "filepath",
+ ]
+ missing_columns = [
+ col for col in required_columns if col not in tile_table.colnames
+ ]
+ if missing_columns:
+ raise RuntimeError(
+ f"Required columns missing from tile table: {missing_columns}. "
+ "Please use the merged tiles CSV format."
+ )
+
+ # Filter by band first
+ band_filtered_tiles = self._filterTilesByBand(tile_table, band)
+ self.log.info(f"Found {len(band_filtered_tiles)} tiles for band '{band}'")
+
+ if len(band_filtered_tiles) == 0:
+ raise RuntimeError(f"No tiles found for band '{band}' in catalog")
+
+ # Find overlapping tiles
+ overlapping_tiles = self._findOverlappingTiles(bbox, wcs, band_filtered_tiles)
+ self.log.info(f"Found {len(overlapping_tiles)} overlapping tiles")
+
+ if not overlapping_tiles:
+ center = wcs.pixelToSky(Point2D(bbox.getCenter()))
+ target_ra = center.getRa().asDegrees()
+ target_dec = center.getDec().asDegrees()
+ self.log.info(
+ f"Target was centered at RA={target_ra:.3f}, Dec={target_dec:.3f}"
+ )
+ raise RuntimeError("No survey tiles found overlapping with target region")
+
+ # Resolve duplicate tiles and load template exposures
+ resolved_tiles = self._resolveDuplicateTiles(
+ overlapping_tiles, band_filtered_tiles
+ )
+ tile_exposures, tile_names = self._loadTileTemplates(resolved_tiles)
+ self.log.info(f"Successfully loaded {len(tile_exposures)} template exposures")
+
+ if not tile_exposures:
+ raise RuntimeError("No valid template exposures found")
+
+ # Create coadded template
+ template = self._createCoaddFromTiles(
+ tile_exposures, tile_names, wcs, bbox, physical_filter
+ )
+
+ self.log.info("Successfully created survey template")
+ return pipeBase.Struct(template=template)
+
+ def _tileOverlapsRegion(self, tile_corners, bbox, wcs):
+ """Check if a tile overlaps with the target region using spherical geometry.
+
+ Parameters
+ ----------
+ tile_corners : `list` of `tuple`
+ List of (ra, dec) tuples for tile corners in degrees
+ bbox : `lsst.geom.Box2I`
+ Bounding box of target region
+ wcs : `lsst.afw.geom.SkyWcs`
+ WCS of target region
+
+ Returns
+ -------
+ overlaps : `bool`
+ True if tile overlaps with target region, False otherwise
+ """
+ # Get target region corners
+ region_corners = [wcs.pixelToSky(Point2D(p)) for p in bbox.getCorners()]
+
+ # Convert tile corners to unit vectors
+ tile_unit_vectors = [
+ UnitVector3d(LonLat.fromDegrees(ra, dec)) for ra, dec in tile_corners
+ ]
+
+ # Convert region corners to unit vectors
+ region_unit_vectors = [
+ UnitVector3d(
+ LonLat.fromDegrees(s.getRa().asDegrees(), s.getDec().asDegrees())
+ )
+ for s in region_corners
+ ]
+
+ # Create convex polygons
+ tile_polygon = ConvexPolygon(tile_unit_vectors)
+ region_polygon = ConvexPolygon(region_unit_vectors)
+
+ # Check for overlap
+ return tile_polygon.overlaps(region_polygon)
+
+ def _findOverlappingTiles(self, bbox, wcs, tile_table):
+ """Find tiles that overlap with the target region.
+
+ Parameters
+ ----------
+ bbox : `lsst.geom.Box2I`
+ Bounding box of target region
+ wcs : `lsst.afw.geom.SkyWcs`
+ WCS of target region
+ tile_table : `astropy.table.Table`
+ Table containing tile information
+
+ Returns
+ -------
+ overlapping_tiles : `list` of `tuple`
+ List of (tile_name, distance_deg) for overlapping tiles, sorted by distance
+ """
+ # Get target region center for distance calculation
+ center = wcs.pixelToSky(Point2D(bbox.getCenter()))
+ target_ra = center.getRa().asDegrees()
+ target_dec = center.getDec().asDegrees()
+ p0 = SpherePoint(target_ra * degrees, target_dec * degrees)
+ overlapping_tiles = []
+ total_tiles_checked = 0
+
+ self.log.info(f"Checking {len(tile_table)} tiles for overlap with exposure")
+
+ for row in tile_table:
+ tile_name = row["tilename"]
+ total_tiles_checked += 1
+
+ # Extract tile corners from table
+ required_corner_columns = [
+ "rac1",
+ "decc1",
+ "rac2",
+ "decc2",
+ "rac3",
+ "decc3",
+ "rac4",
+ "decc4",
+ ]
+
+ if not all(col in tile_table.colnames for col in required_corner_columns):
+ missing_corners = [
+ col
+ for col in required_corner_columns
+ if col not in tile_table.colnames
+ ]
+ raise RuntimeError(
+ f"Required corner coordinate columns missing from tile table: {missing_corners}"
+ )
+
+ # Use actual corner coordinates
+ tile_corners = [
+ (row["rac1"], row["decc1"]),
+ (row["rac2"], row["decc2"]),
+ (row["rac3"], row["decc3"]),
+ (row["rac4"], row["decc4"]),
+ ]
+
+ # Check if tile overlaps with target region
+ has_overlap = self._tileOverlapsRegion(tile_corners, bbox, wcs)
+
+ if has_overlap:
+ # Calculate distance to tile center for sorting
+ tile_ra = row["ra_cent"]
+ tile_dec = row["dec_cent"]
+
+ p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees)
+ distance = p0.separation(p1).asDegrees()
+
+ self.log.info(f" OVERLAP: {tile_name} - distance={distance:.3f}°")
+
+ overlapping_tiles.append((tile_name, distance))
+ else:
+ # Log nearby tiles that don't overlap for debugging
+ tile_ra = row["ra_cent"]
+ tile_dec = row["dec_cent"]
+ p1 = SpherePoint(tile_ra * degrees, tile_dec * degrees)
+ distance = p0.separation(p1).asDegrees()
+
+ if distance < 2.0: # Only log tiles within 2 degrees
+ self.log.info(
+ f" No overlap: {tile_name} - distance={distance:.3f}° (RA={tile_ra:.3f}, Dec={tile_dec:.3f})"
+ )
+
+ self.log.info(
+ f"Checked {total_tiles_checked} tiles, found {len(overlapping_tiles)} with overlap"
+ )
+
+ # Sort by distance
+ overlapping_tiles.sort(key=lambda x: x[1])
+
+ # Return the overlapping tiles
+ return overlapping_tiles
+
+ def _filterTilesByBand(self, tile_table, band):
+ """Filter tiles by the requested band and validate survey values.
+
+ Parameters
+ ----------
+ tile_table : `astropy.table.Table`
+ Table containing all tiles
+ band : `str`
+ Photometric band to filter for
+
+ Returns
+ -------
+ filtered_table : `astropy.table.Table`
+ Table containing only tiles for the specified band with valid surveys
+ """
+ if "band" not in tile_table.colnames:
+ raise RuntimeError(
+ "Required 'band' column not found in tile table. "
+ "Please use the merged tiles CSV format."
+ )
+
+ if "survey" not in tile_table.colnames:
+ raise RuntimeError(
+ "Required 'survey' column not found in tile table. "
+ "Please use the merged tiles CSV format."
+ )
+
+ # First filter by band
+ band_mask = tile_table["band"] == band
+ band_filtered = tile_table[band_mask]
+
+ # Then validate survey values
+ valid_surveys = set(self.config.dataSourcePriority)
+ survey_column = band_filtered["survey"]
+
+ # Vectorized membership check (case-sensitive)
+ valid_mask = np.isin(survey_column, list(valid_surveys))
+
+ invalid_count = len(band_filtered) - np.sum(valid_mask)
+ if invalid_count > 0:
+ self.log.warn(
+ f"Found {invalid_count} tiles with unrecognized survey values, ignoring them"
+ )
+
+ filtered_table = band_filtered[valid_mask]
+
+ self.log.info(
+ f"Filtered {len(tile_table)} tiles to {len(filtered_table)} for band '{band}' with valid surveys"
+ )
+ return filtered_table
+
+ def _resolveDuplicateTiles(self, overlapping_tiles, band_filtered_tiles):
+ """Resolve duplicate tiles based on survey priority.
+
+ Parameters
+ ----------
+ overlapping_tiles : `list` of `tuple`
+ List of (tile_name, distance) tuples for overlapping tiles
+ band_filtered_tiles : `astropy.table.Table`
+ Band-filtered tile table with filepath and survey columns
+
+ Returns
+ -------
+ resolved_tiles : `list` of `dict`
+ List of resolved tile dictionaries with tilename, distance, filepath, survey
+ """
+ # Get priority order from config
+ priority_order = self.config.dataSourcePriority
+
+ resolved_tiles = []
+
+ for tile_name, distance in overlapping_tiles:
+ # Find all entries for this tilename in the band-filtered table
+ tile_mask = band_filtered_tiles["tilename"] == tile_name
+ tile_entries = band_filtered_tiles[tile_mask]
+
+ if len(tile_entries) == 0:
+ self.log.warn(f"No entries found for overlapping tile {tile_name}")
+ continue
+ elif len(tile_entries) == 1:
+ # No duplicates, use the single entry
+ entry = tile_entries[0]
+ resolved_tiles.append(
+ {
+ "tilename": tile_name,
+ "distance": distance,
+ "filepath": entry["filepath"],
+ "survey": entry["survey"],
+ }
+ )
+ else:
+ # Multiple entries, resolve based on priority order
+ selected_entry = None
+
+ # Check for each survey in priority order
+ for preferred_survey in priority_order:
+ survey_entries = tile_entries[
+ tile_entries["survey"] == preferred_survey
+ ]
+ if len(survey_entries) > 0:
+ selected_entry = survey_entries[0]
+ self.log.info(
+ f"Using {preferred_survey} tile for {tile_name} (priority)"
+ )
+ break
+
+ if selected_entry is None:
+ # No recognized survey, use first available
+ selected_entry = tile_entries[0]
+ self.log.info(
+ f"Using {selected_entry['survey']} tile for {tile_name} (fallback)"
+ )
+
+ resolved_tiles.append(
+ {
+ "tilename": tile_name,
+ "distance": distance,
+ "filepath": selected_entry["filepath"],
+ "survey": selected_entry["survey"],
+ }
+ )
+
+ self.log.info(
+ f"Resolved {len(resolved_tiles)} tiles from {len(overlapping_tiles)} overlapping tiles"
+ )
+ return resolved_tiles
+
+ def _loadTileTemplates(self, resolved_tiles):
+ """Load template exposures from resolved tiles.
+
+ Parameters
+ ----------
+ resolved_tiles : `list` of `dict`
+ List of resolved tile dictionaries with tilename, distance, filepath, survey
+
+ Returns
+ -------
+ exposures : `list` of `lsst.afw.image.ExposureF`
+ List of loaded template exposures
+ tile_names : `list` of `str`
+ List of tile names corresponding to exposures
+ """
+ exposures = []
+ tile_names = []
+
+ for tile_info in resolved_tiles:
+ tile_name = tile_info["tilename"]
+ filepath = tile_info["filepath"]
+ survey = tile_info["survey"]
+
+ try:
+ self.log.info(f"Loading {survey} tile {tile_name} from {filepath}")
+ exp = afwImage.ExposureF(filepath)
+
+ if self.config.hasInverseVariancePlane:
+ # Need convert DES weight map to variance map for LSST exposure
+ mi = exp.maskedImage
+ bad_weight = mi.variance.array <= 0.0
+ mi.variance.array[~bad_weight] = (
+ 1.0 / mi.variance.array[~bad_weight]
+ )
+ mi.variance.array[bad_weight] = np.nan
+ mi.mask.array[bad_weight] |= mi.mask.getPlaneBitMask("NO_DATA")
+
+ # Set PSF
+ if not self.config.modelPsf:
+ # Install Gaussian PSF
+ self.installPsfTask.run(exp)
+ else:
+ # Use modeled PSF
+ cat = self.calibrateTask._compute_psf(exp, IdGenerator())
+
+ # Set photometric calibration
+ metadata = exp.getInfo().getMetadata()
+ if "MAGZERO" in metadata:
+ zp = metadata["MAGZERO"]
+ flux0 = 10 ** (0.4 * zp)
+ calib = makePhotoCalibFromCalibZeroPoint(flux0, 0.0)
+ exp.setPhotoCalib(calib)
+ else:
+ self.log.warn(
+ f"No MAGZERO found in {tile_name}, using default calibration"
+ )
+
+ exposures.append(exp)
+ tile_names.append(tile_name)
+
+ except Exception as e:
+ self.log.warn(f"Failed to load tile {tile_name}: {e}")
+ continue
+
+ return exposures, tile_names
+
+ def _createCoaddFromTiles(
+ self, tile_exposures, tile_names, target_wcs, target_bbox, physical_filter
+ ):
+ """Create coadded template from tile exposures.
+
+ Parameters
+ ----------
+ tile_exposures : `list` of `lsst.afw.image.ExposureF`
+ List of tile exposures
+ tile_names : `list` of `str`
+ List of tile names
+ target_wcs : `lsst.afw.geom.SkyWcs`
+ Target WCS for template
+ target_bbox : `lsst.geom.Box2I`
+ Target bounding box for template
+ physical_filter : `str`
+ Physical filter name for the template
+
+ Returns
+ -------
+ template : `lsst.afw.image.ExposureF`
+ Coadded template exposure
+ """
+ warper = afwMath.Warper.fromConfig(afwMath.Warper.ConfigClass())
+
+ # Create tile catalog for PSF computation
+ tile_schema = afwTable.ExposureTable.makeMinimalSchema()
+ tile_key = tile_schema.addField("tile", type="String", size=12)
+ weight_key = tile_schema.addField("weight", type=float)
+
+ coadd_config = CoaddPsfConfig()
+
+ # Statistics configuration
+ stats_flags = afwMath.stringToStatisticsProperty("MEAN")
+ stats_ctrl = afwMath.StatisticsControl()
+ stats_ctrl.setNanSafe(True)
+ stats_ctrl.setWeighted(True)
+ stats_ctrl.setCalcErrorMosaicMode(True)
+
+ tile_catalog = afwTable.ExposureCatalog(tile_schema)
+ masked_images = []
+ weights = []
+
+ for i, (exp, tile) in enumerate(zip(tile_exposures, tile_names)):
+ self.log.debug(f"Processing tile {tile} ({i+1}/{len(tile_exposures)})")
+
+ # Warp template to target coordinate system
+ warped = warper.warpExposure(target_wcs, exp, maxBBox=target_bbox)
+ if warped.getBBox().getArea() == 0:
+ self.log.debug(f"Skipping tile {tile}: no overlap after warping")
+ continue
+
+ # Create properly initialized exposure
+ aligned_exp = afwImage.ExposureF(target_bbox, target_wcs)
+ aligned_exp.maskedImage.set(
+ np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan
+ )
+ aligned_exp.maskedImage.assign(warped.maskedImage, warped.getBBox())
+ masked_images.append(aligned_exp.maskedImage)
+
+ # Calculate weight (inverse of mean variance)
+ var_array = aligned_exp.variance.array
+ finite_var = var_array[np.isfinite(var_array)]
+ if len(finite_var) > 0:
+ mean_var = np.mean(finite_var)
+ weight = 1.0 / mean_var if mean_var > 0 else 1.0
+ else:
+ weight = 1.0
+ weights.append(weight)
+
+ # Add to tile catalog for PSF computation
+ record = tile_catalog.addNew()
+ record.set(tile_key, tile)
+ record.set(weight_key, weight)
+ record.setPsf(exp.getPsf())
+ record.setWcs(exp.getWcs())
+ record.setPhotoCalib(exp.getPhotoCalib())
+ record.setBBox(exp.getBBox())
+
+ polygon = afwGeom.Polygon(Box2D(exp.getBBox()).getCorners())
+ record.setValidPolygon(polygon)
+
+ if not masked_images:
+ raise RuntimeError("No valid warped images for coadd creation")
+
+ # Create coadd exposure
+ coadd = afwImage.ExposureF(target_bbox, target_wcs)
+ coadd.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan)
+ xy0 = coadd.getXY0()
+
+ # Perform statistical stacking
+ coadd.maskedImage = afwMath.statisticsStack(
+ masked_images, stats_flags, stats_ctrl, weights, clipped=0, maskMap=[]
+ )
+ coadd.maskedImage.setXY0(xy0)
+
+ # Create and set PSF
+ if len(tile_catalog) > 0:
+ valid_mask = (
+ coadd.maskedImage.mask.array
+ & coadd.maskedImage.mask.getPlaneBitMask("NO_DATA")
+ ) == 0
+ if np.any(valid_mask):
+ mask_for_centroid = afwImage.makeMaskFromArray(
+ valid_mask.astype(afwImage.MaskPixel)
+ )
+ psf_center = afwGeom.SpanSet.fromMask(
+ mask_for_centroid, 1
+ ).computeCentroid()
+
+ ctrl = coadd_config.makeControl()
+ coadd_psf = CoaddPsf(
+ tile_catalog,
+ target_wcs,
+ psf_center,
+ ctrl.warpingKernelName,
+ ctrl.cacheSize,
+ )
+ coadd.setPsf(coadd_psf)
+
+ # Set calibration and filter from first tile exposure and physical_filter
+ if tile_exposures:
+ coadd.setPhotoCalib(tile_exposures[0].getPhotoCalib())
+ # Create filter label from physical_filter
+ # Extract band from physical_filter (e.g., 'r_03' -> 'r')
+ band = (
+ physical_filter.split("_")[0]
+ if "_" in physical_filter
+ else physical_filter
+ )
+ filter_label = afwImage.FilterLabel(band=band, physical=physical_filter)
+ coadd.setFilter(filter_label)
+
+ return coadd
diff --git a/python/lsst/pipe/tasks/maskReferenceSources.py b/python/lsst/pipe/tasks/maskReferenceSources.py
new file mode 100644
index 000000000..a6b458c03
--- /dev/null
+++ b/python/lsst/pipe/tasks/maskReferenceSources.py
@@ -0,0 +1,424 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Pipeline for masking DIA sources based on a reference catalog
+"""
+
+__all__ = ["MaskReferenceSourcesTask",
+ "MaskReferenceSourcesConfig",
+ "MaskReferenceSourcesConnections"]
+
+import numpy as np
+import logging
+
+from astropy.coordinates import SkyCoord
+import astropy.units as u
+
+import lsst.pex.config as pexConfig
+import lsst.pipe.base as pipeBase
+import lsst.pipe.base.connectionTypes as connectionTypes
+import lsst.afw.image as afwImage
+import lsst.afw.table as afwTable
+import lsst.meas.algorithms
+
+
+class MaskReferenceSourcesConnections(
+ pipeBase.PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+):
+ """Connections for MaskReferenceSources task."""
+
+ # Input reference catalog
+ astrometry_ref_cat = connectionTypes.PrerequisiteInput(
+ doc="Reference catalog to use for source matching",
+ name="gaia_dr3_20230707",
+ storageClass="SimpleCatalog",
+ dimensions=("skypix",),
+ deferLoad=True,
+ multiple=True,
+ )
+
+ # Input image for matching (default: science image)
+ matching_image = connectionTypes.Input(
+ doc="Image to use for reference source matching",
+ name="calexp",
+ storageClass="ExposureF",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+ # Input sources for matching (default: science sources)
+ matching_sources = connectionTypes.Input(
+ doc="Sources detected in matching image",
+ name="src",
+ storageClass="SourceCatalog",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+ # Input difference image to mask
+ difference_image = connectionTypes.Input(
+ doc="Difference image to apply masks to",
+ name="difference_image",
+ storageClass="ExposureF",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+ # Output masked difference image
+ masked_difference_image = connectionTypes.Output(
+ doc="Difference image with reference source mask",
+ name="difference_image_masked",
+ storageClass="ExposureF",
+ dimensions=("instrument", "visit", "detector"),
+ )
+
+
+class MaskReferenceSourcesConfig(
+ pipeBase.PipelineTaskConfig, pipelineConnections=MaskReferenceSourcesConnections
+):
+ """Configuration for MaskReferenceSources task."""
+
+ matching_radius = pexConfig.Field(
+ dtype=float,
+ default=1.0,
+ doc="Maximum separation for source matching in arcseconds",
+ )
+ reference_mag_column = pexConfig.Field(
+ dtype=str,
+ default="r_flux",
+ doc="Column name for reference catalog flux (in nJy)",
+ )
+ reference_ra_column = pexConfig.Field(
+ dtype=str,
+ default="coord_ra",
+ doc="Column name for reference catalog RA in degrees",
+ )
+ reference_dec_column = pexConfig.Field(
+ dtype=str,
+ default="coord_dec",
+ doc="Column name for reference catalog Dec in degrees",
+ )
+ mask_plane_name = pexConfig.Field(
+ dtype=str,
+ default="REFERENCE",
+ doc="Name of mask plane to create for reference source regions",
+ )
+ reference_buffer = pexConfig.Field(
+ dtype=int,
+ default=100,
+ doc="Buffer in pixels to add around image bounds when loading reference catalog",
+ )
+ astrometry_ref_loader = pexConfig.ConfigField(
+ dtype=lsst.meas.algorithms.LoadReferenceObjectsConfig,
+ doc="Configuration of reference object loader for source matching",
+ )
+
+ def setDefaults(self):
+ super().setDefaults()
+ # Set to Gaii, but we don't use photometry currently
+ self.astrometry_ref_loader.filterMap = {
+ "u": "phot_g_mean",
+ "g": "phot_g_mean",
+ "r": "phot_g_mean",
+ "i": "phot_g_mean",
+ "z": "phot_g_mean",
+ "y": "phot_g_mean",
+ }
+
+
+class MaskReferenceSourcesTask(pipeBase.PipelineTask):
+ """Mask regions around reference sources in difference images."""
+
+ ConfigClass = MaskReferenceSourcesConfig
+ _DefaultName = "maskReferenceSources"
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ """Run the task on quantum data."""
+ inputs = butlerQC.get(inputRefs)
+
+ # Create reference object loader following the standard pattern
+ astrometry_loader = lsst.meas.algorithms.ReferenceObjectLoader(
+ dataIds=[ref.datasetRef.dataId for ref in inputRefs.astrometry_ref_cat],
+ refCats=inputs.pop("astrometry_ref_cat"),
+ name=self.config.connections.astrometry_ref_cat,
+ config=self.config.astrometry_ref_loader,
+ log=self.log,
+ )
+
+ # Load reference catalog using loadPixelBox with buffer for edge sources
+ original_bbox = inputs["matching_image"].getBBox()
+ buffered_bbox = original_bbox.dilatedBy(self.config.reference_buffer)
+
+ self.log.info(
+ f"Loading reference catalog for bbox: {original_bbox} (buffered to {buffered_bbox})"
+ )
+
+ ref_result = astrometry_loader.loadPixelBox(
+ bbox=buffered_bbox,
+ wcs=inputs["matching_image"].getWcs(),
+ filterName=inputs["matching_image"].getFilter().bandLabel,
+ )
+
+ self.log.info(f"Loaded {len(ref_result.refCat)} reference sources")
+
+ outputs = self.run(
+ ref_catalog=ref_result.refCat,
+ difference_image=inputs["difference_image"],
+ matching_sources=inputs["matching_sources"],
+ matching_image=inputs["matching_image"],
+ )
+
+ butlerQC.put(outputs, outputRefs)
+
+ def run(self, ref_catalog, difference_image, matching_sources, matching_image):
+ """Mask reference sources in difference image.
+
+ Parameters
+ ----------
+ ref_catalog : `lsst.afw.table.SourceCatalog`
+ Reference catalog with sources to mask
+ difference_image : `lsst.afw.image.ExposureF`
+ Difference image to apply masks to
+ matching_sources : `lsst.afw.table.SourceCatalog`
+ Sources to match against reference catalog
+ matching_image : `lsst.afw.image.ExposureF`
+ Image where matching sources were detected
+
+ Returns
+ -------
+ result : `lsst.pipe.base.Struct`
+ Results struct with:
+ - ``masked_difference_image`` : Modified difference image with reference mask
+ """
+ self.log.info(f"Masking reference sources in difference image")
+ self.log.info(f"Reference catalog has {len(ref_catalog)} sources")
+ self.log.info(f"Matching against {len(matching_sources)} detected sources")
+
+ # Create copy of difference image to modify
+ masked_diff = difference_image.clone()
+
+ # Add mask plane if it doesn't exist
+ mask = masked_diff.mask
+ try:
+ mask_bit = mask.addMaskPlane(self.config.mask_plane_name)
+ self.log.info(
+ f"Added mask plane '{self.config.mask_plane_name}' with bit {mask_bit}"
+ )
+ except Exception:
+ # Mask plane already exists
+ mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name)
+ self.log.info(
+ f"Using existing mask plane '{self.config.mask_plane_name}' with bit {mask_bit}"
+ )
+
+ # Match reference sources to the matching sources
+ matches = self._matchSources(
+ ref_catalog, matching_sources, matching_image.getWcs()
+ )
+ self.log.info(f"Found {len(matches)} matched sources")
+
+ if len(matches) == 0:
+ self.log.warn(f"No reference sources matched to detected sources")
+ return pipeBase.Struct(masked_difference_image=masked_diff)
+
+ # Apply masks for matched sources
+ n_masked = self._applyReferenceMasks(matches, masked_diff)
+ self.log.info(f"Masked {n_masked} reference source regions")
+
+ return pipeBase.Struct(masked_difference_image=masked_diff)
+
+ def _matchSources(self, ref_catalog, matching_sources, wcs):
+ """Match reference catalog sources to detected sources using Astropy SkyCoord.
+
+ Parameters
+ ----------
+ ref_catalog : `lsst.afw.table.SourceCatalog`
+ Reference catalog sources
+ matching_sources : `lsst.afw.table.SourceCatalog`
+ Detected sources to match against
+ wcs : `lsst.afw.geom.SkyWcs`
+ WCS for coordinate conversion
+
+ Returns
+ -------
+ matches : `list` of `tuple`
+ List of (ref_source, matching_source) pairs that match
+ """
+ if len(ref_catalog) == 0 or len(matching_sources) == 0:
+ return []
+
+ # Extract reference source coordinates
+ ref_ras = []
+ ref_decs = []
+ ref_sources = []
+
+ for ref_src in ref_catalog:
+ if (
+ self.config.reference_ra_column in ref_src.schema
+ and self.config.reference_dec_column in ref_src.schema
+ ):
+ # Convert to degrees - these might be Angle objects
+ ra = ref_src[self.config.reference_ra_column].asDegrees()
+ dec = ref_src[self.config.reference_dec_column].asDegrees()
+ else:
+ # Fallback to coord if specific columns not available
+ coord = ref_src.getCoord()
+ ra = coord.getRa().asDegrees()
+ dec = coord.getDec().asDegrees()
+
+ ref_ras.append(ra)
+ ref_decs.append(dec)
+ ref_sources.append(ref_src)
+
+ # Extract detected source coordinates
+ match_ras = []
+ match_decs = []
+ match_sources = []
+
+ for src in matching_sources:
+ coord = wcs.pixelToSky(Point2D(src.getX(), src.getY()))
+ match_ras.append(coord.getRa().asDegrees())
+ match_decs.append(coord.getDec().asDegrees())
+ match_sources.append(src)
+
+ # Create SkyCoord objects
+ ref_coords = SkyCoord(ref_ras * u.deg, ref_decs * u.deg)
+ match_coords = SkyCoord(match_ras * u.deg, match_decs * u.deg)
+
+ # Perform matching with maximum separation
+ max_sep = self.config.matching_radius * u.arcsec
+ idx, d2d, d3d = match_coords.match_to_catalog_sky(ref_coords)
+
+ # Build matches list for sources within matching radius
+ matches = []
+ n_too_far = 0
+ for i, (match_idx, sep) in enumerate(zip(idx, d2d)):
+ if sep <= max_sep:
+ matches.append((ref_sources[match_idx], match_sources[i]))
+ else:
+ n_too_far += 1
+
+ self.log.debug(
+ f"Matching: {len(matches)} matched, {n_too_far} beyond {self.config.matching_radius} arcsec"
+ )
+ return matches
+
+ def _applyReferenceMasks(self, matches, masked_exposure):
+ """Apply reference masks to matched source regions.
+
+ Parameters
+ ----------
+ matches : `list` of `tuple`
+ List of (ref_source, matching_source) matched pairs
+ masked_exposure : `lsst.afw.image.ExposureF`
+ Difference image exposure to apply masks to
+
+ Returns
+ -------
+ n_masked : `int`
+ Number of regions masked
+ """
+ mask = masked_exposure.mask
+ ref_mask_bit = mask.getPlaneBitMask(self.config.mask_plane_name)
+
+ n_masked = 0
+
+ for ref_src, matching_src in matches:
+ # Use the source footprint from the matching source to define the mask region
+ footprint = matching_src.getFootprint()
+ if footprint is not None:
+ self._maskSourceFootprint(footprint, mask, ref_mask_bit)
+ n_masked += 1
+
+ return n_masked
+
+ def _maskSourceFootprint(self, footprint, mask, mask_bit):
+ """Mask pixels in the source footprint.
+
+ Parameters
+ ----------
+ footprint : `lsst.afw.detection.Footprint`
+ Footprint of the source
+ mask : `lsst.afw.image.Mask`
+ Difference image mask to modify
+ mask_bit : `int`
+ Mask plane bit to set
+ """
+ # Get the footprint pixels - since matching and difference images have same WCS,
+ # coordinates can be used directly
+ bbox = mask.getBBox()
+ n_clipped = 0
+ n_masked = 0
+
+ for span in footprint.getSpans():
+ y = span.getY()
+ for x in range(span.getMinX(), span.getMaxX() + 1):
+ if bbox.contains(x, y):
+ mask.array[y, x] |= mask_bit
+ n_masked += 1
+ else:
+ n_clipped += 1
+
+ if n_clipped > 0:
+ self.log.debug(
+ f"Footprint clipped: {n_clipped} pixels outside bounds, {n_masked} pixels masked"
+ )
+
+ def _maskConnectedRegion(
+ self, mask, seed_x, seed_y, detected_bit, detected_neg_bit, ref_mask_bit
+ ):
+ """Mask a connected region of DETECTED or DETECTED_NEGATIVE pixels.
+
+ Parameters
+ ----------
+ mask : `lsst.afw.image.Mask`
+ Mask to modify
+ seed_x, seed_y : `int`
+ Seed pixel coordinates
+ detected_bit, detected_neg_bit : `int`
+ Mask plane bits for DETECTED and DETECTED_NEGATIVE
+ ref_mask_bit : `int`
+ Mask plane bit for REFERENCE_MASK
+ """
+ height, width = mask.array.shape
+ visited = np.zeros((height, width), dtype=bool)
+
+ # Stack for flood fill algorithm
+ stack = [(seed_x, seed_y)]
+ target_bits = detected_bit | detected_neg_bit
+
+ while stack:
+ x, y = stack.pop()
+
+ # Check bounds and if already visited
+ if x < 0 or x >= width or y < 0 or y >= height or visited[y, x]:
+ continue
+
+ visited[y, x] = True
+ mask_value = mask.array[y, x]
+
+ # If this pixel has DETECTED or DETECTED_NEGATIVE, mark it and add neighbors
+ if mask_value & target_bits:
+ mask.array[y, x] |= ref_mask_bit
+
+ # Add 8-connected neighbors to stack
+ for dx in [-1, 0, 1]:
+ for dy in [-1, 0, 1]:
+ if dx != 0 or dy != 0: # Skip center pixel
+ stack.append((x + dx, y + dy))
diff --git a/tests/test_desTemplate.py b/tests/test_desTemplate.py
new file mode 100644
index 000000000..8d83cba80
--- /dev/null
+++ b/tests/test_desTemplate.py
@@ -0,0 +1,314 @@
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import collections
+import itertools
+import os
+import tempfile
+import unittest
+
+import numpy as np
+from astropy.table import Table
+
+import lsst.afw.geom
+import lsst.afw.image
+import lsst.afw.math
+import lsst.geom
+import lsst.ip.diffim
+import lsst.meas.algorithms
+import lsst.meas.base.tests
+import lsst.pipe.base as pipeBase
+import lsst.utils.tests
+
+from utils import generate_data_id
+
+# Change this to True, `setup display_ds9`, and open ds9 (or use another afw
+# display backend) to show the tract/patch layouts on the image.
+debug = False
+if debug:
+ import lsst.afw.display
+
+ display = lsst.afw.display.Display()
+ display.frame = 1
+
+
+class GetTemplateTaskTestCase(lsst.utils.tests.TestCase):
+ """Test that GetTemplateTask works on both one tract and multiple tract
+ input coadd exposures.
+
+ Makes a synthetic exposure large enough to fit four small tracts with 2x2
+ (300x300 pixel) patches each, extracts pixels for those patches by warping,
+ and tests GetTemplateTask's output against boxes that overlap various
+ combinations of one or multiple tracts.
+ """
+
+ def setUp(self):
+ self.scale = 0.2 # arcsec/pixel
+ # DES pixel scale is approximately 0.263 arcsec/pixel
+ self.template_scale = 0.263
+ self.exposure = self._makeExposure()
+
+ # Track temporary files for cleanup
+ self.temp_files = []
+ self.temp_csv_file = None
+
+ if debug:
+ display.image(self.exposure, "base exposure")
+
+ def _makeExposure(self):
+ """Create a large image to break up into tracts and patches.
+
+ The image will have a source every 100 pixels in x and y, and a WCS
+ that results in the tracts all fitting in the image, with tract=0
+ in the lower left, tract=1 to the right, tract=2 above, and tract=3
+ to the upper right.
+ """
+ box = lsst.geom.Box2I(
+ lsst.geom.Point2I(-200, -200), lsst.geom.Point2I(800, 800)
+ )
+ # This WCS was constructed so that tract 0 mostly fills the lower left
+ # quadrant of the image, and the other tracts fill the rest; slight
+ # extra rotation as a check on the final warp layout, scaled by 5%
+ # from the patch pixel scale.
+ cd_matrix = lsst.afw.geom.makeCdMatrix(
+ 1.05 * self.scale * lsst.geom.arcseconds, 93 * lsst.geom.degrees
+ )
+ wcs = lsst.afw.geom.makeSkyWcs(
+ lsst.geom.Point2D(120, 150),
+ lsst.geom.SpherePoint(0, 0, lsst.geom.radians),
+ cd_matrix,
+ )
+ dataset = lsst.meas.base.tests.TestDataset(box, wcs=wcs)
+ for x, y in itertools.product(np.arange(0, 500, 100), np.arange(0, 500, 100)):
+ dataset.addSource(1e5, lsst.geom.Point2D(x, y))
+ exposure, _ = dataset.realize(2, dataset.makeMinimalSchema())
+ exposure.setFilter(lsst.afw.image.FilterLabel("r", "r_03"))
+ return exposure
+
+ def _checkMetadata(self, template, config, box, wcs, nPsfs):
+ """Check that the various metadata components were set correctly."""
+ expectedBox = lsst.geom.Box2I(box)
+ self.assertEqual(template.getBBox(), expectedBox)
+ # WCS should match our exposure, not any of the coadd tracts.
+ self.assertEqual(template.wcs, self.exposure.wcs)
+ self.assertEqual(template.getXY0(), expectedBox.getMin())
+ self.assertEqual(template.filter.bandLabel, "r")
+ self.assertEqual(template.filter.physicalLabel, "r_03")
+ self.assertEqual(template.psf.getComponentCount(), nPsfs)
+
+ def _checkPixels(self, template, config, box):
+ """Check that the pixel values in the template are close to the
+ original image.
+ """
+ # All pixels should have real values!
+ expectedBox = lsst.geom.Box2I(box)
+
+ # Check that we fully filled the template
+ self.assertTrue(np.all(np.isfinite(template.image.array)))
+
+ def _makeSyntheticTileExposure(
+ self, ra_center=0.0, dec_center=0.0, size_pixels=4096, band="r"
+ ):
+ """Create a synthetic tile exposure similar to a DES coadd tile.
+
+ Parameters
+ ----------
+ ra_center : float
+ RA center of the tile in degrees
+ dec_center : float
+ DEC center of the tile in degrees
+ size_pixels : int
+ Size of the tile in pixels (DES tiles are typically 10000x10000,
+ but we use smaller for testing)
+ band : str
+ Photometric band
+
+ Returns
+ -------
+ exposure : lsst.afw.image.ExposureF
+ Synthetic tile exposure with WCS, PSF, and photometric calibration
+ """
+ # Create bounding box for the tile
+ bbox = lsst.geom.Box2I(
+ lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(size_pixels, size_pixels)
+ )
+
+ # Create WCS centered on the tile center
+ cd_matrix = lsst.afw.geom.makeCdMatrix(
+ self.template_scale * lsst.geom.arcseconds,
+ 0 * lsst.geom.degrees, # No rotation
+ )
+ wcs = lsst.afw.geom.makeSkyWcs(
+ lsst.geom.Point2D(size_pixels / 2, size_pixels / 2),
+ lsst.geom.SpherePoint(ra_center, dec_center, lsst.geom.degrees),
+ cd_matrix,
+ )
+
+ # Create dataset with synthetic sources
+ dataset = lsst.meas.base.tests.TestDataset(bbox, wcs=wcs)
+
+ # Add a grid of sources across the tile (every 500 pixels)
+ for x, y in itertools.product(
+ np.arange(500, size_pixels - 500, 500),
+ np.arange(500, size_pixels - 500, 500),
+ ):
+ dataset.addSource(1e5, lsst.geom.Point2D(x, y))
+
+ # Realize the exposure with noise
+ exposure, _ = dataset.realize(5.0, dataset.makeMinimalSchema())
+
+ # Set filter
+ filter_label = lsst.afw.image.FilterLabel(band=band, physical=f"{band}_03")
+ exposure.setFilter(filter_label)
+
+ # Add MAGZERO to metadata (typical for DES)
+ metadata = exposure.getInfo().getMetadata()
+ metadata.set("MAGZERO", 30.0) # Typical DES zero point
+
+ return exposure
+
+ def _createSyntheticTileCatalog(self, bbox, wcs, band="r", tile_radius=0.02):
+ """Create a synthetic tile catalog CSV file for testing.
+
+ The tile catalog will be centered on the box, with tile_radius on each side
+
+ Parameters
+ ----------
+ bbox : lsst.geom.Box2I
+ Bounding box of the region to overlap
+ wcs : lsst.afw.geom.SkyWcs
+ WCS of the region
+ band : str
+ Photometric band
+ tile_radius : float
+ Size of tile on one side
+
+ Returns
+ -------
+ csv_path : str
+ Path to the temporary CSV file
+ """
+ # Get the center of the bbox in sky coordinates
+ center = wcs.pixelToSky(lsst.geom.Point2D(bbox.getCenter()))
+ ra_center = center.getRa().asDegrees()
+ dec_center = center.getDec().asDegrees()
+
+ # Calculate tile corners (simplified rectangular approximation)
+ rac1 = ra_center - tile_radius
+ rac2 = ra_center + tile_radius
+ rac3 = ra_center + tile_radius
+ rac4 = ra_center - tile_radius
+ decc1 = dec_center - tile_radius
+ decc2 = dec_center - tile_radius
+ decc3 = dec_center + tile_radius
+ decc4 = dec_center + tile_radius
+
+ size_pixels = int(2 * tile_radius * 3600 / self.template_scale)
+ # Create the synthetic tile exposure and save to FITS
+ tile_exposure = self._makeSyntheticTileExposure(
+ ra_center, dec_center, size_pixels=size_pixels, band=band
+ )
+
+ # Write to temporary FITS file
+ temp_fits = tempfile.NamedTemporaryFile(
+ suffix=".fits", delete=False, prefix="test_tile_"
+ )
+ tile_exposure.writeFits(temp_fits.name)
+ temp_fits.close()
+ self.temp_files.append(temp_fits.name)
+ # Create tile catalog data
+ tile_data = {
+ "tilename": [f"TES{int(ra_center):04d}{int(dec_center):+05d}"],
+ "band": [band],
+ "survey": ["DES"],
+ "ra_cent": [ra_center],
+ "dec_cent": [dec_center],
+ "rac1": [rac1],
+ "rac2": [rac2],
+ "rac3": [rac3],
+ "rac4": [rac4],
+ "decc1": [decc1],
+ "decc2": [decc2],
+ "decc3": [decc3],
+ "decc4": [decc4],
+ "filepath": [temp_fits.name],
+ }
+
+ # Create astropy table
+ table = Table(tile_data)
+
+ # Write to temporary CSV file
+ temp_csv = tempfile.NamedTemporaryFile(
+ mode="w", suffix=".csv", delete=False, prefix="test_tiles_"
+ )
+ table.write(temp_csv.name, format="ascii.csv", overwrite=True)
+ temp_csv.close()
+ self.temp_csv_file = temp_csv.name
+
+ return temp_csv.name
+
+ def tearDown(self):
+ """Clean up temporary files created during testing."""
+ # Remove temporary FITS files
+ for temp_file in self.temp_files:
+ if os.path.exists(temp_file):
+ try:
+ os.remove(temp_file)
+ except Exception:
+ pass # Ignore cleanup errors
+
+ # Remove temporary CSV file
+ if self.temp_csv_file and os.path.exists(self.temp_csv_file):
+ try:
+ os.remove(self.temp_csv_file)
+ except Exception:
+ pass # Ignore cleanup errors
+
+ def testRunOneTractInput(self):
+ """Test a bounding box that fits inside single DES tract"""
+
+ box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
+
+ # Create synthetic tile catalog with tiles that overlap the test bbox
+ synthetic_csv = self._createSyntheticTileCatalog(
+ box, self.exposure.wcs, band="r"
+ )
+
+ config = lsst.ip.diffim.DesTemplateConfig()
+ config.tileFile = synthetic_csv
+ task = lsst.ip.diffim.DesTemplateTask(config=config)
+ result = task.run(
+ bbox=box, wcs=self.exposure.wcs, band="r", physical_filter="r_03"
+ )
+
+ self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 1)
+ self._checkPixels(result.template, task.config, box)
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()
diff --git a/tests/test_maskReferenceSources.py b/tests/test_maskReferenceSources.py
new file mode 100644
index 000000000..9e398da2a
--- /dev/null
+++ b/tests/test_maskReferenceSources.py
@@ -0,0 +1,223 @@
+import unittest
+import numpy as np
+
+import lsst.utils.tests as utilsTests
+import lsst.afw.image as afwImage
+import lsst.afw.table as afwTable
+import lsst.afw.geom as afwGeom
+import lsst.afw.detection as afwDetect
+from lsst.afw.geom import makeSkyWcs
+from lsst.geom import Point2D, Angle, degrees, SpherePoint
+
+from lsst.ip.diffim.maskReferenceSources import (
+ MaskReferenceSourcesTask,
+ MaskReferenceSourcesConfig,
+)
+
+# ----------------- helpers -----------------
+
+
+def make_exposure(
+ w=300, h=300, scale_arcsec=0.2, crpix=(150.0, 150.0), crval_deg=(10.0, -5.0)
+):
+ exp = afwImage.ExposureF(w, h)
+ scale_deg = scale_arcsec / 3600.0
+ cd = np.array([[scale_deg, 0.0], [0.0, scale_deg]], dtype=float)
+ wcs = makeSkyWcs(
+ crpix=Point2D(*crpix),
+ crval=SpherePoint(Angle(crval_deg[0], degrees), Angle(crval_deg[1], degrees)),
+ cdMatrix=cd,
+ )
+ exp.setWcs(wcs)
+ return exp
+
+
+def circular_footprint(xc, yc, r=4):
+ spans = []
+ for y in range(yc - r, yc + r + 1):
+ dy = y - yc
+ dx = int(np.floor(np.sqrt(max(0, r * r - dy * dy))))
+ spans.append(afwGeom.Span(y, xc - dx, xc + dx))
+ ss = afwGeom.SpanSet(spans)
+ return afwDetect.Footprint(ss)
+
+
+def make_src_cat(wcs, xy_list, radius=4):
+
+ schema = afwTable.SourceTable.makeMinimalSchema()
+
+ xKey = schema.addField(afwTable.FieldD("centroid_x", "centroid x", "pixel"))
+ yKey = schema.addField(afwTable.FieldD("centroid_y", "centroid y", "pixel"))
+
+ schema.getAliasMap().set("slot_Centroid", "centroid")
+ table = afwTable.SourceTable.make(schema)
+
+ cat = afwTable.SourceCatalog(table)
+ for x, y in xy_list:
+ rec = cat.addNew()
+ rec.set(xKey, float(x))
+ rec.set(yKey, float(y))
+ rec.setCoord(wcs.pixelToSky(Point2D(float(x), float(y))))
+ rec.setFootprint(circular_footprint(int(x), int(y), radius))
+ return cat
+
+
+def make_ref_cat_from_sources(wcs, src_xy, arcsec_offset, offset_axis="ra"):
+ """Make a reference catalog by offsetting each source's sky coord."""
+ schema = afwTable.SourceTable.makeMinimalSchema()
+ table = afwTable.SourceTable.make(schema)
+ cat = afwTable.SourceCatalog(table)
+ deg_off = arcsec_offset / 3600.0
+ for x, y in src_xy:
+ sp = wcs.pixelToSky(Point2D(float(x), float(y)))
+ ra = sp.getRa().asDegrees()
+ dec = sp.getDec().asDegrees()
+ if offset_axis == "ra":
+ ra += deg_off
+ else:
+ dec += deg_off
+ rec = cat.addNew()
+ rec.setCoord(SpherePoint(Angle(ra, degrees), Angle(dec, degrees)))
+ return cat
+
+
+def mask_bit(exp, plane):
+ try:
+ return exp.mask.addMaskPlane(plane)
+ except Exception:
+ return exp.mask.getPlaneBitMask(plane)
+
+
+def ensure_plane_bitmask(mask, name):
+ try:
+ bit_index = mask.addMaskPlane(name)
+ return 1 << bit_index
+ except Exception:
+ return mask.getPlaneBitMask(name)
+
+
+def paint_footprint(mask, footprint, bit):
+ """OR the given bit into mask pixels inside the footprint (clip to bounds)."""
+ bbox = mask.getBBox()
+ for span in footprint.getSpans():
+ y = span.getY()
+ for x in range(span.getMinX(), span.getMaxX() + 1):
+ if bbox.contains(x, y):
+ mask.array[y, x] |= bit
+
+
+# ----------------- tests -----------------
+
+
+class TestMaskReferenceSourcesMany(utilsTests.TestCase):
+ """Build 300x300 exposure with 8 DETECTED sources.
+ Make 5 reference objects close enough to match
+ Verify REFERENCE is set exactly over DETECTED for matches only.
+ """
+
+ def setUp(self):
+ self.exp = make_exposure(300, 300, scale_arcsec=0.2)
+ self.cfg = MaskReferenceSourcesConfig()
+ self.cfg.mask_plane_name = "REFERENCE"
+ self.cfg.matching_radius = 1.2 # arcsec
+ self.task = MaskReferenceSourcesTask(config=self.cfg)
+
+ # 8 sources, well-separated (no overlapping footprints)
+ self.src_xy = [
+ (40, 40),
+ (90, 60),
+ (140, 80),
+ (190, 100),
+ (240, 120),
+ (60, 200),
+ (160, 220),
+ (260, 240),
+ ]
+ self.src_cat = make_src_cat(self.exp.getWcs(), self.src_xy, radius=4)
+
+ self.detected_bit = ensure_plane_bitmask(self.exp.mask, "DETECTED")
+ for rec in self.src_cat:
+ paint_footprint(self.exp.mask, rec.getFootprint(), self.detected_bit)
+
+ # Build reference catalog:
+ # - First 5: +0.8" in RA (=> should match within 1.2")
+ # - Last 3: not in the reference catalog
+ refs = make_ref_cat_from_sources(
+ self.exp.getWcs(), self.src_xy[:5], arcsec_offset=0.8, offset_axis="ra"
+ )
+
+ # Concatenate (order shouldn't matter)
+ self.ref_cat = afwTable.SourceCatalog(refs.getTable())
+ for rec in refs:
+ self.ref_cat.append(rec)
+
+ self.reference_bit = ensure_plane_bitmask(
+ self.exp.mask, self.cfg.mask_plane_name
+ )
+
+ def test_masking(self):
+ self.assertEqual(int((self.exp.mask.array & self.reference_bit).sum()), 0)
+
+ # Run the algorithm
+ result = self.task.run(
+ ref_catalog=self.ref_cat,
+ difference_image=self.exp,
+ matching_sources=self.src_cat,
+ matching_image=self.exp,
+ )
+ out = result.masked_difference_image
+ mask_arr = out.mask.array
+
+ # 1) REFERENCE must be a subset of DETECTED everywhere
+ ref_only = (mask_arr & self.reference_bit) & ~(
+ (mask_arr & self.detected_bit) > 0
+ )
+ self.assertEqual(
+ int(ref_only.sum()), 0, "REFERENCE set outside DETECTED footprint(s)"
+ )
+
+ # 2) Exactly 5 sources matched; REFERENCE pixels == sum of their footprint areas
+ matches = self.task._matchSources(self.ref_cat, self.src_cat, self.exp.getWcs())
+ self.assertEqual(len(matches), 5, "Expected 5 matches within 1.2 arcsec")
+ matched_src_ids = {id(src) for _, src in matches}
+
+ expected_pixels = 0
+ for src in self.src_cat:
+ if id(src) in matched_src_ids:
+ expected_pixels += src.getFootprint().getArea()
+
+ actual_pixels = ((mask_arr & self.reference_bit) == self.reference_bit).sum()
+ self.assertEqual(
+ actual_pixels,
+ expected_pixels,
+ "REFERENCE mask pixels should equal sum of matched footprint areas",
+ )
+
+ # 3) Unmatched sources: DETECTED-only, no REFERENCE bits inside their footprints
+ for src in self.src_cat:
+ if id(src) not in matched_src_ids:
+ # Collect REFERENCE pixels inside this footprint
+ in_fp = 0
+ bbox = self.exp.mask.getBBox()
+ for span in src.getFootprint().getSpans():
+ y = span.getY()
+ for x in range(span.getMinX(), span.getMaxX() + 1):
+ if bbox.contains(x, y) and (
+ mask_arr[y, x] & self.reference_bit
+ ):
+ in_fp += 1
+ self.assertEqual(
+ in_fp, 0, "Unmatched source has REFERENCE bits in its footprint"
+ )
+
+
+# --------------- boilerplate ---------------
+
+
+def setup_module(module):
+ utilsTests.init()
+
+
+if __name__ == "__main__":
+ utilsTests.init()
+ unittest.main()