From 6b1b0960f77d8443bb25dbcc5f80359ecb663925 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Fri, 24 Oct 2025 08:33:42 -0700 Subject: [PATCH 01/12] Add brightStarCutout.py --- .../brightStarSubtraction/brightStarCutout.py | 661 ++++++++++++++++++ 1 file changed, 661 insertions(+) create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py new file mode 100644 index 000000000..a9d58a762 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -0,0 +1,661 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Extract bright star cutouts; normalize and warp, optionally fit the PSF.""" + +__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] + +from typing import Any, Iterable, cast + +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + BrightStarStamp, + BrightStarStamps, + KernelPsf, + LoadReferenceObjectsConfig, + ReferenceObjectLoader, + WarpedPsf, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarCutoutConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for BrightStarCutoutTask.""" + + refCat = PrerequisiteInput( + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + doc="Reference catalog that contains bright star positions.", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + inputExposure = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + extendedPsf = Input( + name="extendedPsf2", + storageClass="ImageF", + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + brightStarStamps = Output( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + ) + + def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): + super().__init__(config=config) + assert config is not None + if not config.useExtendedPsf: + self.inputs.remove("extendedPsf") + + +class BrightStarCutoutConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarCutoutConnections, +): + """Configuration parameters for BrightStarCutoutTask.""" + + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + excludeArcsecRadius = Field[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=5, + ) + excludeMagRange = ListField[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=[0, 20], + ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + "optionally, fitting of the PSF.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel.", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + + # PSF Fitting + useExtendedPsf = Field[bool]( + doc="Use the extended PSF model to normalize bright star cutouts.", + default=False, + ) + doFitPsf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + default=True, + ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + + # Misc + loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + + +class BrightStarCutoutTask(PipelineTask): + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + The BrightStarCutoutTask is used to extract, process, and store small image + cutouts (or "postage stamps") around bright stars. + This task essentially consists of three principal steps. + First, it identifies bright stars within an exposure using a reference + catalog and extracts a stamp around each. + Second, it shifts and warps each stamp to remove optical distortions and + sample all stars on the same pixel grid. + Finally, it optionally fits a PSF plus plane flux model to the cutout. + This final fitting procedure may be used to normalize each bright star + stamp prior to stacking when producing extended PSF models. + """ + + ConfigClass = BrightStarCutoutConfig + _DefaultName = "brightStarCutout" + config: BrightStarCutoutConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize * self.config.stampSizePadding + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.loadReferenceObjectsConfig, + ) + extendedPsf = inputs.pop("extendedPsf", None) + output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) + # Only ingest Stamp if it exists; prevents ingesting an empty FITS file + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputBackground: BackgroundList, + extendedPsf: ImageF | None, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + # Restore original subtracted background + inputMI = inputExposure.getMaskedImage() + inputMI += inputBackground.getImage() + + # Set up NEIGHBOR mask plane; associate footprints with stars + inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + + # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here + inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) + + # Set up transform + detector = inputExposure.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + # Loop over each bright star + stamps, goodFracs, stamps_fitPsfResults = [], [], [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + footprintIndex = associations.get(starIndex, None) + stampMI = MaskedImageF(self.paddedStampBBox) + + # Set NEIGHBOR footprints in the mask plane + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + + # Apply the warp to the star stamp (in-place) + warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + + # Trim to the base stamp size, check mask coverage, update metadata + stampMI = stampMI[self.stampBBox] + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size + goodFracs.append(goodFrac) + if goodFrac < self.config.minAreaFraction: + continue + + # Fit a scaled PSF and a pedestal to each bright star cutout + psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) + constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + if self.config.useExtendedPsf: + psfImage = extendedPsf # Assumed to be warped, center at [0,0] + else: + psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + fitPsfResults = {} + if self.config.doFitPsf: + fitPsfResults = self._fitPsf(stampMI, psfImage) + stamps_fitPsfResults.append(fitPsfResults) + + # Save the stamp if the PSF fit was successful or no fit requested + if fitPsfResults or not self.config.doFitPsf: + stamp = BrightStarStamp( + maskedImage=stampMI, + # TODO: what to do about this PSF? + psf=constantPsf, + wcs=makeModifiedWcs(pixToPolar, wcs, False), + visit=cast(int, dataId["visit"]), + detector=cast(int, dataId["detector"]), + refId=obj["id"], + refMag=obj["mag"], + position=pixCoord, + scale=fitPsfResults.get("scale", None), + scaleErr=fitPsfResults.get("scaleErr", None), + pedestal=fitPsfResults.get("pedestal", None), + pedestalErr=fitPsfResults.get("pedestalErr", None), + pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), + xGradient=fitPsfResults.get("xGradient", None), + yGradient=fitPsfResults.get("yGradient", None), + globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), + globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), + psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + stamps.append(stamp) + + self.log.info( + "Extracted %i bright star stamp%s. " + "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", + len(stamps), + "" if len(stamps) == 1 else "s", + len(refCatBright) - len(stamps), + "" if len(refCatBright) - len(stamps) == 1 else "s", + np.sum(np.array(goodFracs) < self.config.minAreaFraction), + ( + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) + if self.config.doFitPsf + else 0 + ), + ) + brightStarStamps = BrightStarStamps(stamps) + return Struct(brightStarStamps=brightStarStamps) + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( + refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + + proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( + refCatSubset[fluxField] <= proxFluxRange[1] + ) + brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( + refCatSubset[fluxField] <= brightFluxRange[1] + ) + + coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") + excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore + refCatBrightIsolated = [] + for coord in cast(Iterable[SkyCoord], coords[brightStars]): + neighbors = coords[proxStars] + seps = coord.separation(neighbors).to(u.arcsec) + tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched + refCatBrightIsolated.append(not tooClose.any()) + + refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + + fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore + refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + + self.log.info( + "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " + "stars within %s arcsec.", + len(refCatBright), + len(refCatFull), + "" if len(refCatFull) == 1 else "s", + self.config.magRange, + self.config.excludeArcsecRadius, + ) + + return refCatBright + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + # TODO: This is np.float64, else FITS metadata serialization fails + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Create consistently masked data + badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = solutions[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) From b3d73c3883784048a8cb803df2de67a86ef4047a Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Fri, 24 Oct 2025 08:35:05 -0700 Subject: [PATCH 02/12] Merge updates to brightStarCutout.py from DM-48377 --- .../brightStarSubtraction/brightStarCutout.py | 361 ++++++++++++++---- 1 file changed, 288 insertions(+), 73 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index a9d58a762..3a678a5bf 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -23,13 +23,15 @@ __all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] +import math +from copy import deepcopy from typing import Any, Iterable, cast import astropy.units as u import numpy as np from astropy.coordinates import SkyCoord from astropy.table import Table -from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.cameraGeom import FIELD_ANGLE, FOCAL_PLANE, PIXELS from lsst.afw.detection import Footprint, FootprintSet, Threshold from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs from lsst.afw.geom.transformFactory import makeTransform @@ -38,6 +40,7 @@ from lsst.daf.butler import DataCoordinate from lsst.geom import ( AffineTransform, + Angle, Box2I, Extent2D, Extent2I, @@ -148,8 +151,6 @@ class BrightStarCutoutConfig( NEIGHBOR_MASK_PLANE, ], ) - - # Cutout geometry stampSize = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(251, 251), @@ -178,6 +179,10 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) + scalePsfModel = Field[bool]( + doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", + default=True, + ) # PSF Fitting useExtendedPsf = Field[bool]( @@ -196,6 +201,14 @@ class BrightStarCutoutConfig( doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", default=0.97, ) + fitIterations = Field[int]( + doc="Number of iterations over pedestal-gradient and scaling fit.", + default=5, + ) + offFrameMagLim = Field[float]( + doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + default=15.0, + ) # Misc loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( @@ -232,6 +245,7 @@ def __init__(self, initInputs=None, *args, **kwargs): self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( self.paddedStampRadius ) + self.modelScale = 1 def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) @@ -311,15 +325,18 @@ def run( # Loop over each bright star stamps, goodFracs, stamps_fitPsfResults = [], [], [] for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + # Excluding faint stars that are not within the frame. + if obj["mag"] > self.config.offFrameMagLim and not self.star_in_frame(pixCoord, bbox): + continue footprintIndex = associations.get(starIndex, None) stampMI = MaskedImageF(self.paddedStampBBox) # Set NEIGHBOR footprints in the mask plane if footprintIndex: neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) else: - self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) # Define linear shifting to recenter stamps coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star @@ -329,7 +346,7 @@ def run( pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) # Apply the warp to the star stamp (in-place) - warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + warpImage(stampMI, inputMI, pixToPolar, warpingControl) # Trim to the base stamp size, check mask coverage, update metadata stampMI = stampMI[self.stampBBox] @@ -343,38 +360,61 @@ def run( psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) if self.config.useExtendedPsf: - psfImage = extendedPsf # Assumed to be warped, center at [0,0] + psfImage = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] else: psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + # TODO: maybe we want to generate a smaller psf in case the following happens? + # The following could happen for when the user chooses small stampSize ~(50, 50) + if ( + psfImage.array.shape[0] > stampMI.image.array.shape[0] + or psfImage.array.shape[1] > stampMI.image.array.shape[1] + ): + continue + # Computing an scale factor that brings the model to the similar level of the star. + self.computeModelScale(stampMI, psfImage) + psfImage.array *= self.modelScale # ####### model scale correction ######## + fitPsfResults = {} + if self.config.doFitPsf: fitPsfResults = self._fitPsf(stampMI, psfImage) stamps_fitPsfResults.append(fitPsfResults) # Save the stamp if the PSF fit was successful or no fit requested if fitPsfResults or not self.config.doFitPsf: + distance_mm, theta_angle = self.star_location_on_focal(pixCoord, detector) + stamp = BrightStarStamp( - maskedImage=stampMI, - # TODO: what to do about this PSF? + stamp_im=stampMI, psf=constantPsf, wcs=makeModifiedWcs(pixToPolar, wcs, False), visit=cast(int, dataId["visit"]), detector=cast(int, dataId["detector"]), - refId=obj["id"], - refMag=obj["mag"], + ref_id=obj["id"], + ref_mag=obj["mag"], position=pixCoord, + focal_plane_radius=distance_mm, + focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here scale=fitPsfResults.get("scale", None), - scaleErr=fitPsfResults.get("scaleErr", None), + scale_err=fitPsfResults.get("scaleErr", None), pedestal=fitPsfResults.get("pedestal", None), - pedestalErr=fitPsfResults.get("pedestalErr", None), - pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), - xGradient=fitPsfResults.get("xGradient", None), - yGradient=fitPsfResults.get("yGradient", None), - globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), - globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), - psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), - psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), - psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + pedestal_err=fitPsfResults.get("pedestalErr", None), + pedestal_scale_cov=fitPsfResults.get("pedestalScaleCov", None), + gradient_x=fitPsfResults.get("xGradient", None), + gradient_y=fitPsfResults.get("yGradient", None), + global_reduced_chi_squared=fitPsfResults.get("globalReducedChiSquared", None), + global_degrees_of_freedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psf_reduced_chi_squared=fitPsfResults.get("psfReducedChiSquared", None), + psf_degrees_of_freedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psf_masked_flux_fraction=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + print( + obj["mag"], + fitPsfResults.get("globalReducedChiSquared", None), + fitPsfResults.get("globalDegreesOfFreedom", None), + fitPsfResults.get("psfReducedChiSquared", None), + fitPsfResults.get("psfDegreesOfFreedom", None), + fitPsfResults.get("psfMaskedFluxFrac", None), ) stamps.append(stamp) @@ -395,6 +435,25 @@ def run( brightStarStamps = BrightStarStamps(stamps) return Struct(brightStarStamps=brightStarStamps) + def star_location_on_focal(self, pixCoord, detector): + star_focal_plane_coords = detector.transform(pixCoord, PIXELS, FOCAL_PLANE) + star_x_fp = star_focal_plane_coords.getX() + star_y_fp = star_focal_plane_coords.getY() + distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2) + theta_rad = math.atan2(star_y_fp, star_x_fp) + theta_angle = Angle(theta_rad, radians) + return distance_mm, theta_angle + + def star_in_frame(self, pixCoord, inputExposureBBox): + if ( + pixCoord[0] < 0 + or pixCoord[1] < 0 + or pixCoord[0] > inputExposureBBox.getDimensions()[0] + or pixCoord[1] > inputExposureBBox.getDimensions()[1] + ): + return False + return True + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: """Get a bright star subset of the reference catalog. @@ -578,73 +637,109 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, # Calculate the fraction of the PSF image flux masked by bad pixels psfMaskedPixels = ImageF(psfImage.getBBox()) psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) - # TODO: This is np.float64, else FITS metadata serialization fails - psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + psfMaskedFluxFrac = ( + np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + ) if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: return {} # Handle cases where the PSF image is mostly masked - # Create a padded version of the input constant PSF image - paddedPsfImage = ImageF(stampMI.getBBox()) - paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() - - # Create consistently masked data - badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) - goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) - varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask). + gradientGoodSpans = self.generate_gradient_spans(stampMI, badMaskBitMask) + varianceData = gradientGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) if self.config.useMedianVariance: varianceData = np.median(varianceData) sigmaData = np.sqrt(varianceData) - imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B - imageData /= sigmaData - psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) - psfData /= sigmaData - # Fit the PSF scale factor and global pedestal + for i in range(self.config.fitIterations): + # Gradient-pedestal fitting: + if i: + # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can + # remove the star using the scale factor. + stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821 + else: + stamp = deepcopy(stampMI.image.array) + + imageDataGr = gradientGoodSpans.flatten(stamp, stampMI.getXY0()) / sigmaData # B + nData = len(imageDataGr) + coefficientMatrix = np.ones((nData, 3), dtype=float) # A + coefficientMatrix[:, 0] /= sigmaData + coefficientMatrix[:, 1:] = gradientGoodSpans.indices().T + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2] /= sigmaData + + try: + grSolutions, grSumSquaredResiduals, *_ = np.linalg.lstsq( + coefficientMatrix, imageDataGr, rcond=None + ) + covarianceMatrix = np.linalg.inv( + np.dot(coefficientMatrix.transpose(), coefficientMatrix) + ) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if grSumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + + pedestal = grSolutions[0] + pedestalErr = np.sqrt(covarianceMatrix[0, 0]) + scalePedestalCov = None + xGradient = grSolutions[2] + yGradient = grSolutions[1] + + # Scale fitting: + updatedStampMI = deepcopy(stampMI) + self._removePedestalAndGradient(updatedStampMI, pedestal, xGradient, yGradient) + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(updatedStampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Generating a mask plane while considering bad pixels in the psf model. + mask = self.add_psf_mask(paddedPsfImage, updatedStampMI) + # Create consistently masked data + scaleGoodSpans = self.generate_good_spans(mask, updatedStampMI.getBBox(), badMaskBitMask) + + imageData = scaleGoodSpans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) + psfData = scaleGoodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + scaleCoefficientMatrix = psfData.reshape(psfData.shape[0], 1) + + try: + scaleSolution, scaleSumSquaredResiduals, *_ = np.linalg.lstsq( + scaleCoefficientMatrix, imageData, rcond=None + ) + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if scaleSumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = scaleSolution[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + # TODO: calculate scale error and store it. + scaleErr = None + + scale *= self.modelScale # ####### model scale correction ######## nData = len(imageData) - coefficientMatrix = np.ones((nData, 4), dtype=float) # A - coefficientMatrix[:, 0] = psfData - coefficientMatrix[:, 1] /= sigmaData - coefficientMatrix[:, 2:] = goodSpans.indices().T - coefficientMatrix[:, 2] /= sigmaData - coefficientMatrix[:, 3] /= sigmaData - try: - solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) - covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C - except np.linalg.LinAlgError: - return {} # Handle singular matrix errors - if sumSquaredResiduals.size == 0: - return {} # Handle cases where sum of the squared residuals are empty - scale = solutions[0] - if scale <= 0: - return {} # Handle cases where the PSF scale fit has failed - scaleErr = np.sqrt(covarianceMatrix[0, 0]) - pedestal = solutions[1] - pedestalErr = np.sqrt(covarianceMatrix[1, 1]) - scalePedestalCov = covarianceMatrix[0, 1] - xGradient = solutions[3] - yGradient = solutions[2] - - # Calculate global (whole image) reduced chi-squared - globalChiSquared = np.sum(sumSquaredResiduals) - globalDegreesOfFreedom = nData - 4 - globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate global (whole image) reduced chi-squared (scaling fit is assumed as the main fitting + # process here.) + globalChiSquared = np.sum(scaleSumSquaredResiduals) + globalDegreesOfFreedom = nData - 1 + globalReducedChiSquared = np.float64(globalChiSquared / globalDegreesOfFreedom) # Calculate PSF BBox reduced chi-squared - psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) - psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() - psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxscaleGoodSpans = scaleGoodSpans.clippedTo(psfImage.getBBox()) + psfBBoxscaleGoodSpansX, psfBBoxscaleGoodSpansY = psfBBoxscaleGoodSpans.indices() + psfBBoxData = psfBBoxscaleGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + paddedPsfImage.array /= self.modelScale # ####### model scale correction ######## psfBBoxModel = ( - psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + psfBBoxscaleGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + pedestal - + psfBBoxGoodSpansX * xGradient - + psfBBoxGoodSpansY * yGradient + + psfBBoxscaleGoodSpansX * xGradient + + psfBBoxscaleGoodSpansY * yGradient ) - psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 # / psfBBoxVariance psfBBoxChiSquared = np.sum(psfBBoxResiduals) - psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 1 psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom - return dict( scale=scale, scaleErr=scaleErr, @@ -659,3 +754,123 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, psfMaskedFluxFrac=psfMaskedFluxFrac, ) + + def add_psf_mask(self, psfImage, stampMI, maskZeros=True): + """ + Creates a new mask by adding PSF bad pixels to an existing stamp mask. + + This method identifies "bad" pixels in the PSF image (NaNs and + optionally zeros/non-positives) and adds them to a deep copy + of the input stamp's mask. + + Args: + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + maskZeros (bool, optional): If True (default), mask pixels + where the PSF is <= 0. If False, only mask pixels < 0. + + Returns: + Any: A new mask object (deep copy) with the PSF mask planes added. + """ + cond = np.isnan(psfImage.array) + if maskZeros: + cond |= psfImage.array <= 0 + else: + cond |= psfImage.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask + + def _removePedestalAndGradient(self, stampMI, pedestal, xGradient, yGradient): + """Apply fitted pedestal and gradients to a single bright star stamp.""" + stampBBox = stampMI.getBBox() + xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) + xPlane = ImageF((xGrid * xGradient).astype(np.float32), xy0=stampMI.getXY0()) + yPlane = ImageF((yGrid * yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= pedestal + stampMI -= xPlane + stampMI -= yPlane + + def remove_star(self, stampMI, scale, psfImage): + """ + Subtracts a scaled PSF model from a star image. + + This performs a simple subtraction: `image - (psf * scale)`. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + scale (float): The scaling factor to apply to the PSF. + psfImage: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + + Returns: + np.ndarray: A new 2D numpy array containing the star-subtracted + image. + """ + star_removed_cutout = stampMI.image.array - psfImage.array * scale + return star_removed_cutout + + def computeModelScale(self, stampMI, psfImage): + """ + Computes the scaling factor of the given model against a star. + + Args: + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The given PSF model. + """ + cond = stampMI.mask.array == 0 + self.starMedianValue = np.median(stampMI.image.array[cond]).astype(np.float64) + + psfPos = psfImage.array > 0 + + imageArray = stampMI.image.array - self.starMedianValue + imageArrayPos = imageArray > 0 + self.modelScale = np.nanmean(imageArray[imageArrayPos]) / np.nanmean(psfImage.array[psfPos]) + + def generate_gradient_spans(self, stampMI, badMaskBitMask): + """ + Generates spans of "good" pixels for gradient fitting. + + This method creates a combined bitmask by OR-ing the provided + `badMaskBitMask` with the "DETECTED" plane from the stamp's mask. + It then calls `self.generate_good_spans` to find all pixel spans + not covered by this combined mask. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + badMaskBitMask (int): A bitmask representing planes to be + considered "bad" for gradient fitting. + + Returns: + gradientGoodSpans: A SpanSet object containing the "good" spans. + """ + detectedMaskBitMask = stampMI.mask.getPlaneBitMask("DETECTED") + gradientBitMask = np.bitwise_or(badMaskBitMask, detectedMaskBitMask) + + gradientGoodSpans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradientBitMask) + return gradientGoodSpans + + def generate_good_spans(self, mask, bBox, badBitMask): + """ + Generates a SpanSet of "good" pixels from a mask. + + This method identifies all spans within a given bounding box (`bBox`) + that are *not* flagged by the `badBitMask` in the provided `mask`. + + Args: + mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`). + bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`). + badBitMask (int): The combined bitmask of planes to exclude. + + Returns: + goodSpans: A SpanSet object representing all "good" spans. + """ + badSpans = SpanSet.fromMask(mask, badBitMask) + goodSpans = SpanSet(bBox).intersectNot(badSpans) + return goodSpans From fb6221a41e0a89aedff54b8f82be59cfceb364da Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Fri, 24 Oct 2025 12:00:40 -0700 Subject: [PATCH 03/12] Restore test_brightStarCutout.py from original effort --- tests/test_brightStarCutout.py | 102 +++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/test_brightStarCutout.py diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py new file mode 100644 index 000000000..67b88d02f --- /dev/null +++ b/tests/test_brightStarCutout.py @@ -0,0 +1,102 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.cameraGeom.testUtils +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageD, ImageF, MaskedImageF +from lsst.afw.math import FixedKernel +from lsst.geom import Point2I +from lsst.meas.algorithms import KernelPsf +from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask + + +class BrightStarCutoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Fit values + self.scale = 2.34e5 + self.pedestal = 3210.1 + self.xGradient = 5.432 + self.yGradient = 10.987 + + # Create a pedestal + 2D plane + xCoords = np.linspace(-50, 50, 101) + yCoords = np.linspace(-50, 50, 101) + xPlane, yPlane = np.meshgrid(xCoords, yCoords) + pedestal = np.ones_like(xPlane) * self.pedestal + + # Create a pseudo-PSF + dist_from_center = np.sqrt(xPlane**2 + yPlane**2) + psfArray = np.exp(-dist_from_center / 5) + psfArray /= np.sum(psfArray) + fixedKernel = FixedKernel(ImageD(psfArray)) + self.psf = KernelPsf(fixedKernel) + + # Bring everything together to construct a stamp masked image + stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient + stampIm = ImageF((stampArray).astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 654.321) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + self.stampMI.setXY0(Point2I(-50, -50)) + + # Ensure that all mask planes required by the task are in-place; + # new mask plane entries will be created as necessary + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_fitPsf(self): + """Test the PSF fitting method.""" + brightStarCutoutConfig = BrightStarCutoutConfig() + brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) + fitPsfResults = brightStarCutoutTask._fitPsf( + self.stampMI, + self.psf, + ) + self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) + self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) + self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) + self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() From 4c9a95b22d11b9d722d3ef261e997a7fdd98fa9e Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Fri, 24 Oct 2025 12:02:58 -0700 Subject: [PATCH 04/12] Merge updates to test_brightStarCutout.py from DM-48377 --- tests/test_brightStarCutout.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py index 67b88d02f..88da41440 100644 --- a/tests/test_brightStarCutout.py +++ b/tests/test_brightStarCutout.py @@ -51,7 +51,8 @@ def setUp(self): psfArray = np.exp(-dist_from_center / 5) psfArray /= np.sum(psfArray) fixedKernel = FixedKernel(ImageD(psfArray)) - self.psf = KernelPsf(fixedKernel) + psf = KernelPsf(fixedKernel) + self.psf = psf.computeKernelImage(psf.getAveragePosition()) # Bring everything together to construct a stamp masked image stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient @@ -83,10 +84,10 @@ def test_fitPsf(self): self.stampMI, self.psf, ) - self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) - self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) - self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) - self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + assert abs(fitPsfResults["scale"] - self.scale) / self.scale < 1e-6 + assert abs(fitPsfResults["pedestal"] - self.pedestal) / self.pedestal < 1e-6 + assert abs(fitPsfResults["xGradient"] - self.xGradient) / self.xGradient < 1e-6 + assert abs(fitPsfResults["yGradient"] - self.yGradient) / self.yGradient < 1e-6 def setup_module(module): From ae5ed283cb3c2de8d6cf514526e9023dd71e65c7 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Tue, 28 Oct 2025 08:45:56 -0700 Subject: [PATCH 05/12] Mid-refactor --- .../brightStarSubtraction/brightStarCutout.py | 197 +++++++++--------- 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 3a678a5bf..2c612fe7e 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -73,7 +73,7 @@ class BrightStarCutoutConnections( ): """Connections for BrightStarCutoutTask.""" - refCat = PrerequisiteInput( + ref_cat = PrerequisiteInput( name="gaia_dr3_20230707", storageClass="SimpleCatalog", doc="Reference catalog that contains bright star positions.", @@ -81,26 +81,26 @@ class BrightStarCutoutConnections( multiple=True, deferLoad=True, ) - inputExposure = Input( - name="calexp", + input_image = Input( + name="preliminary_visit_image", storageClass="ExposureF", doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", dimensions=("visit", "detector"), ) - inputBackground = Input( - name="calexpBackground", + input_background = Input( + name="preliminary_visit_image_background", storageClass="Background", doc="Background model for the input exposure, to be added back on during processing.", dimensions=("visit", "detector"), ) - extendedPsf = Input( - name="extendedPsf2", + extended_psf = Input( + name="extended_psf", storageClass="ImageF", doc="Extended PSF model, built from stacking bright star cutouts.", dimensions=("band",), ) - brightStarStamps = Output( - name="brightStarStamps", + bright_star_stamps = Output( + name="bright_star_stamps", storageClass="BrightStarStamps", doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", dimensions=("visit", "detector"), @@ -110,7 +110,7 @@ def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): super().__init__(config=config) assert config is not None if not config.useExtendedPsf: - self.inputs.remove("extendedPsf") + self.inputs.remove("extended_psf") class BrightStarCutoutConfig( @@ -124,18 +124,22 @@ class BrightStarCutoutConfig( doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", default=[0, 18], ) - excludeArcsecRadius = Field[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + excludeRadiusArcsec = Field[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.", default=5, ) excludeMagRange = ListField[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.", default=[0, 20], ) minAreaFraction = Field[float]( doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", default=0.1, ) + # offFrameMagLim = Field[float]( + # doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + # default=15.0, + # ) badMaskPlanes = ListField[str]( doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " "optionally, fitting of the PSF.", @@ -151,6 +155,8 @@ class BrightStarCutoutConfig( NEIGHBOR_MASK_PLANE, ], ) + + # Stamp configuration stampSize = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(251, 251), @@ -179,18 +185,18 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - scalePsfModel = Field[bool]( - doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", - default=True, - ) + # scalePsfModel = Field[bool]( + # doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", + # default=True, + # ) # PSF Fitting useExtendedPsf = Field[bool]( - doc="Use the extended PSF model to normalize bright star cutouts.", + doc="Use the extended PSF model to estimate the bright star cutout normalization factor.", default=False, ) doFitPsf = Field[bool]( - doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + doc="Fit a scaled PSF and a simple background to each bright star cutout.", default=True, ) useMedianVariance = Field[bool]( @@ -202,13 +208,9 @@ class BrightStarCutoutConfig( default=0.97, ) fitIterations = Field[int]( - doc="Number of iterations over pedestal-gradient and scaling fit.", + doc="Number of iterations to constrain PSF fitting.", default=5, ) - offFrameMagLim = Field[float]( - doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", - default=15.0, - ) # Misc loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( @@ -226,7 +228,7 @@ class BrightStarCutoutTask(PipelineTask): catalog and extracts a stamp around each. Second, it shifts and warps each stamp to remove optical distortions and sample all stars on the same pixel grid. - Finally, it optionally fits a PSF plus plane flux model to the cutout. + Finally, it optionally fits a PSF and a simple background model. This final fitting procedure may be used to normalize each bright star stamp prior to stacking when producing extended PSF models. """ @@ -237,39 +239,39 @@ class BrightStarCutoutTask(PipelineTask): def __init__(self, initInputs=None, *args, **kwargs): super().__init__(*args, **kwargs) - stampSize = Extent2D(*self.config.stampSize.list()) - stampRadius = floor(stampSize / 2) - self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) - paddedStampSize = stampSize * self.config.stampSizePadding - self.paddedStampRadius = floor(paddedStampSize / 2) - self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( - self.paddedStampRadius + stamp_size = Extent2D(*self.config.stampSize.list()) + stamp_radius = floor(stamp_size / 2) + self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius) + padded_stamp_size = stamp_size * self.config.stampSizePadding + self.padded_stamp_radius = floor(padded_stamp_size / 2) + self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.padded_stamp_radius ) - self.modelScale = 1 - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - inputs = butlerQC.get(inputRefs) - inputs["dataId"] = butlerQC.quantum.dataId - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, + # self.modelScale = 1 + + def runQuantum(self, butlerQC, input_refs, output_refs): + inputs = butlerQC.get(input_refs) + inputs["data_id"] = butlerQC.quantum.dataId + ref_obj_loader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in input_refs.ref_cat], + refCats=inputs.pop("ref_cat"), + name=self.config.connections.ref_cat, config=self.config.loadReferenceObjectsConfig, ) - extendedPsf = inputs.pop("extendedPsf", None) - output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) + extended_psf = inputs.pop("extended_psf", None) + output = self.run(**inputs, extended_psf=extended_psf, ref_obj_loader=ref_obj_loader) # Only ingest Stamp if it exists; prevents ingesting an empty FITS file if output: - butlerQC.put(output, outputRefs) + butlerQC.put(output, output_refs) @timeMethod def run( self, - inputExposure: ExposureF, - inputBackground: BackgroundList, - extendedPsf: ImageF | None, - refObjLoader: ReferenceObjectLoader, - dataId: dict[str, Any] | DataCoordinate, + input_image: ExposureF, + input_background: BackgroundList, + extended_psf: ImageF | None, + ref_obj_loader: ReferenceObjectLoader, + data_id: dict[str, Any] | DataCoordinate, ): """Identify bright stars within an exposure using a reference catalog, extract stamps around each, warp/shift stamps onto a common frame and @@ -277,29 +279,32 @@ def run( Parameters ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The background-subtracted image to extract bright star stamps. - inputBackground : `~lsst.afw.math.BackgroundList` + input_image : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps from. + input_background : `~lsst.afw.math.BackgroundList` The background model associated with the input exposure. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + extended_psf : `~lsst.afw.image.ImageF` | `None` + The extended PSF model, built from stacking bright star cutouts. + ref_obj_loader : + `~lsst.meas.algorithms.ReferenceObjectLoader`, optional Loader to find objects within a reference catalog. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure that bright stars are extracted from. + data_id : `dict` or `~lsst.daf.butler.DataCoordinate` + The data ID of the detector that bright stars are extracted from. Both 'visit' and 'detector' will be persisted in the output data. Returns ------- - brightStarResults : `~lsst.pipe.base.Struct` + bright_star_stamps_results : `~lsst.pipe.base.Struct` Results as a struct with attributes: - ``brightStarStamps`` + ``bright_star_stamps`` (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) """ - wcs = inputExposure.getWcs() - bbox = inputExposure.getBBox() - warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + wcs = input_image.getWcs() + bbox = input_image.getBBox() - refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + # Get reference catalog stars + ref_cat = self._get_ref_cat(ref_obj_loader, wcs, bbox) zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] pixCoords = wcs.skyToPixel(spherePoints) @@ -323,6 +328,7 @@ def run( ) # Loop over each bright star + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) stamps, goodFracs, stamps_fitPsfResults = [], [], [] for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore # Excluding faint stars that are not within the frame. @@ -454,8 +460,8 @@ def star_in_frame(self, pixCoord, inputExposureBBox): return False return True - def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: - """Get a bright star subset of the reference catalog. + def _get_ref_cat(self, ref_obj_loader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a subset of the reference catalog. Trim the reference catalog to only those objects within the exposure bounding box dilated by half the bright star stamp size. @@ -463,48 +469,51 @@ def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbo Parameters ---------- - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + ref_obj_loader : `~lsst.meas.algorithms.ReferenceObjectLoader` Loader to find objects within a reference catalog. wcs : `~lsst.afw.geom.SkyWcs` World coordinate system. bbox : `~lsst.geom.Box2I` - Bounding box of the exposure. + Bounding box of the image. Returns ------- - refCatBright : `~astropy.table.Table` - Bright star subset of the reference catalog. + ref_cat : `~astropy.table.Table` + Subset of the reference catalog. """ - dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) - withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") - refCatFull = withinExposure.refCat - fluxField: str = withinExposure.fluxField - - proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) - brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) - - subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( - refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + # Get all stars within a dilated bbox + dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius) + within_dilated_bbox = ref_obj_loader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean") + ref_cat_full = within_dilated_bbox.refCat + flux_field: str = within_dilated_bbox.fluxField + + # Trim to stars within the desired magnitude range + flux_range_nearby = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) + flux_range_bright = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + stars_magnitude_limited = ( + ref_cat_full[flux_field] > np.min((flux_range_nearby[0], flux_range_bright[0])) + ) & (ref_cat_full[flux_field] < np.max((flux_range_nearby[1], flux_range_bright[1]))) + ref_cat_subset = Table( + ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=stars_magnitude_limited) ) - refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) - - proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( - refCatSubset[fluxField] <= proxFluxRange[1] + stars_nearby = (ref_cat_subset[flux_field] >= flux_range_nearby[0]) & ( + ref_cat_subset[flux_field] <= flux_range_nearby[1] ) - brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( - refCatSubset[fluxField] <= brightFluxRange[1] + stars_bright = (ref_cat_subset[flux_field] >= flux_range_bright[0]) & ( + ref_cat_subset[flux_field] <= flux_range_bright[1] ) - coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") - excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore - refCatBrightIsolated = [] - for coord in cast(Iterable[SkyCoord], coords[brightStars]): - neighbors = coords[proxStars] - seps = coord.separation(neighbors).to(u.arcsec) - tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched - refCatBrightIsolated.append(not tooClose.any()) - - refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + # Exclude stars with bright enough neighbors in a specified radius + coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad") + exclude_radius_arcsec = self.config.excludeRadiusArcsec * u.arcsec + ref_cat_bright_isolated = [] + for coord in cast(Iterable[SkyCoord], coords[stars_bright]): + neighbors = coords[stars_nearby] + separations = coord.separation(neighbors).to(u.arcsec) + too_close = (separations > 0) & (separations <= exclude_radius_arcsec) # ensure not self matched + ref_cat_bright_isolated.append(not too_close.any()) + ref_cat_bright = cast(Table, ref_cat_subset[stars_bright][ref_cat_bright_isolated]) + breakpoint() fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes @@ -652,7 +661,7 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, for i in range(self.config.fitIterations): # Gradient-pedestal fitting: - if i: + if i > 0: # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can # remove the star using the scale factor. stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821 From 19b3dd16a9b3819c282dfe324ed28d4378c5fe0e Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Fri, 5 Dec 2025 02:32:55 +0000 Subject: [PATCH 06/12] implementing non-linear fitting and cleaning up --- .../tasks/brightStarSubtraction/__init__.py | 2 + .../brightStarSubtraction/brightStarCutout.py | 746 ++++++++++-------- .../brightStarSubtraction/brightStarStack.py | 292 +++++++ tests/test_brightStarCutout.py | 50 +- 4 files changed, 730 insertions(+), 360 deletions(-) create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/__init__.py create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py new file mode 100644 index 000000000..fe5088369 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -0,0 +1,2 @@ +from .brightStarCutout import * +from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 2c612fe7e..fdc4f50ba 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -23,15 +23,13 @@ __all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] -import math -from copy import deepcopy from typing import Any, Iterable, cast import astropy.units as u import numpy as np from astropy.coordinates import SkyCoord from astropy.table import Table -from lsst.afw.cameraGeom import FIELD_ANGLE, FOCAL_PLANE, PIXELS +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, FOCAL_PLANE from lsst.afw.detection import Footprint, FootprintSet, Threshold from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs from lsst.afw.geom.transformFactory import makeTransform @@ -40,7 +38,6 @@ from lsst.daf.butler import DataCoordinate from lsst.geom import ( AffineTransform, - Angle, Box2I, Extent2D, Extent2I, @@ -50,6 +47,7 @@ arcseconds, floor, radians, + Angle, ) from lsst.meas.algorithms import ( BrightStarStamp, @@ -63,6 +61,9 @@ from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput from lsst.utils.timer import timeMethod +from copy import deepcopy +import math + NEIGHBOR_MASK_PLANE = "NEIGHBOR" @@ -73,7 +74,7 @@ class BrightStarCutoutConnections( ): """Connections for BrightStarCutoutTask.""" - ref_cat = PrerequisiteInput( + refCat = PrerequisiteInput( name="gaia_dr3_20230707", storageClass="SimpleCatalog", doc="Reference catalog that contains bright star positions.", @@ -81,26 +82,26 @@ class BrightStarCutoutConnections( multiple=True, deferLoad=True, ) - input_image = Input( - name="preliminary_visit_image", + inputExposure = Input( + name="calexp", storageClass="ExposureF", doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", dimensions=("visit", "detector"), ) - input_background = Input( - name="preliminary_visit_image_background", + inputBackground = Input( + name="calexpBackground", storageClass="Background", doc="Background model for the input exposure, to be added back on during processing.", dimensions=("visit", "detector"), ) - extended_psf = Input( - name="extended_psf", + extendedPsf = Input( + name="extendedPsf2", storageClass="ImageF", doc="Extended PSF model, built from stacking bright star cutouts.", dimensions=("band",), ) - bright_star_stamps = Output( - name="bright_star_stamps", + brightStarStamps = Output( + name="brightStarStamps", storageClass="BrightStarStamps", doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", dimensions=("visit", "detector"), @@ -109,8 +110,8 @@ class BrightStarCutoutConnections( def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): super().__init__(config=config) assert config is not None - if not config.useExtendedPsf: - self.inputs.remove("extended_psf") + if not config.use_extended_psf: + self.inputs.remove("extendedPsf") class BrightStarCutoutConfig( @@ -120,28 +121,26 @@ class BrightStarCutoutConfig( """Configuration parameters for BrightStarCutoutTask.""" # Star selection - magRange = ListField[float]( + mag_range = ListField[float]( doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", default=[0, 18], ) - excludeRadiusArcsec = Field[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.", + exclude_arcsec_radius = Field[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", default=5, ) - excludeMagRange = ListField[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.", + exclude_mag_range = ListField[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", default=[0, 20], ) - minAreaFraction = Field[float]( + min_area_fraction = Field[float]( doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", default=0.1, ) - # offFrameMagLim = Field[float]( - # doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", - # default=15.0, - # ) - badMaskPlanes = ListField[str]( - doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + bad_mask_planes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``min_area_fraction`` and, " "optionally, fitting of the PSF.", default=[ "BAD", @@ -155,17 +154,15 @@ class BrightStarCutoutConfig( NEIGHBOR_MASK_PLANE, ], ) - - # Stamp configuration - stampSize = ListField[int]( + stamp_size = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(251, 251), ) - stampSizePadding = Field[float]( + stamp_size_padding = Field[float]( doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", default=1.1, ) - warpingKernelName = ChoiceField[str]( + warping_kernel_name = ChoiceField[str]( doc="Warping kernel.", default="lanczos5", allowed={ @@ -175,7 +172,7 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - maskWarpingKernelName = ChoiceField[str]( + mask_warping_kernel_name = ChoiceField[str]( doc="Warping kernel for mask.", default="bilinear", allowed={ @@ -185,35 +182,36 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - # scalePsfModel = Field[bool]( - # doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", - # default=True, - # ) + off_frame_mag_lim = Field[float]( + doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + default=15.0, + ) # PSF Fitting - useExtendedPsf = Field[bool]( - doc="Use the extended PSF model to estimate the bright star cutout normalization factor.", + use_extended_psf = Field[bool]( + doc="Use the extended PSF model to normalize bright star cutouts.", default=False, ) - doFitPsf = Field[bool]( - doc="Fit a scaled PSF and a simple background to each bright star cutout.", + do_fit_psf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", default=True, ) - useMedianVariance = Field[bool]( + use_median_variance = Field[bool]( doc="Use the median of the variance plane for PSF fitting.", default=False, ) - psfMaskedFluxFracThreshold = Field[float]( + psf_masked_flux_frac_threshold = Field[float]( doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", default=0.97, ) - fitIterations = Field[int]( - doc="Number of iterations to constrain PSF fitting.", + fit_iterations = Field[int]( + doc="Number of iterations over pedestal-gradient and scaling fit.", default=5, ) # Misc - loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + + load_reference_objects_config = ConfigField[LoadReferenceObjectsConfig]( doc="Reference object loader for astrometric calibration.", ) @@ -228,7 +226,7 @@ class BrightStarCutoutTask(PipelineTask): catalog and extracts a stamp around each. Second, it shifts and warps each stamp to remove optical distortions and sample all stars on the same pixel grid. - Finally, it optionally fits a PSF and a simple background model. + Finally, it optionally fits a PSF plus plane flux model to the cutout. This final fitting procedure may be used to normalize each bright star stamp prior to stacking when producing extended PSF models. """ @@ -237,41 +235,41 @@ class BrightStarCutoutTask(PipelineTask): _DefaultName = "brightStarCutout" config: BrightStarCutoutConfig - def __init__(self, initInputs=None, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - stamp_size = Extent2D(*self.config.stampSize.list()) + stamp_size = Extent2D(*self.config.stamp_size.list()) stamp_radius = floor(stamp_size / 2) self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius) - padded_stamp_size = stamp_size * self.config.stampSizePadding + padded_stamp_size = stamp_size * self.config.stamp_size_padding self.padded_stamp_radius = floor(padded_stamp_size / 2) self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( self.padded_stamp_radius ) - # self.modelScale = 1 - - def runQuantum(self, butlerQC, input_refs, output_refs): - inputs = butlerQC.get(input_refs) - inputs["data_id"] = butlerQC.quantum.dataId - ref_obj_loader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in input_refs.ref_cat], - refCats=inputs.pop("ref_cat"), - name=self.config.connections.ref_cat, - config=self.config.loadReferenceObjectsConfig, + self.model_scale = 1 + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.load_reference_objects_config, ) - extended_psf = inputs.pop("extended_psf", None) - output = self.run(**inputs, extended_psf=extended_psf, ref_obj_loader=ref_obj_loader) + extendedPsf = inputs.pop("extendedPsf", None) + output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) # Only ingest Stamp if it exists; prevents ingesting an empty FITS file if output: - butlerQC.put(output, output_refs) + butlerQC.put(output, outputRefs) @timeMethod def run( self, - input_image: ExposureF, - input_background: BackgroundList, - extended_psf: ImageF | None, - ref_obj_loader: ReferenceObjectLoader, - data_id: dict[str, Any] | DataCoordinate, + inputExposure: ExposureF, + inputBackground: BackgroundList, + extendedPsf: ImageF | None, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, ): """Identify bright stars within an exposure using a reference catalog, extract stamps around each, warp/shift stamps onto a common frame and @@ -279,35 +277,36 @@ def run( Parameters ---------- - input_image : `~lsst.afw.image.ExposureF` - The background-subtracted image to extract bright star stamps from. - input_background : `~lsst.afw.math.BackgroundList` + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` The background model associated with the input exposure. - extended_psf : `~lsst.afw.image.ImageF` | `None` - The extended PSF model, built from stacking bright star cutouts. - ref_obj_loader : - `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + extendedPsf: `~lsst.afw.image.ImageF` + The extended PSF model from previous iteration(s). + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional Loader to find objects within a reference catalog. - data_id : `dict` or `~lsst.daf.butler.DataCoordinate` - The data ID of the detector that bright stars are extracted from. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. Both 'visit' and 'detector' will be persisted in the output data. Returns ------- - bright_star_stamps_results : `~lsst.pipe.base.Struct` + brightStarResults : `~lsst.pipe.base.Struct` Results as a struct with attributes: - ``bright_star_stamps`` + ``brightStarStamps`` (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) """ - wcs = input_image.getWcs() - bbox = input_image.getBBox() + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warping_control = WarpingControl( + self.config.warping_kernel_name, self.config.mask_warping_kernel_name + ) - # Get reference catalog stars - ref_cat = self._get_ref_cat(ref_obj_loader, wcs, bbox) - zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) - spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] - pixCoords = wcs.skyToPixel(spherePoints) + ref_cat_bright = self._get_ref_cat_bright(refObjLoader, wcs, bbox) + zip_ra_dec = zip(ref_cat_bright["coord_ra"] * radians, ref_cat_bright["coord_dec"] * radians) + sphere_points = [SpherePoint(ra, dec) for ra, dec in zip_ra_dec] + pix_coords = wcs.skyToPixel(sphere_points) # Restore original subtracted background inputMI = inputExposure.getMaskedImage() @@ -315,112 +314,106 @@ def run( # Set up NEIGHBOR mask plane; associate footprints with stars inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) - allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + all_footprints, associations = self._associate_footprints(inputExposure, pix_coords, plane="DETECTED") # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) # Set up transform detector = inputExposure.detector - pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds - pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( - makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + pixel_scale = wcs.getPixelScale().asArcseconds() * arcseconds + pix_to_focal_plane_tan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixel_scale.asRadians())) ) # Loop over each bright star - warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) - stamps, goodFracs, stamps_fitPsfResults = [], [], [] - for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + stamps, good_fracs, stamps_fit_psf_results = [], [], [] + for star_index, (obj, pix_coord) in enumerate(zip(ref_cat_bright, pix_coords)): # type: ignore # Excluding faint stars that are not within the frame. - if obj["mag"] > self.config.offFrameMagLim and not self.star_in_frame(pixCoord, bbox): + if obj["mag"] > self.config.off_frame_mag_lim and not self.star_in_frame(pix_coord, bbox): continue - footprintIndex = associations.get(starIndex, None) - stampMI = MaskedImageF(self.paddedStampBBox) + footprint_index = associations.get(star_index, None) + stampMI = MaskedImageF(self.padded_stamp_bbox) # Set NEIGHBOR footprints in the mask plane - if footprintIndex: - neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) + if footprint_index: + neighbor_footprints = [fp for i, fp in enumerate(all_footprints) if i != footprint_index] + self._set_footprints(inputMI, neighbor_footprints, NEIGHBOR_MASK_PLANE) else: - self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) + self._set_footprints(inputMI, all_footprints, NEIGHBOR_MASK_PLANE) # Define linear shifting to recenter stamps - coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star - shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) - angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + coord_focal_plane_tan = pix_to_focal_plane_tan.applyForward(pix_coord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coord_focal_plane_tan)) + angle = np.arctan2(coord_focal_plane_tan.getY(), coord_focal_plane_tan.getX()) * radians rotation = makeTransform(AffineTransform.makeRotation(-angle)) - pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + pix_to_polar = pix_to_focal_plane_tan.then(shift).then(rotation) # Apply the warp to the star stamp (in-place) - warpImage(stampMI, inputMI, pixToPolar, warpingControl) + warpImage(stampMI, inputMI, pix_to_polar, warping_control) # Trim to the base stamp size, check mask coverage, update metadata - stampMI = stampMI[self.stampBBox] - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) - goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size - goodFracs.append(goodFrac) - if goodFrac < self.config.minAreaFraction: + stampMI = stampMI[self.stamp_bbox] + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + good_frac = np.sum(stampMI.mask.array & bad_mask_bit_mask == 0) / stampMI.mask.array.size + good_fracs.append(good_frac) + if good_frac < self.config.min_area_fraction: continue # Fit a scaled PSF and a pedestal to each bright star cutout - psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) - constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) - if self.config.useExtendedPsf: - psfImage = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] + psf = WarpedPsf(inputExposure.getPsf(), pix_to_polar, warping_control) + constant_psf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + if self.config.use_extended_psf: + psf_image = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] else: - psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + psf_image = constant_psf.computeKernelImage(constant_psf.getAveragePosition()) # TODO: maybe we want to generate a smaller psf in case the following happens? - # The following could happen for when the user chooses small stampSize ~(50, 50) + # The following could happen for when the user chooses small stamp_size ~(50, 50) if ( - psfImage.array.shape[0] > stampMI.image.array.shape[0] - or psfImage.array.shape[1] > stampMI.image.array.shape[1] + psf_image.array.shape[0] > stampMI.image.array.shape[0] + or psf_image.array.shape[1] > stampMI.image.array.shape[1] ): continue # Computing an scale factor that brings the model to the similar level of the star. - self.computeModelScale(stampMI, psfImage) - psfImage.array *= self.modelScale # ####### model scale correction ######## + self.estimate_model_scale_value(stampMI, psf_image) + psf_image.array *= self.model_scale # ####### model scale correction ######## - fitPsfResults = {} + fit_psf_results = {} - if self.config.doFitPsf: - fitPsfResults = self._fitPsf(stampMI, psfImage) - stamps_fitPsfResults.append(fitPsfResults) + if self.config.do_fit_psf: + fit_psf_results = self._fit_psf(stampMI, psf_image) + stamps_fit_psf_results.append(fit_psf_results) # Save the stamp if the PSF fit was successful or no fit requested - if fitPsfResults or not self.config.doFitPsf: - distance_mm, theta_angle = self.star_location_on_focal(pixCoord, detector) + if fit_psf_results or not self.config.do_fit_psf: + distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) stamp = BrightStarStamp( stamp_im=stampMI, - psf=constantPsf, - wcs=makeModifiedWcs(pixToPolar, wcs, False), + psf=constant_psf, + wcs=makeModifiedWcs(pix_to_polar, wcs, False), visit=cast(int, dataId["visit"]), detector=cast(int, dataId["detector"]), ref_id=obj["id"], ref_mag=obj["mag"], - position=pixCoord, + position=pix_coord, focal_plane_radius=distance_mm, focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here - scale=fitPsfResults.get("scale", None), - scale_err=fitPsfResults.get("scaleErr", None), - pedestal=fitPsfResults.get("pedestal", None), - pedestal_err=fitPsfResults.get("pedestalErr", None), - pedestal_scale_cov=fitPsfResults.get("pedestalScaleCov", None), - gradient_x=fitPsfResults.get("xGradient", None), - gradient_y=fitPsfResults.get("yGradient", None), - global_reduced_chi_squared=fitPsfResults.get("globalReducedChiSquared", None), - global_degrees_of_freedom=fitPsfResults.get("globalDegreesOfFreedom", None), - psf_reduced_chi_squared=fitPsfResults.get("psfReducedChiSquared", None), - psf_degrees_of_freedom=fitPsfResults.get("psfDegreesOfFreedom", None), - psf_masked_flux_fraction=fitPsfResults.get("psfMaskedFluxFrac", None), - ) - print( - obj["mag"], - fitPsfResults.get("globalReducedChiSquared", None), - fitPsfResults.get("globalDegreesOfFreedom", None), - fitPsfResults.get("psfReducedChiSquared", None), - fitPsfResults.get("psfDegreesOfFreedom", None), - fitPsfResults.get("psfMaskedFluxFrac", None), + scale=fit_psf_results.get("scale", None), + scale_err=fit_psf_results.get("scale_err", None), + pedestal=fit_psf_results.get("pedestal", None), + pedestal_err=fit_psf_results.get("pedestal_err", None), + pedestal_scale_cov=fit_psf_results.get("pedestal_scale_cov", None), + gradient_x=fit_psf_results.get("x_gradient", None), + gradient_y=fit_psf_results.get("y_gradient", None), + curvature_x=fit_psf_results.get("curvature_x", None), + curvature_y=fit_psf_results.get("curvature_y", None), + curvature_xy=fit_psf_results.get("curvature_xy", None), + global_reduced_chi_squared=fit_psf_results.get("global_reduced_chi_squared", None), + global_degrees_of_freedom=fit_psf_results.get("global_degrees_of_freedom", None), + psf_reduced_chi_squared=fit_psf_results.get("psf_reduced_chi_squared", None), + psf_degrees_of_freedom=fit_psf_results.get("psf_degrees_of_freedom", None), + psf_masked_flux_fraction=fit_psf_results.get("psf_masked_flux_frac", None), ) stamps.append(stamp) @@ -429,39 +422,70 @@ def run( "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", len(stamps), "" if len(stamps) == 1 else "s", - len(refCatBright) - len(stamps), - "" if len(refCatBright) - len(stamps) == 1 else "s", - np.sum(np.array(goodFracs) < self.config.minAreaFraction), + len(ref_cat_bright) - len(stamps), + "" if len(ref_cat_bright) - len(stamps) == 1 else "s", + np.sum(np.array(good_fracs) < self.config.min_area_fraction), ( - np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) - if self.config.doFitPsf + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fit_psf_results])) + if self.config.do_fit_psf else 0 ), ) brightStarStamps = BrightStarStamps(stamps) return Struct(brightStarStamps=brightStarStamps) - def star_location_on_focal(self, pixCoord, detector): - star_focal_plane_coords = detector.transform(pixCoord, PIXELS, FOCAL_PLANE) + def star_location_on_focal(self, pix_coord, detector): + """ + Calculates the radial coordinates of a star on the focal plane. + + Transforms the given pixel coordinates to the focal plane and computes + the radial distance and angle relative to the optical axis. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) coordinates of the star on the + detector in pixels. + detector (Detector): `~lsst.afw.cameraGeom.Detector` + The detector object capable of transforming coordinates + from PIXELS to FOCAL_PLANE. + + Returns: + tuple: A tuple containing: + - distance_mm (float): The radial distance from the center in millimeters. + - theta_angle (Angle): The azimuthal angle of the star on the focal plane. + """ + star_focal_plane_coords = detector.transform(pix_coord, PIXELS, FOCAL_PLANE) star_x_fp = star_focal_plane_coords.getX() star_y_fp = star_focal_plane_coords.getY() - distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2) + distance_mm = np.sqrt(star_x_fp ** 2 + star_y_fp ** 2) theta_rad = math.atan2(star_y_fp, star_x_fp) theta_angle = Angle(theta_rad, radians) return distance_mm, theta_angle - def star_in_frame(self, pixCoord, inputExposureBBox): + def star_in_frame(self, pix_coord, inputExposureBBox): + """ + Checks if a star's pixel coordinates lie within the exposure boundaries. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) pixel coordinates of the star. + inputExposureBBox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns: + bool: True if the coordinates are within the frame limits, False otherwise. + """ if ( - pixCoord[0] < 0 - or pixCoord[1] < 0 - or pixCoord[0] > inputExposureBBox.getDimensions()[0] - or pixCoord[1] > inputExposureBBox.getDimensions()[1] + pix_coord[0] < 0 + or pix_coord[1] < 0 + or pix_coord[0] > inputExposureBBox.getDimensions()[0] + or pix_coord[1] > inputExposureBBox.getDimensions()[1] ): return False return True - def _get_ref_cat(self, ref_obj_loader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: - """Get a subset of the reference catalog. + def _get_ref_cat_bright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. Trim the reference catalog to only those objects within the exposure bounding box dilated by half the bright star stamp size. @@ -469,69 +493,68 @@ def _get_ref_cat(self, ref_obj_loader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Parameters ---------- - ref_obj_loader : `~lsst.meas.algorithms.ReferenceObjectLoader` + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` Loader to find objects within a reference catalog. wcs : `~lsst.afw.geom.SkyWcs` World coordinate system. bbox : `~lsst.geom.Box2I` - Bounding box of the image. + Bounding box of the exposure. Returns ------- - ref_cat : `~astropy.table.Table` - Subset of the reference catalog. + ref_cat_bright : `~astropy.table.Table` + Bright star subset of the reference catalog. """ - # Get all stars within a dilated bbox dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius) - within_dilated_bbox = ref_obj_loader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean") - ref_cat_full = within_dilated_bbox.refCat - flux_field: str = within_dilated_bbox.fluxField - - # Trim to stars within the desired magnitude range - flux_range_nearby = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) - flux_range_bright = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) - stars_magnitude_limited = ( - ref_cat_full[flux_field] > np.min((flux_range_nearby[0], flux_range_bright[0])) - ) & (ref_cat_full[flux_field] < np.max((flux_range_nearby[1], flux_range_bright[1]))) + within_exposure = refObjLoader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean") + ref_cat_full = within_exposure.refCat + flux_field: str = within_exposure.fluxField + + prox_flux_range = sorted(((self.config.exclude_mag_range * u.ABmag).to(u.nJy)).to_value()) + bright_flux_range = sorted(((self.config.mag_range * u.ABmag).to(u.nJy)).to_value()) + + subset_stars = (ref_cat_full[flux_field] > np.min((prox_flux_range[0], bright_flux_range[0]))) & ( + ref_cat_full[flux_field] < np.max((prox_flux_range[1], bright_flux_range[1])) + ) ref_cat_subset = Table( - ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=stars_magnitude_limited) + ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=subset_stars) ) - stars_nearby = (ref_cat_subset[flux_field] >= flux_range_nearby[0]) & ( - ref_cat_subset[flux_field] <= flux_range_nearby[1] + + prox_stars = (ref_cat_subset[flux_field] >= prox_flux_range[0]) & ( + ref_cat_subset[flux_field] <= prox_flux_range[1] ) - stars_bright = (ref_cat_subset[flux_field] >= flux_range_bright[0]) & ( - ref_cat_subset[flux_field] <= flux_range_bright[1] + bright_stars = (ref_cat_subset[flux_field] >= bright_flux_range[0]) & ( + ref_cat_subset[flux_field] <= bright_flux_range[1] ) - # Exclude stars with bright enough neighbors in a specified radius coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad") - exclude_radius_arcsec = self.config.excludeRadiusArcsec * u.arcsec + exclude_arcsec_radius = self.config.exclude_arcsec_radius * u.arcsec # type: ignore ref_cat_bright_isolated = [] - for coord in cast(Iterable[SkyCoord], coords[stars_bright]): - neighbors = coords[stars_nearby] - separations = coord.separation(neighbors).to(u.arcsec) - too_close = (separations > 0) & (separations <= exclude_radius_arcsec) # ensure not self matched + for coord in cast(Iterable[SkyCoord], coords[bright_stars]): + neighbors = coords[prox_stars] + seps = coord.separation(neighbors).to(u.arcsec) + too_close = (seps > 0) & (seps <= exclude_arcsec_radius) # not self matched ref_cat_bright_isolated.append(not too_close.any()) - ref_cat_bright = cast(Table, ref_cat_subset[stars_bright][ref_cat_bright_isolated]) - breakpoint() - fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore - refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + ref_cat_bright = cast(Table, ref_cat_subset[bright_stars][ref_cat_bright_isolated]) + + flux_nanojansky = ref_cat_bright[flux_field][:] * u.nJy # type: ignore + ref_cat_bright["mag"] = flux_nanojansky.to(u.ABmag).to_value() # AB magnitudes self.log.info( "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " "stars within %s arcsec.", - len(refCatBright), - len(refCatFull), - "" if len(refCatFull) == 1 else "s", - self.config.magRange, - self.config.excludeArcsecRadius, + len(ref_cat_bright), + len(ref_cat_full), + "" if len(ref_cat_full) == 1 else "s", + self.config.mag_range, + self.config.exclude_arcsec_radius, ) - return refCatBright + return ref_cat_bright - def _associateFootprints( - self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + def _associate_footprints( + self, inputExposure: ExposureF, pix_coords: list[Point2D], plane: str ) -> tuple[list[Footprint], dict[int, int]]: """Associate footprints from a given mask plane with specific objects. @@ -542,7 +565,7 @@ def _associateFootprints( ---------- inputExposure : `~lsst.afw.image.ExposureF` The input exposure with a mask plane. - pixCoords : `list` [`~lsst.geom.Point2D`] + pix_coords : `list` [`~lsst.geom.Point2D`] The pixel coordinates of the objects. plane : `str` The mask plane used to identify masked pixels. @@ -554,27 +577,27 @@ def _associateFootprints( associations : `dict`[int, int] Association indices between objects (key) and footprints (value). """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, det_threshold) footprints = footprintSet.getFootprints() associations = {} - for starIndex, pixCoord in enumerate(pixCoords): - for footprintIndex, footprint in enumerate(footprints): - if footprint.contains(Point2I(pixCoord)): - associations[starIndex] = footprintIndex + for star_index, pix_coord in enumerate(pix_coords): + for footprint_index, footprint in enumerate(footprints): + if footprint.contains(Point2I(pix_coord)): + associations[star_index] = footprint_index break self.log.debug( "Associated %i of %i star%s to one each of the %i %s footprint%s.", len(associations), - len(pixCoords), - "" if len(pixCoords) == 1 else "s", + len(pix_coords), + "" if len(pix_coords) == 1 else "s", len(footprints), plane, "" if len(footprints) == 1 else "s", ) return footprints, associations - def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + def _set_footprints(self, inputExposure: ExposureF, footprints: list, mask_plane: str): """Set footprints in a given mask plane. Parameters @@ -583,188 +606,217 @@ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: The input exposure to modify. footprints : `list` [`~lsst.afw.detection.Footprint`] The footprints to set in the mask plane. - maskPlane : `str` + mask_plane : `str` The mask plane to set the footprints in. Notes ----- This method modifies the ``inputExposure`` object in-place. """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) - detThresholdValue = int(detThreshold.getValue()) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(mask_plane), Threshold.BITMASK) + det_threshold_value = int(det_threshold.getValue()) + footprint_set = FootprintSet(inputExposure.mask, det_threshold) # Wipe any existing footprints in the mask plane - inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + inputExposure.mask.clearMaskPlane(int(np.log2(det_threshold_value))) # Set the footprints in the mask plane - footprintSet.setFootprints(footprints) - footprintSet.setMask(inputExposure.mask, maskPlane) + footprint_set.setFootprints(footprints) + footprint_set.setMask(inputExposure.mask, mask_plane) - def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + def _fit_psf(self, stampMI: MaskedImageF, psf_image: ImageD | ImageF) -> dict[str, Any]: """Fit a scaled PSF and a pedestal to each bright star cutout. Parameters ---------- stampMI : `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF model to fit. Returns ------- - fitPsfResults : `dict`[`str`, `float`] + fit_psf_results : `dict`[`str`, `float`] The result of the PSF fitting, with keys: ``scale`` : `float` The scale factor. - ``scaleErr`` : `float` + ``scale_err`` : `float` The error on the scale factor. ``pedestal`` : `float` The pedestal value. - ``pedestalErr`` : `float` + ``pedestal_err`` : `float` The error on the pedestal value. - ``pedestalScaleCov`` : `float` + ``pedestal_scale_cov`` : `float` The covariance between the pedestal and scale factor. - ``xGradient`` : `float` + ``x_gradient`` : `float` The gradient in the x-direction. - ``yGradient`` : `float` + ``y_gradient`` : `float` The gradient in the y-direction. - ``globalReducedChiSquared`` : `float` + ``global_reduced_chi_squared`` : `float` The global reduced chi-squared goodness-of-fit. - ``globalDegreesOfFreedom`` : `int` + ``global_degrees_of_freedom`` : `int` The global number of degrees of freedom. - ``psfReducedChiSquared`` : `float` + ``psf_reduced_chi_squared`` : `float` The PSF BBox reduced chi-squared goodness-of-fit. - ``psfDegreesOfFreedom`` : `int` + ``psf_degrees_of_freedom`` : `int` The PSF BBox number of degrees of freedom. - ``psfMaskedFluxFrac`` : `float` + ``psf_masked_flux_frac`` : `float` The fraction of the PSF image flux masked by bad pixels. """ - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) # Calculate the fraction of the PSF image flux masked by bad pixels - psfMaskedPixels = ImageF(psfImage.getBBox()) - psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) - psfMaskedFluxFrac = ( - np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + psf_masked_pixels = ImageF(psf_image.getBBox()) + psf_masked_pixels.array[:, :] = (stampMI.mask[psf_image.getBBox()].array & bad_mask_bit_mask).astype( + bool ) - if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + psf_masked_flux_frac = ( + np.dot(psf_image.array.flat, psf_masked_pixels.array.flat).astype(np.float64) + / psf_image.array.sum() + ) + if psf_masked_flux_frac > self.config.psf_masked_flux_frac_threshold: return {} # Handle cases where the PSF image is mostly masked # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask). - gradientGoodSpans = self.generate_gradient_spans(stampMI, badMaskBitMask) - varianceData = gradientGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - if self.config.useMedianVariance: - varianceData = np.median(varianceData) - sigmaData = np.sqrt(varianceData) + gradient_good_spans = self.generate_gradient_spans(stampMI, bad_mask_bit_mask) + variance_data = gradient_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data = np.median(variance_data) + sigma_data = np.sqrt(variance_data) - for i in range(self.config.fitIterations): + for i in range(self.config.fit_iterations): # Gradient-pedestal fitting: - if i > 0: + if i: # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can # remove the star using the scale factor. - stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821 + stamp = self.remove_star(stampMI, scale, padded_psf_image) # noqa: F821 else: stamp = deepcopy(stampMI.image.array) - imageDataGr = gradientGoodSpans.flatten(stamp, stampMI.getXY0()) / sigmaData # B - nData = len(imageDataGr) - coefficientMatrix = np.ones((nData, 3), dtype=float) # A - coefficientMatrix[:, 0] /= sigmaData - coefficientMatrix[:, 1:] = gradientGoodSpans.indices().T - coefficientMatrix[:, 1] /= sigmaData - coefficientMatrix[:, 2] /= sigmaData + image_data_gr = gradient_good_spans.flatten(stamp, stampMI.getXY0()) / sigma_data # B + n_data = len(image_data_gr) + + xy = gradient_good_spans.indices() + y = xy[0, :] + x = xy[1, :] + coefficient_matrix = np.ones((n_data, 6), dtype=float) # A + coefficient_matrix[:, 0] /= sigma_data + coefficient_matrix[:, 1] = y / sigma_data + coefficient_matrix[:, 2] = x / sigma_data + coefficient_matrix[:, 3] = y ** 2 / sigma_data + coefficient_matrix[:, 4] = x ** 2 / sigma_data + coefficient_matrix[:, 5] = x * y / sigma_data + # scikit might have a fitting tool try: - grSolutions, grSumSquaredResiduals, *_ = np.linalg.lstsq( - coefficientMatrix, imageDataGr, rcond=None + gr_solutions, gr_sum_squared_residuals, *_ = np.linalg.lstsq( + coefficient_matrix, image_data_gr, rcond=None ) - covarianceMatrix = np.linalg.inv( - np.dot(coefficientMatrix.transpose(), coefficientMatrix) + covariance_matrix = np.linalg.inv( + np.dot(coefficient_matrix.transpose(), coefficient_matrix) ) # C except np.linalg.LinAlgError: return {} # Handle singular matrix errors - if grSumSquaredResiduals.size == 0: + if gr_sum_squared_residuals.size == 0: return {} # Handle cases where sum of the squared residuals are empty - pedestal = grSolutions[0] - pedestalErr = np.sqrt(covarianceMatrix[0, 0]) - scalePedestalCov = None - xGradient = grSolutions[2] - yGradient = grSolutions[1] + pedestal = gr_solutions[0] + pedestal_err = np.sqrt(covariance_matrix[0, 0]) + pedestal_scale_cov = None + x_gradient = gr_solutions[2] + y_gradient = gr_solutions[1] + x_curvature = gr_solutions[4] + y_curvature = gr_solutions[3] + curvature_xy = gr_solutions[5] # Scale fitting: updatedStampMI = deepcopy(stampMI) - self._removePedestalAndGradient(updatedStampMI, pedestal, xGradient, yGradient) + self._removePedestalAndGradient( + updatedStampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, curvature_xy + ) # Create a padded version of the input constant PSF image - paddedPsfImage = ImageF(updatedStampMI.getBBox()) - paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + padded_psf_image = ImageF(updatedStampMI.getBBox()) + padded_psf_image[psf_image.getBBox()] = psf_image.convertF() # Generating a mask plane while considering bad pixels in the psf model. - mask = self.add_psf_mask(paddedPsfImage, updatedStampMI) + mask = self.add_psf_mask(padded_psf_image, updatedStampMI) # Create consistently masked data - scaleGoodSpans = self.generate_good_spans(mask, updatedStampMI.getBBox(), badMaskBitMask) + scale_good_spans = self.generate_good_spans(mask, updatedStampMI.getBBox(), bad_mask_bit_mask) + + variance_data_scale = scale_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data_scale = np.median(variance_data_scale) + sigma_data_scale = np.sqrt(variance_data_scale) - imageData = scaleGoodSpans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) - psfData = scaleGoodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) - scaleCoefficientMatrix = psfData.reshape(psfData.shape[0], 1) + image_data = scale_good_spans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) + psf_data = scale_good_spans.flatten(padded_psf_image.array, padded_psf_image.getXY0()) + image_data /= sigma_data_scale + psf_data /= sigma_data_scale + scale_coefficient_matrix = psf_data.reshape(psf_data.shape[0], 1) try: - scaleSolution, scaleSumSquaredResiduals, *_ = np.linalg.lstsq( - scaleCoefficientMatrix, imageData, rcond=None + scale_solution, scale_sum_squared_residuals, *_ = np.linalg.lstsq( + scale_coefficient_matrix, image_data, rcond=None ) except np.linalg.LinAlgError: return {} # Handle singular matrix errors - if scaleSumSquaredResiduals.size == 0: + if scale_sum_squared_residuals.size == 0: return {} # Handle cases where sum of the squared residuals are empty - scale = scaleSolution[0] + scale = scale_solution[0] if scale <= 0: return {} # Handle cases where the PSF scale fit has failed - # TODO: calculate scale error and store it. - scaleErr = None - scale *= self.modelScale # ####### model scale correction ######## - nData = len(imageData) + scale *= self.model_scale # ####### model scale correction ######## + n_data = len(image_data) - # Calculate global (whole image) reduced chi-squared (scaling fit is assumed as the main fitting + scale_covariance_matrix = np.linalg.inv( + np.dot(scale_coefficient_matrix.transpose(), scale_coefficient_matrix) + ) # C + scale_err = scale_covariance_matrix[0].astype(float)[0] + + # Calculate global (whole image) reduced chi-squared (scale fit is assumed as the main fitting # process here.) - globalChiSquared = np.sum(scaleSumSquaredResiduals) - globalDegreesOfFreedom = nData - 1 - globalReducedChiSquared = np.float64(globalChiSquared / globalDegreesOfFreedom) + global_chi_squared = np.sum(scale_sum_squared_residuals) + global_degrees_of_freedom = n_data - 1 + global_reduced_chi_squared = np.float64(global_chi_squared / global_degrees_of_freedom) # Calculate PSF BBox reduced chi-squared - psfBBoxscaleGoodSpans = scaleGoodSpans.clippedTo(psfImage.getBBox()) - psfBBoxscaleGoodSpansX, psfBBoxscaleGoodSpansY = psfBBoxscaleGoodSpans.indices() - psfBBoxData = psfBBoxscaleGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) - paddedPsfImage.array /= self.modelScale # ####### model scale correction ######## - psfBBoxModel = ( - psfBBoxscaleGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + psf_bbox_scale_good_spans = scale_good_spans.clippedTo(psf_image.getBBox()) + psf_bbox_scale_good_spans_x, psf_bbox_scale_good_spans_y = psf_bbox_scale_good_spans.indices() + psf_bbox_data = psf_bbox_scale_good_spans.flatten(stampMI.image.array, stampMI.getXY0()) + padded_psf_image.array /= self.model_scale # ####### model scale correction ######## + psf_bbox_model = ( + psf_bbox_scale_good_spans.flatten(padded_psf_image.array, stampMI.getXY0()) * scale + pedestal - + psfBBoxscaleGoodSpansX * xGradient - + psfBBoxscaleGoodSpansY * yGradient + + psf_bbox_scale_good_spans_x * x_gradient + + psf_bbox_scale_good_spans_y * y_gradient ) - psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 # / psfBBoxVariance - psfBBoxChiSquared = np.sum(psfBBoxResiduals) - psfBBoxDegreesOfFreedom = len(psfBBoxData) - 1 - psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + psf_bbox_residuals = (psf_bbox_data - psf_bbox_model) ** 2 # / psfBBoxVariance + psf_bbox_chi_squared = np.sum(psf_bbox_residuals) + psf_bbox_degrees_of_freedom = len(psf_bbox_data) - 1 + psf_bbox_reduced_chi_squared = psf_bbox_chi_squared / psf_bbox_degrees_of_freedom + return dict( scale=scale, - scaleErr=scaleErr, + scale_err=scale_err, pedestal=pedestal, - pedestalErr=pedestalErr, - xGradient=xGradient, - yGradient=yGradient, - pedestalScaleCov=scalePedestalCov, - globalReducedChiSquared=globalReducedChiSquared, - globalDegreesOfFreedom=globalDegreesOfFreedom, - psfReducedChiSquared=psfBBoxReducedChiSquared, - psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, - psfMaskedFluxFrac=psfMaskedFluxFrac, + pedestal_err=pedestal_err, + x_gradient=x_gradient, + y_gradient=y_gradient, + curvature_x=x_curvature, + curvature_y=y_curvature, + curvature_xy=curvature_xy, + pedestal_scale_cov=pedestal_scale_cov, + global_reduced_chi_squared=global_reduced_chi_squared, + global_degrees_of_freedom=global_degrees_of_freedom, + psf_reduced_chi_squared=psf_bbox_reduced_chi_squared, + psf_degrees_of_freedom=psf_bbox_degrees_of_freedom, + psf_masked_flux_frac=psf_masked_flux_frac, ) - def add_psf_mask(self, psfImage, stampMI, maskZeros=True): + def add_psf_mask(self, psf_image, stampMI, maskZeros=True): """ Creates a new mask by adding PSF bad pixels to an existing stamp mask. @@ -773,7 +825,7 @@ def add_psf_mask(self, psfImage, stampMI, maskZeros=True): of the input stamp's mask. Args: - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF image object. stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. @@ -783,26 +835,34 @@ def add_psf_mask(self, psfImage, stampMI, maskZeros=True): Returns: Any: A new mask object (deep copy) with the PSF mask planes added. """ - cond = np.isnan(psfImage.array) + cond = np.isnan(psf_image.array) if maskZeros: - cond |= psfImage.array <= 0 + cond |= psf_image.array <= 0 else: - cond |= psfImage.array < 0 + cond |= psf_image.array < 0 mask = deepcopy(stampMI.mask) mask.array[cond] = np.bitwise_or(mask.array[cond], 1) return mask - def _removePedestalAndGradient(self, stampMI, pedestal, xGradient, yGradient): + def _removePedestalAndGradient( + self, stampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, curvature_xy + ): """Apply fitted pedestal and gradients to a single bright star stamp.""" - stampBBox = stampMI.getBBox() - xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) - xPlane = ImageF((xGrid * xGradient).astype(np.float32), xy0=stampMI.getXY0()) - yPlane = ImageF((yGrid * yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stamp_bbox = stampMI.getBBox() + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + x_plane = ImageF((x_grid * x_gradient).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * y_gradient).astype(np.float32), xy0=stampMI.getXY0()) + x_curve = ImageF((x_grid ** 2 * x_curvature).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid ** 2 * y_curvature).astype(np.float32), xy0=stampMI.getXY0()) + curvature_xy = ImageF((x_grid * y_grid * curvature_xy).astype(np.float32), xy0=stampMI.getXY0()) stampMI -= pedestal - stampMI -= xPlane - stampMI -= yPlane + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= curvature_xy - def remove_star(self, stampMI, scale, psfImage): + def remove_star(self, stampMI, scale, psf_image): """ Subtracts a scaled PSF model from a star image. @@ -812,74 +872,76 @@ def remove_star(self, stampMI, scale, psfImage): stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. scale (float): The scaling factor to apply to the PSF. - psfImage: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF image object. Returns: np.ndarray: A new 2D numpy array containing the star-subtracted image. """ - star_removed_cutout = stampMI.image.array - psfImage.array * scale + star_removed_cutout = stampMI.image.array - psf_image.array * scale return star_removed_cutout - def computeModelScale(self, stampMI, psfImage): + def estimate_model_scale_value(self, stampMI, psf_image): """ Computes the scaling factor of the given model against a star. Args: stampMI : `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The given PSF model. """ cond = stampMI.mask.array == 0 - self.starMedianValue = np.median(stampMI.image.array[cond]).astype(np.float64) + self.star_median = np.median(stampMI.image.array[cond]).astype(np.float64) - psfPos = psfImage.array > 0 + psf_positives = psf_image.array > 0 - imageArray = stampMI.image.array - self.starMedianValue - imageArrayPos = imageArray > 0 - self.modelScale = np.nanmean(imageArray[imageArrayPos]) / np.nanmean(psfImage.array[psfPos]) + image_array = stampMI.image.array - self.star_median + image_array_positives = image_array > 0 + self.model_scale = np.nanmean(image_array[image_array_positives]) / np.nanmean( + psf_image.array[psf_positives] + ) - def generate_gradient_spans(self, stampMI, badMaskBitMask): + def generate_gradient_spans(self, stampMI, bad_mask_bit_mask): """ Generates spans of "good" pixels for gradient fitting. This method creates a combined bitmask by OR-ing the provided - `badMaskBitMask` with the "DETECTED" plane from the stamp's mask. + `bad_mask_bit_mask` with the "DETECTED" plane from the stamp's mask. It then calls `self.generate_good_spans` to find all pixel spans not covered by this combined mask. Args: stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - badMaskBitMask (int): A bitmask representing planes to be + bad_mask_bit_mask (int): A bitmask representing planes to be considered "bad" for gradient fitting. Returns: - gradientGoodSpans: A SpanSet object containing the "good" spans. + gradient_good_spans: A SpanSet object containing the "good" spans. """ - detectedMaskBitMask = stampMI.mask.getPlaneBitMask("DETECTED") - gradientBitMask = np.bitwise_or(badMaskBitMask, detectedMaskBitMask) + bit_mask_detected = stampMI.mask.getPlaneBitMask("DETECTED") + gradient_bit_mask = np.bitwise_or(bad_mask_bit_mask, bit_mask_detected) - gradientGoodSpans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradientBitMask) - return gradientGoodSpans + gradient_good_spans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradient_bit_mask) + return gradient_good_spans - def generate_good_spans(self, mask, bBox, badBitMask): + def generate_good_spans(self, mask, bBox, bad_bit_mask): """ Generates a SpanSet of "good" pixels from a mask. This method identifies all spans within a given bounding box (`bBox`) - that are *not* flagged by the `badBitMask` in the provided `mask`. + that are *not* flagged by the `bad_bit_mask` in the provided `mask`. Args: mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`). bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`). - badBitMask (int): The combined bitmask of planes to exclude. + bad_bit_mask (int): The combined bitmask of planes to exclude. Returns: - goodSpans: A SpanSet object representing all "good" spans. + good_spans: A SpanSet object representing all "good" spans. """ - badSpans = SpanSet.fromMask(mask, badBitMask) - goodSpans = SpanSet(bBox).intersectNot(badSpans) - return goodSpans + bad_spans = SpanSet.fromMask(mask, bad_bit_mask) + good_spans = SpanSet(bBox).intersectNot(bad_spans) + return good_spans diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py new file mode 100644 index 000000000..bf7d0641c --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -0,0 +1,292 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" + +__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] + +import numpy as np +from lsst.afw.image import ImageF +from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty +from lsst.geom import Point2I +from lsst.meas.algorithms import BrightStarStamps +from lsst.pex.config import Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarStackConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), +): + """Connections for BrightStarStackTask.""" + + brightStarStamps = Input( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + multiple=True, + deferLoad=True, + ) + extendedPsf = Output( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # stamp_imF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + + +class BrightStarStackConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarStackConnections, +): + """Configuration parameters for BrightStarStackTask.""" + + global_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=5.0, + ) + psf_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=50.0, + ) + bright_star_threshold = Field[float]( + doc="Stars brighter than this magnitude, are considered as bright stars.", + default=12.0, + ) + bright_global_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=250.0, + ) + psf_bright_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=400.0, + ) + + bad_mask_planes = ListField[str]( + doc="Mask planes that identify excluded (masked) pixels.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + stack_type = Field[str]( + default="WEIGHTED_MEDIAN", + doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", + ) + stack_num_sigma_clip = Field[float]( + doc="Number of sigma to use for clipping when stacking.", + default=3.0, + ) + stack_num_iter = Field[int]( + doc="Number of iterations to use for clipping when stacking.", + default=5, + ) + magnitude_bins = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Bins of magnitudes for weighting purposes.", + default=[20, 19, 18, 17, 16, 15, 13, 10], + ) + subset_stamp_number = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Number of stamps per subset to generate stacked " + "images for. The length of this parameter must be equal to the length of magnitude_bins plus one.", + default=[300, 200, 150, 100, 100, 100, 1], + ) + min_focal_plane_radius = Field[float]( + doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than " + "this will be omitted.", + default=-1.0, + ) + max_focal_plane_radius = Field[float]( + doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than " + "this will be omitted.", + default=2000.0, + ) + + +class BrightStarStackTask(PipelineTask): + """Stack bright star postage stamps to produce an extended PSF model.""" + + ConfigClass = BrightStarStackConfig + _DefaultName = "brightStarStack" + config: BrightStarStackConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + output = self.run(**inputs) + butlerQC.put(output, outputRefs) + + def _applyStampFit(self, stamp): + """Apply fitted stamp components to a single bright star stamp.""" + stampMI = stamp.stamp_im + stamp_bbox = stampMI.getBBox() + + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + + x_plane = ImageF((x_grid * stamp.gradient_x).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * stamp.gradient_y).astype(np.float32), xy0=stampMI.getXY0()) + + x_curve = ImageF((x_grid**2 * stamp.curvature_x).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid**2 * stamp.curvature_y).astype(np.float32), xy0=stampMI.getXY0()) + xy_curve = ImageF((x_grid * y_grid * stamp.curvature_xy).astype(np.float32), xy0=stampMI.getXY0()) + + stampMI -= stamp.pedestal + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= xy_curve + stampMI /= stamp.scale + + @timeMethod + def run( + self, + brightStarStamps: BrightStarStamps, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + inputBackground : `~lsst.afw.image.Background` + The background model for the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + if self.config.stack_type == "WEIGHTED_MEDIAN": + stack_type_property = stringToStatisticsProperty("MEDIAN") + else: + stack_type_property = stringToStatisticsProperty(self.config.stack_type) + statistics_control = StatisticsControl( + numSigmaClip=self.config.stack_num_sigma_clip, + numIter=self.config.stack_num_iter, + ) + + mag_bins_dict = {} + subset_stampMIs = {} + self.metadata["psf_star_count"] = {} + self.metadata["psf_star_count"]["all"] = 0 + for i in range(len(self.config.subset_stamp_number)): + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] = 0 + for stampsDDH in brightStarStamps: + stamps = stampsDDH.get() + self.metadata["psf_star_count"]["all"] += len(stamps) + for stamp in stamps: + if stamp.ref_mag >= self.config.bright_star_threshold: + global_reduced_chi_squared_threshold = self.config.global_reduced_chi_squared_threshold + psf_reduced_chi_squared_threshold = self.config.psf_reduced_chi_squared_threshold + else: + global_reduced_chi_squared_threshold = ( + self.config.bright_global_reduced_chi_squared_threshold + ) + psf_reduced_chi_squared_threshold = self.config.psf_bright_reduced_chi_squared_threshold + for i in range(len(self.config.subset_stamp_number)): + if ( + stamp.global_reduced_chi_squared > global_reduced_chi_squared_threshold + or stamp.psf_reduced_chi_squared > psf_reduced_chi_squared_threshold + or stamp.focal_plane_radius < self.config.min_focal_plane_radius + or stamp.focal_plane_radius > self.config.max_focal_plane_radius + ): + continue + + if ( + stamp.ref_mag < self.config.magnitude_bins[i] + and stamp.ref_mag > self.config.magnitude_bins[i + 1] + ): + if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys(): + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + stampMI = stamp.stamp_im + self._applyStampFit(stamp) + mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI) + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + if ( + len(mag_bins_dict[self.config.magnitude_bins[i + 1]]) + == self.config.subset_stamp_number[i] + ): + if self.config.magnitude_bins[i + 1] not in subset_stampMIs.keys(): + subset_stampMIs[self.config.magnitude_bins[i + 1]] = [] + subset_stampMIs[self.config.magnitude_bins[i + 1]].append( + statisticsStack( + mag_bins_dict[self.config.magnitude_bins[i + 1]], + stack_type_property, + statistics_control, + ) + ) + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] += len( + mag_bins_dict[self.config.magnitude_bins[i + 1]] + ) + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + + for key in mag_bins_dict.keys(): + if key not in subset_stampMIs.keys(): + subset_stampMIs[key] = [] + subset_stampMIs[key].append( + statisticsStack(mag_bins_dict[key], stack_type_property, statistics_control) + ) + self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key]) + + # TODO: which stamp mask plane to use here? + # TODO: Amir: there might be cases where subset_stampMIs is an empty list. What do we want to do + # then? + # Currently, we get an "IndexError: list index out of range" + final_subset_stampMIs = [] + for key in subset_stampMIs.keys(): + final_subset_stampMIs.extend(subset_stampMIs[key]) + badMaskBitMask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control) + + extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() + extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) + extendedPsfMI.setXY0(extendedPsfOrigin) + + return Struct(extendedPsf=extendedPsfMI.getImage()) diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py index 88da41440..be3c480de 100644 --- a/tests/test_brightStarCutout.py +++ b/tests/test_brightStarCutout.py @@ -37,26 +37,36 @@ def setUp(self): # Fit values self.scale = 2.34e5 self.pedestal = 3210.1 - self.xGradient = 5.432 - self.yGradient = 10.987 + self.x_gradient = 5.432 + self.y_gradient = 10.987 + self.curvature_x = 0.1 + self.curvature_y = -0.2 + self.curvature_xy = 1e-2 # Create a pedestal + 2D plane - xCoords = np.linspace(-50, 50, 101) - yCoords = np.linspace(-50, 50, 101) - xPlane, yPlane = np.meshgrid(xCoords, yCoords) - pedestal = np.ones_like(xPlane) * self.pedestal + x_coords = np.linspace(-50, 50, 101) + y_coords = np.linspace(-50, 50, 101) + x_plane, y_plane = np.meshgrid(x_coords, y_coords) + pedestal = np.ones_like(x_plane) * self.pedestal # Create a pseudo-PSF - dist_from_center = np.sqrt(xPlane**2 + yPlane**2) - psfArray = np.exp(-dist_from_center / 5) - psfArray /= np.sum(psfArray) - fixedKernel = FixedKernel(ImageD(psfArray)) - psf = KernelPsf(fixedKernel) + dist_from_center = np.sqrt(x_plane ** 2 + y_plane ** 2) + psf_array = np.exp(-dist_from_center / 5) + psf_array /= np.sum(psf_array) + fixed_kernel = FixedKernel(ImageD(psf_array)) + psf = KernelPsf(fixed_kernel) self.psf = psf.computeKernelImage(psf.getAveragePosition()) # Bring everything together to construct a stamp masked image - stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient - stampIm = ImageF((stampArray).astype(np.float32)) + stamp_array = ( + psf_array * self.scale + pedestal + x_plane * self.x_gradient + y_plane * self.y_gradient + ) + stamp_array += ( + x_plane**2 * self.curvature_x + + y_plane**2 * self.curvature_y + + x_plane * y_plane * self.curvature_xy + ) + stampIm = ImageF((stamp_array).astype(np.float32)) stampVa = ImageF(stampIm.getBBox(), 654.321) self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) self.stampMI.setXY0(Point2I(-50, -50)) @@ -80,14 +90,18 @@ def test_fitPsf(self): """Test the PSF fitting method.""" brightStarCutoutConfig = BrightStarCutoutConfig() brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) - fitPsfResults = brightStarCutoutTask._fitPsf( + fit_psf_results = brightStarCutoutTask._fit_psf( self.stampMI, self.psf, ) - assert abs(fitPsfResults["scale"] - self.scale) / self.scale < 1e-6 - assert abs(fitPsfResults["pedestal"] - self.pedestal) / self.pedestal < 1e-6 - assert abs(fitPsfResults["xGradient"] - self.xGradient) / self.xGradient < 1e-6 - assert abs(fitPsfResults["yGradient"] - self.yGradient) / self.yGradient < 1e-6 + + assert abs(fit_psf_results["scale"] - self.scale) / self.scale < 1e-3 + assert abs(fit_psf_results["pedestal"] - self.pedestal) / self.pedestal < 1e-3 + assert abs(fit_psf_results["x_gradient"] - self.x_gradient) / self.x_gradient < 1e-3 + assert abs(fit_psf_results["y_gradient"] - self.y_gradient) / self.y_gradient < 1e-3 + assert abs(fit_psf_results["curvature_x"] - self.curvature_x) / self.curvature_x < 1e-3 + assert abs(fit_psf_results["curvature_y"] - self.curvature_y) / self.curvature_y < 1e-3 + assert abs(fit_psf_results["curvature_xy"] - self.curvature_xy) / self.curvature_xy < 1e-3 def setup_module(module): From 0b8fbb35f346bd35bbe854ef0afa362f7cd5302a Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Wed, 10 Dec 2025 05:27:02 +0000 Subject: [PATCH 07/12] working version of brightStarStack --- .../brightStarSubtraction/brightStarStack.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index bf7d0641c..d48c24400 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -38,7 +38,7 @@ class BrightStarStackConnections( PipelineTaskConnections, - dimensions=("instrument", "detector"), + dimensions=("instrument", "band"), ): """Connections for BrightStarStackTask.""" @@ -100,7 +100,7 @@ class BrightStarStackConfig( ], ) stack_type = Field[str]( - default="WEIGHTED_MEDIAN", + default="MEDIAN", doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", ) stack_num_sigma_clip = Field[float]( @@ -201,10 +201,7 @@ def run( ``brightStarStamps`` (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) """ - if self.config.stack_type == "WEIGHTED_MEDIAN": - stack_type_property = stringToStatisticsProperty("MEDIAN") - else: - stack_type_property = stringToStatisticsProperty(self.config.stack_type) + stack_type_property = stringToStatisticsProperty(self.config.stack_type) statistics_control = StatisticsControl( numSigmaClip=self.config.stack_num_sigma_clip, numIter=self.config.stack_num_iter, @@ -241,13 +238,13 @@ def run( stamp.ref_mag < self.config.magnitude_bins[i] and stamp.ref_mag > self.config.magnitude_bins[i + 1] ): + self._applyStampFit(stamp) if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys(): mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] stampMI = stamp.stamp_im - self._applyStampFit(stamp) mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI) - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) - statistics_control.setAndMask(badMaskBitMask) + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(bad_mask_bit_mask) if ( len(mag_bins_dict[self.config.magnitude_bins[i + 1]]) == self.config.subset_stamp_number[i] @@ -274,15 +271,11 @@ def run( ) self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key]) - # TODO: which stamp mask plane to use here? - # TODO: Amir: there might be cases where subset_stampMIs is an empty list. What do we want to do - # then? - # Currently, we get an "IndexError: list index out of range" final_subset_stampMIs = [] for key in subset_stampMIs.keys(): final_subset_stampMIs.extend(subset_stampMIs[key]) - badMaskBitMask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) - statistics_control.setAndMask(badMaskBitMask) + bad_mask_bit_mask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(bad_mask_bit_mask) extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control) extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() From 18a58dfd929692e3504c34ada9d2aba91c8d8836 Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 11 Dec 2025 10:49:21 +0000 Subject: [PATCH 08/12] including variance in extendedPsf and adding tests for brightStarStack module --- .../brightStarSubtraction/brightStarCutout.py | 6 +- .../brightStarSubtraction/brightStarStack.py | 6 +- tests/test_brightStarStack.py | 226 ++++++++++++++++++ 3 files changed, 232 insertions(+), 6 deletions(-) create mode 100644 tests/test_brightStarStack.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index fdc4f50ba..8a58b6bc1 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -95,8 +95,8 @@ class BrightStarCutoutConnections( dimensions=("visit", "detector"), ) extendedPsf = Input( - name="extendedPsf2", - storageClass="ImageF", + name="extendedPsf", + storageClass="MaskedImageF", doc="Extended PSF model, built from stacking bright star cutouts.", dimensions=("band",), ) @@ -364,7 +364,7 @@ def run( psf = WarpedPsf(inputExposure.getPsf(), pix_to_polar, warping_control) constant_psf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) if self.config.use_extended_psf: - psf_image = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] + psf_image = deepcopy(extendedPsf.image) # Assumed to be warped, center at [0,0] else: psf_image = constant_psf.computeKernelImage(constant_psf.getAveragePosition()) # TODO: maybe we want to generate a smaller psf in case the following happens? diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index d48c24400..bfafdfe0f 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -51,8 +51,8 @@ class BrightStarStackConnections( deferLoad=True, ) extendedPsf = Output( - name="extendedPsf2", # extendedPsfDetector ??? - storageClass="ImageF", # stamp_imF + name="extendedPsf", # extendedPsfDetector ??? + storageClass="MaskedImageF", # stamp_imF doc="Extended PSF model, built from stacking bright star cutouts.", dimensions=("band",), ) @@ -282,4 +282,4 @@ def run( extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) extendedPsfMI.setXY0(extendedPsfOrigin) - return Struct(extendedPsf=extendedPsfMI.getImage()) + return Struct(extendedPsf=extendedPsfMI) diff --git a/tests/test_brightStarStack.py b/tests/test_brightStarStack.py new file mode 100644 index 000000000..073464013 --- /dev/null +++ b/tests/test_brightStarStack.py @@ -0,0 +1,226 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageF, MaskedImageF +from lsst.pipe.tasks.brightStarSubtraction import BrightStarStackConfig, BrightStarStackTask + + +# Mock class to simulate the DeferredDatasetHandle (DDH) behavior +class MockHandle: + def __init__(self, content): + self.content = content + + def get(self): + return self.content + + +# Mock class to simulate a single BrightStarStamp +class MockStamp: + def __init__(self, stamp_im, mag, fit_params, stats): + self.stamp_im = stamp_im + self.ref_mag = mag + + # Unpack fit parameters + self.scale = fit_params.get("scale", 1.0) + self.pedestal = fit_params.get("pedestal", 0.0) + self.gradient_x = fit_params.get("gradient_x", 0.0) + self.gradient_y = fit_params.get("gradient_y", 0.0) + self.curvature_x = fit_params.get("curvature_x", 0.0) + self.curvature_y = fit_params.get("curvature_y", 0.0) + self.curvature_xy = fit_params.get("curvature_xy", 0.0) + + # Unpack statistics for filtering + self.global_reduced_chi_squared = stats.get("global_chi2", 1.0) + self.psf_reduced_chi_squared = stats.get("psf_chi2", 1.0) + self.bright_global_reduced_chi_squared = stats.get("bright_global_chi2", 1.0) + self.psf_bright_reduced_chi_squared = stats.get("psf_bright_chi2", 1.0) + self.bright_star_threshold = stats.get("brights_threshold", 100.0) + self.focal_plane_radius = stats.get("fp_radius", 100.0) + + +class BrightStarStackTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Define fit values + self.scale = 10.0 + self.pedestal = 50.0 + self.x_gradient = 0.5 + self.y_gradient = -0.5 + self.curvature_x = 0.01 + self.curvature_y = 0.01 + self.curvature_xy = 0.005 + + self.fit_params = { + "scale": self.scale, + "pedestal": self.pedestal, + "gradient_x": self.x_gradient, + "gradient_y": self.y_gradient, + "curvature_x": self.curvature_x, + "curvature_y": self.curvature_y, + "curvature_xy": self.curvature_xy, + } + + # Create the "Clean" PSF (a simple Gaussian) + self.dim = 51 + x_coords = np.linspace(-25, 25, self.dim) + y_coords = np.linspace(-25, 25, self.dim) + x_grid, y_grid = np.meshgrid(x_coords, y_coords) + + sigma = 5.0 + dist_sq = x_grid**2 + y_grid**2 + self.clean_array = np.exp(-dist_sq / (2 * sigma**2)) + + # Create the "star" Image (What the task receives) + # Apply scaling + star_array = self.clean_array * self.scale + + # Add background terms + x_indices, y_indices = np.meshgrid(np.arange(self.dim), np.arange(self.dim)) + + star_array += self.pedestal + star_array += x_indices * self.x_gradient + star_array += y_indices * self.y_gradient + star_array += (x_indices**2) * self.curvature_x + star_array += (y_indices**2) * self.curvature_y + star_array += (x_indices * y_indices) * self.curvature_xy + + # Create MaskedImage + stampIm = ImageF(star_array.astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 1.0) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + + # Initialize the mask planes required + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_applyStampFit(self): + """Test that _applyStampFit correctly removes background and normalizes.""" + config = BrightStarStackConfig() + task = BrightStarStackTask(config=config) + + # Create a mock stamp + stamp_mi_copy = self.stampMI.clone() + mock_stamp = MockStamp(stamp_mi_copy, mag=10.0, fit_params=self.fit_params, stats={}) + + # Run the method + task._applyStampFit(mock_stamp) + + # The result should be the clean array (normalized to scale 1.0) + result_array = mock_stamp.stamp_im.image.array + + # Allow for small floating point discrepancies + np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5) + + def test_run(self): + """Test the full run method: filtering, binning, and stacking.""" + config = BrightStarStackConfig() + # Set config to ensure our test stamps are included + config.magnitude_bins = [11, 9] + config.subset_stamp_number = [1] + config.stack_type = "MEDIAN" + + task = BrightStarStackTask(config=config) + + valid_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0} + invalid_stats = {"global_chi2": 1e9, "psf_chi2": 1e9, "fp_radius": 100.0} + + stamp1 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats) + stamp2 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats) + + # This stamp should be ignored + bad_stamp = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=invalid_stats) + + # Create mock input structure + # brightStarStamps is a list of handles + input_stamps = [MockHandle([stamp1, bad_stamp]), MockHandle([stamp2])] + + result = task.run(brightStarStamps=input_stamps) + + # Verify output exists + self.assertIsNotNone(result.extendedPsf) + + # Verify output dimensions match input + self.assertEqual(result.extendedPsf.getDimensions(), self.stampMI.getDimensions()) + + # Verify the calculation + # Since we stacked identical "clean" stamps (after fit application), + # the result should match self.clean_array + result_array = result.extendedPsf.image.array + np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5) + + def test_filtering_logic(self): + """Test that stamps outside focal plane radius or thresholds are skipped.""" + config = BrightStarStackConfig() + config.min_focal_plane_radius = 50.0 + config.max_focal_plane_radius = 150.0 + config.global_reduced_chi_squared_threshold = 5.0 + config.magnitude_bins = [15, 11, 9] + config.subset_stamp_number = [100, 1] + + task = BrightStarStackTask(config=config) + + good_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0} + bad_radius_low = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 10.0} + bad_radius_high = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 2000.0} + bad_chi2 = {"global_chi2": 100.0, "psf_chi2": 1.0, "fp_radius": 100.0} + + stamps = [ + MockStamp(self.stampMI.clone(), 10.0, self.fit_params, good_stats), + MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_low), + MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_high), + MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_chi2), + ] + + input_stamps = [MockHandle(stamps)] + task.run(brightStarStamps=input_stamps) + + bin_key = "9" # Based on config magnitude_bins=[11, 9], the lower bound is 9 + + self.assertEqual(task.metadata["psf_star_count"]["all"], 4) + + self.assertEqual(task.metadata["psf_star_count"][bin_key], 2) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() From 286fc0d2dac35ede1cc62f50cea77caf18242733 Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 18 Dec 2025 02:35:47 +0000 Subject: [PATCH 09/12] enabling stars-distance filtering in brightStarCutout --- .../brightStarSubtraction/brightStarCutout.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 8a58b6bc1..7533006c5 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -186,6 +186,16 @@ class BrightStarCutoutConfig( doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", default=15.0, ) + min_focal_plane_radius = Field[float]( + doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than " + "this will be omitted.", + default=-1.0, + ) + max_focal_plane_radius = Field[float]( + doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than " + "this will be omitted.", + default=2000.0, + ) # PSF Fitting use_extended_psf = Field[bool]( @@ -332,6 +342,11 @@ def run( # Excluding faint stars that are not within the frame. if obj["mag"] > self.config.off_frame_mag_lim and not self.star_in_frame(pix_coord, bbox): continue + + distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) + + if distance_mm < self.config.min_focal_plane_radius or distance_mm > self.config.max_focal_plane_radius: + continue footprint_index = associations.get(star_index, None) stampMI = MaskedImageF(self.padded_stamp_bbox) @@ -386,7 +401,6 @@ def run( # Save the stamp if the PSF fit was successful or no fit requested if fit_psf_results or not self.config.do_fit_psf: - distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) stamp = BrightStarStamp( stamp_im=stampMI, From cf87e1c1a7fded19312318bf11ae13d9b59c71d7 Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 18 Dec 2025 02:40:32 +0000 Subject: [PATCH 10/12] fixing the linter fail --- .../brightStarSubtraction/brightStarCutout.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 7533006c5..34089d58c 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -345,7 +345,10 @@ def run( distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) - if distance_mm < self.config.min_focal_plane_radius or distance_mm > self.config.max_focal_plane_radius: + if ( + distance_mm < self.config.min_focal_plane_radius + or distance_mm > self.config.max_focal_plane_radius + ): continue footprint_index = associations.get(star_index, None) stampMI = MaskedImageF(self.padded_stamp_bbox) @@ -471,7 +474,7 @@ def star_location_on_focal(self, pix_coord, detector): star_focal_plane_coords = detector.transform(pix_coord, PIXELS, FOCAL_PLANE) star_x_fp = star_focal_plane_coords.getX() star_y_fp = star_focal_plane_coords.getY() - distance_mm = np.sqrt(star_x_fp ** 2 + star_y_fp ** 2) + distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2) theta_rad = math.atan2(star_y_fp, star_x_fp) theta_angle = Angle(theta_rad, radians) return distance_mm, theta_angle @@ -718,8 +721,8 @@ def _fit_psf(self, stampMI: MaskedImageF, psf_image: ImageD | ImageF) -> dict[st coefficient_matrix[:, 0] /= sigma_data coefficient_matrix[:, 1] = y / sigma_data coefficient_matrix[:, 2] = x / sigma_data - coefficient_matrix[:, 3] = y ** 2 / sigma_data - coefficient_matrix[:, 4] = x ** 2 / sigma_data + coefficient_matrix[:, 3] = y**2 / sigma_data + coefficient_matrix[:, 4] = x**2 / sigma_data coefficient_matrix[:, 5] = x * y / sigma_data # scikit might have a fitting tool @@ -866,8 +869,8 @@ def _removePedestalAndGradient( x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) x_plane = ImageF((x_grid * x_gradient).astype(np.float32), xy0=stampMI.getXY0()) y_plane = ImageF((y_grid * y_gradient).astype(np.float32), xy0=stampMI.getXY0()) - x_curve = ImageF((x_grid ** 2 * x_curvature).astype(np.float32), xy0=stampMI.getXY0()) - y_curve = ImageF((y_grid ** 2 * y_curvature).astype(np.float32), xy0=stampMI.getXY0()) + x_curve = ImageF((x_grid**2 * x_curvature).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid**2 * y_curvature).astype(np.float32), xy0=stampMI.getXY0()) curvature_xy = ImageF((x_grid * y_grid * curvature_xy).astype(np.float32), xy0=stampMI.getXY0()) stampMI -= pedestal stampMI -= x_plane From 77980fbeb7047ba1a04669c45cc0a64ff9c87668 Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 8 Jan 2026 07:01:09 +0000 Subject: [PATCH 11/12] correcting mag_range default and modifying some docs --- .../tasks/brightStarSubtraction/brightStarCutout.py | 2 +- .../tasks/brightStarSubtraction/brightStarStack.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 34089d58c..788628c67 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -123,7 +123,7 @@ class BrightStarCutoutConfig( # Star selection mag_range = ListField[float]( doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", - default=[0, 18], + default=[10, 18], ) exclude_arcsec_radius = Field[float]( doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index bfafdfe0f..77f1989f9 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -65,11 +65,11 @@ class BrightStarStackConfig( """Configuration parameters for BrightStarStackTask.""" global_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for global reduced chi-squared for bright star stamps.", + doc="Threshold for global reduced chi-squared for stamps.", default=5.0, ) psf_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for PSF reduced chi-squared for bright star stamps.", + doc="Threshold for PSF reduced chi-squared for stamps.", default=50.0, ) bright_star_threshold = Field[float]( @@ -112,12 +112,12 @@ class BrightStarStackConfig( default=5, ) magnitude_bins = ListField[int]( - doc="Only used if stack_type == WEIGHTED_MEDIAN. Bins of magnitudes for weighting purposes.", + doc="Bins of magnitudes for weighting purposes.", default=[20, 19, 18, 17, 16, 15, 13, 10], ) subset_stamp_number = ListField[int]( - doc="Only used if stack_type == WEIGHTED_MEDIAN. Number of stamps per subset to generate stacked " - "images for. The length of this parameter must be equal to the length of magnitude_bins plus one.", + doc="Number of stamps per subset to generate stacked " + "images for. The length of this parameter must be equal to the length of magnitude_bins minus one.", default=[300, 200, 150, 100, 100, 100, 1], ) min_focal_plane_radius = Field[float]( From 5009719db23065de18d55dbc51f3059da4ef1dac Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 8 Jan 2026 23:29:13 +0000 Subject: [PATCH 12/12] removing stack cod from this PR --- .../tasks/brightStarSubtraction/__init__.py | 1 - .../brightStarSubtraction/brightStarStack.py | 285 ------------------ tests/test_brightStarStack.py | 226 -------------- 3 files changed, 512 deletions(-) delete mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py delete mode 100644 tests/test_brightStarStack.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py index fe5088369..fb8e3320d 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -1,2 +1 @@ from .brightStarCutout import * -from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py deleted file mode 100644 index 77f1989f9..000000000 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ /dev/null @@ -1,285 +0,0 @@ -# This file is part of pipe_tasks. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" - -__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] - -import numpy as np -from lsst.afw.image import ImageF -from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty -from lsst.geom import Point2I -from lsst.meas.algorithms import BrightStarStamps -from lsst.pex.config import Field, ListField -from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct -from lsst.pipe.base.connectionTypes import Input, Output -from lsst.utils.timer import timeMethod - -NEIGHBOR_MASK_PLANE = "NEIGHBOR" - - -class BrightStarStackConnections( - PipelineTaskConnections, - dimensions=("instrument", "band"), -): - """Connections for BrightStarStackTask.""" - - brightStarStamps = Input( - name="brightStarStamps", - storageClass="BrightStarStamps", - doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", - dimensions=("visit", "detector"), - multiple=True, - deferLoad=True, - ) - extendedPsf = Output( - name="extendedPsf", # extendedPsfDetector ??? - storageClass="MaskedImageF", # stamp_imF - doc="Extended PSF model, built from stacking bright star cutouts.", - dimensions=("band",), - ) - - -class BrightStarStackConfig( - PipelineTaskConfig, - pipelineConnections=BrightStarStackConnections, -): - """Configuration parameters for BrightStarStackTask.""" - - global_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for global reduced chi-squared for stamps.", - default=5.0, - ) - psf_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for PSF reduced chi-squared for stamps.", - default=50.0, - ) - bright_star_threshold = Field[float]( - doc="Stars brighter than this magnitude, are considered as bright stars.", - default=12.0, - ) - bright_global_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for global reduced chi-squared for bright star stamps.", - default=250.0, - ) - psf_bright_reduced_chi_squared_threshold = Field[float]( - doc="Threshold for PSF reduced chi-squared for bright star stamps.", - default=400.0, - ) - - bad_mask_planes = ListField[str]( - doc="Mask planes that identify excluded (masked) pixels.", - default=[ - "BAD", - "CR", - "CROSSTALK", - "EDGE", - "NO_DATA", - "SAT", - "SUSPECT", - "UNMASKEDNAN", - NEIGHBOR_MASK_PLANE, - ], - ) - stack_type = Field[str]( - default="MEDIAN", - doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", - ) - stack_num_sigma_clip = Field[float]( - doc="Number of sigma to use for clipping when stacking.", - default=3.0, - ) - stack_num_iter = Field[int]( - doc="Number of iterations to use for clipping when stacking.", - default=5, - ) - magnitude_bins = ListField[int]( - doc="Bins of magnitudes for weighting purposes.", - default=[20, 19, 18, 17, 16, 15, 13, 10], - ) - subset_stamp_number = ListField[int]( - doc="Number of stamps per subset to generate stacked " - "images for. The length of this parameter must be equal to the length of magnitude_bins minus one.", - default=[300, 200, 150, 100, 100, 100, 1], - ) - min_focal_plane_radius = Field[float]( - doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than " - "this will be omitted.", - default=-1.0, - ) - max_focal_plane_radius = Field[float]( - doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than " - "this will be omitted.", - default=2000.0, - ) - - -class BrightStarStackTask(PipelineTask): - """Stack bright star postage stamps to produce an extended PSF model.""" - - ConfigClass = BrightStarStackConfig - _DefaultName = "brightStarStack" - config: BrightStarStackConfig - - def __init__(self, initInputs=None, *args, **kwargs): - super().__init__(*args, **kwargs) - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - inputs = butlerQC.get(inputRefs) - output = self.run(**inputs) - butlerQC.put(output, outputRefs) - - def _applyStampFit(self, stamp): - """Apply fitted stamp components to a single bright star stamp.""" - stampMI = stamp.stamp_im - stamp_bbox = stampMI.getBBox() - - x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) - - x_plane = ImageF((x_grid * stamp.gradient_x).astype(np.float32), xy0=stampMI.getXY0()) - y_plane = ImageF((y_grid * stamp.gradient_y).astype(np.float32), xy0=stampMI.getXY0()) - - x_curve = ImageF((x_grid**2 * stamp.curvature_x).astype(np.float32), xy0=stampMI.getXY0()) - y_curve = ImageF((y_grid**2 * stamp.curvature_y).astype(np.float32), xy0=stampMI.getXY0()) - xy_curve = ImageF((x_grid * y_grid * stamp.curvature_xy).astype(np.float32), xy0=stampMI.getXY0()) - - stampMI -= stamp.pedestal - stampMI -= x_plane - stampMI -= y_plane - stampMI -= x_curve - stampMI -= y_curve - stampMI -= xy_curve - stampMI /= stamp.scale - - @timeMethod - def run( - self, - brightStarStamps: BrightStarStamps, - ): - """Identify bright stars within an exposure using a reference catalog, - extract stamps around each, then preprocess them. - - Bright star preprocessing steps are: shifting, warping and potentially - rotating them to the same pixel grid; computing their annular flux, - and; normalizing them. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright star stamps should be extracted. - inputBackground : `~lsst.afw.image.Background` - The background model for the input exposure. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure (including detector) that bright stars - should be extracted from. - - Returns - ------- - brightStarResults : `~lsst.pipe.base.Struct` - Results as a struct with attributes: - - ``brightStarStamps`` - (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) - """ - stack_type_property = stringToStatisticsProperty(self.config.stack_type) - statistics_control = StatisticsControl( - numSigmaClip=self.config.stack_num_sigma_clip, - numIter=self.config.stack_num_iter, - ) - - mag_bins_dict = {} - subset_stampMIs = {} - self.metadata["psf_star_count"] = {} - self.metadata["psf_star_count"]["all"] = 0 - for i in range(len(self.config.subset_stamp_number)): - self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] = 0 - for stampsDDH in brightStarStamps: - stamps = stampsDDH.get() - self.metadata["psf_star_count"]["all"] += len(stamps) - for stamp in stamps: - if stamp.ref_mag >= self.config.bright_star_threshold: - global_reduced_chi_squared_threshold = self.config.global_reduced_chi_squared_threshold - psf_reduced_chi_squared_threshold = self.config.psf_reduced_chi_squared_threshold - else: - global_reduced_chi_squared_threshold = ( - self.config.bright_global_reduced_chi_squared_threshold - ) - psf_reduced_chi_squared_threshold = self.config.psf_bright_reduced_chi_squared_threshold - for i in range(len(self.config.subset_stamp_number)): - if ( - stamp.global_reduced_chi_squared > global_reduced_chi_squared_threshold - or stamp.psf_reduced_chi_squared > psf_reduced_chi_squared_threshold - or stamp.focal_plane_radius < self.config.min_focal_plane_radius - or stamp.focal_plane_radius > self.config.max_focal_plane_radius - ): - continue - - if ( - stamp.ref_mag < self.config.magnitude_bins[i] - and stamp.ref_mag > self.config.magnitude_bins[i + 1] - ): - self._applyStampFit(stamp) - if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys(): - mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] - stampMI = stamp.stamp_im - mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI) - bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) - statistics_control.setAndMask(bad_mask_bit_mask) - if ( - len(mag_bins_dict[self.config.magnitude_bins[i + 1]]) - == self.config.subset_stamp_number[i] - ): - if self.config.magnitude_bins[i + 1] not in subset_stampMIs.keys(): - subset_stampMIs[self.config.magnitude_bins[i + 1]] = [] - subset_stampMIs[self.config.magnitude_bins[i + 1]].append( - statisticsStack( - mag_bins_dict[self.config.magnitude_bins[i + 1]], - stack_type_property, - statistics_control, - ) - ) - self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] += len( - mag_bins_dict[self.config.magnitude_bins[i + 1]] - ) - mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] - - for key in mag_bins_dict.keys(): - if key not in subset_stampMIs.keys(): - subset_stampMIs[key] = [] - subset_stampMIs[key].append( - statisticsStack(mag_bins_dict[key], stack_type_property, statistics_control) - ) - self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key]) - - final_subset_stampMIs = [] - for key in subset_stampMIs.keys(): - final_subset_stampMIs.extend(subset_stampMIs[key]) - bad_mask_bit_mask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) - statistics_control.setAndMask(bad_mask_bit_mask) - extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control) - - extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() - extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) - extendedPsfMI.setXY0(extendedPsfOrigin) - - return Struct(extendedPsf=extendedPsfMI) diff --git a/tests/test_brightStarStack.py b/tests/test_brightStarStack.py deleted file mode 100644 index 073464013..000000000 --- a/tests/test_brightStarStack.py +++ /dev/null @@ -1,226 +0,0 @@ -# This file is part of pipe_tasks. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -import unittest - -import lsst.afw.image -import lsst.utils.tests -import numpy as np -from lsst.afw.image import ImageF, MaskedImageF -from lsst.pipe.tasks.brightStarSubtraction import BrightStarStackConfig, BrightStarStackTask - - -# Mock class to simulate the DeferredDatasetHandle (DDH) behavior -class MockHandle: - def __init__(self, content): - self.content = content - - def get(self): - return self.content - - -# Mock class to simulate a single BrightStarStamp -class MockStamp: - def __init__(self, stamp_im, mag, fit_params, stats): - self.stamp_im = stamp_im - self.ref_mag = mag - - # Unpack fit parameters - self.scale = fit_params.get("scale", 1.0) - self.pedestal = fit_params.get("pedestal", 0.0) - self.gradient_x = fit_params.get("gradient_x", 0.0) - self.gradient_y = fit_params.get("gradient_y", 0.0) - self.curvature_x = fit_params.get("curvature_x", 0.0) - self.curvature_y = fit_params.get("curvature_y", 0.0) - self.curvature_xy = fit_params.get("curvature_xy", 0.0) - - # Unpack statistics for filtering - self.global_reduced_chi_squared = stats.get("global_chi2", 1.0) - self.psf_reduced_chi_squared = stats.get("psf_chi2", 1.0) - self.bright_global_reduced_chi_squared = stats.get("bright_global_chi2", 1.0) - self.psf_bright_reduced_chi_squared = stats.get("psf_bright_chi2", 1.0) - self.bright_star_threshold = stats.get("brights_threshold", 100.0) - self.focal_plane_radius = stats.get("fp_radius", 100.0) - - -class BrightStarStackTestCase(lsst.utils.tests.TestCase): - def setUp(self): - # Define fit values - self.scale = 10.0 - self.pedestal = 50.0 - self.x_gradient = 0.5 - self.y_gradient = -0.5 - self.curvature_x = 0.01 - self.curvature_y = 0.01 - self.curvature_xy = 0.005 - - self.fit_params = { - "scale": self.scale, - "pedestal": self.pedestal, - "gradient_x": self.x_gradient, - "gradient_y": self.y_gradient, - "curvature_x": self.curvature_x, - "curvature_y": self.curvature_y, - "curvature_xy": self.curvature_xy, - } - - # Create the "Clean" PSF (a simple Gaussian) - self.dim = 51 - x_coords = np.linspace(-25, 25, self.dim) - y_coords = np.linspace(-25, 25, self.dim) - x_grid, y_grid = np.meshgrid(x_coords, y_coords) - - sigma = 5.0 - dist_sq = x_grid**2 + y_grid**2 - self.clean_array = np.exp(-dist_sq / (2 * sigma**2)) - - # Create the "star" Image (What the task receives) - # Apply scaling - star_array = self.clean_array * self.scale - - # Add background terms - x_indices, y_indices = np.meshgrid(np.arange(self.dim), np.arange(self.dim)) - - star_array += self.pedestal - star_array += x_indices * self.x_gradient - star_array += y_indices * self.y_gradient - star_array += (x_indices**2) * self.curvature_x - star_array += (y_indices**2) * self.curvature_y - star_array += (x_indices * y_indices) * self.curvature_xy - - # Create MaskedImage - stampIm = ImageF(star_array.astype(np.float32)) - stampVa = ImageF(stampIm.getBBox(), 1.0) - self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) - - # Initialize the mask planes required - badMaskPlanes = [ - "BAD", - "CR", - "CROSSTALK", - "EDGE", - "NO_DATA", - "SAT", - "SUSPECT", - "UNMASKEDNAN", - "NEIGHBOR", - ] - _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] - - def test_applyStampFit(self): - """Test that _applyStampFit correctly removes background and normalizes.""" - config = BrightStarStackConfig() - task = BrightStarStackTask(config=config) - - # Create a mock stamp - stamp_mi_copy = self.stampMI.clone() - mock_stamp = MockStamp(stamp_mi_copy, mag=10.0, fit_params=self.fit_params, stats={}) - - # Run the method - task._applyStampFit(mock_stamp) - - # The result should be the clean array (normalized to scale 1.0) - result_array = mock_stamp.stamp_im.image.array - - # Allow for small floating point discrepancies - np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5) - - def test_run(self): - """Test the full run method: filtering, binning, and stacking.""" - config = BrightStarStackConfig() - # Set config to ensure our test stamps are included - config.magnitude_bins = [11, 9] - config.subset_stamp_number = [1] - config.stack_type = "MEDIAN" - - task = BrightStarStackTask(config=config) - - valid_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0} - invalid_stats = {"global_chi2": 1e9, "psf_chi2": 1e9, "fp_radius": 100.0} - - stamp1 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats) - stamp2 = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=valid_stats) - - # This stamp should be ignored - bad_stamp = MockStamp(self.stampMI.clone(), mag=10.0, fit_params=self.fit_params, stats=invalid_stats) - - # Create mock input structure - # brightStarStamps is a list of handles - input_stamps = [MockHandle([stamp1, bad_stamp]), MockHandle([stamp2])] - - result = task.run(brightStarStamps=input_stamps) - - # Verify output exists - self.assertIsNotNone(result.extendedPsf) - - # Verify output dimensions match input - self.assertEqual(result.extendedPsf.getDimensions(), self.stampMI.getDimensions()) - - # Verify the calculation - # Since we stacked identical "clean" stamps (after fit application), - # the result should match self.clean_array - result_array = result.extendedPsf.image.array - np.testing.assert_allclose(result_array, self.clean_array, atol=1e-5) - - def test_filtering_logic(self): - """Test that stamps outside focal plane radius or thresholds are skipped.""" - config = BrightStarStackConfig() - config.min_focal_plane_radius = 50.0 - config.max_focal_plane_radius = 150.0 - config.global_reduced_chi_squared_threshold = 5.0 - config.magnitude_bins = [15, 11, 9] - config.subset_stamp_number = [100, 1] - - task = BrightStarStackTask(config=config) - - good_stats = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 100.0} - bad_radius_low = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 10.0} - bad_radius_high = {"global_chi2": 1.0, "psf_chi2": 1.0, "fp_radius": 2000.0} - bad_chi2 = {"global_chi2": 100.0, "psf_chi2": 1.0, "fp_radius": 100.0} - - stamps = [ - MockStamp(self.stampMI.clone(), 10.0, self.fit_params, good_stats), - MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_low), - MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_radius_high), - MockStamp(self.stampMI.clone(), 10.0, self.fit_params, bad_chi2), - ] - - input_stamps = [MockHandle(stamps)] - task.run(brightStarStamps=input_stamps) - - bin_key = "9" # Based on config magnitude_bins=[11, 9], the lower bound is 9 - - self.assertEqual(task.metadata["psf_star_count"]["all"], 4) - - self.assertEqual(task.metadata["psf_star_count"][bin_key], 2) - - -def setup_module(module): - lsst.utils.tests.init() - - -class MemoryTestCase(lsst.utils.tests.MemoryTestCase): - pass - - -if __name__ == "__main__": - lsst.utils.tests.init() - unittest.main()