From 22f643a41dfbd3f21807b24fa0996ccde723cb90 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Mon, 19 Feb 2024 13:08:55 -0800 Subject: [PATCH 1/5] Refactor bright star subtraction framework --- .../tasks/brightStarSubtraction/__init__.py | 2 + .../brightStarSubtraction/brightStarCutout.py | 661 ++++++++++++++++++ .../brightStarSubtraction/brightStarStack.py | 226 ++++++ .../brightStarSubtract.py | 0 tests/test_brightStarCutout.py | 102 +++ 5 files changed, 991 insertions(+) create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/__init__.py create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py create mode 100644 tests/test_brightStarCutout.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py new file mode 100644 index 000000000..fe5088369 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -0,0 +1,2 @@ +from .brightStarCutout import * +from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py new file mode 100644 index 000000000..a9d58a762 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -0,0 +1,661 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Extract bright star cutouts; normalize and warp, optionally fit the PSF.""" + +__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] + +from typing import Any, Iterable, cast + +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + BrightStarStamp, + BrightStarStamps, + KernelPsf, + LoadReferenceObjectsConfig, + ReferenceObjectLoader, + WarpedPsf, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarCutoutConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for BrightStarCutoutTask.""" + + refCat = PrerequisiteInput( + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + doc="Reference catalog that contains bright star positions.", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + inputExposure = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + extendedPsf = Input( + name="extendedPsf2", + storageClass="ImageF", + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + brightStarStamps = Output( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + ) + + def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): + super().__init__(config=config) + assert config is not None + if not config.useExtendedPsf: + self.inputs.remove("extendedPsf") + + +class BrightStarCutoutConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarCutoutConnections, +): + """Configuration parameters for BrightStarCutoutTask.""" + + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + excludeArcsecRadius = Field[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=5, + ) + excludeMagRange = ListField[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=[0, 20], + ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + "optionally, fitting of the PSF.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel.", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + + # PSF Fitting + useExtendedPsf = Field[bool]( + doc="Use the extended PSF model to normalize bright star cutouts.", + default=False, + ) + doFitPsf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + default=True, + ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + + # Misc + loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + + +class BrightStarCutoutTask(PipelineTask): + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + The BrightStarCutoutTask is used to extract, process, and store small image + cutouts (or "postage stamps") around bright stars. + This task essentially consists of three principal steps. + First, it identifies bright stars within an exposure using a reference + catalog and extracts a stamp around each. + Second, it shifts and warps each stamp to remove optical distortions and + sample all stars on the same pixel grid. + Finally, it optionally fits a PSF plus plane flux model to the cutout. + This final fitting procedure may be used to normalize each bright star + stamp prior to stacking when producing extended PSF models. + """ + + ConfigClass = BrightStarCutoutConfig + _DefaultName = "brightStarCutout" + config: BrightStarCutoutConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize * self.config.stampSizePadding + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.loadReferenceObjectsConfig, + ) + extendedPsf = inputs.pop("extendedPsf", None) + output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) + # Only ingest Stamp if it exists; prevents ingesting an empty FITS file + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputBackground: BackgroundList, + extendedPsf: ImageF | None, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + # Restore original subtracted background + inputMI = inputExposure.getMaskedImage() + inputMI += inputBackground.getImage() + + # Set up NEIGHBOR mask plane; associate footprints with stars + inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + + # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here + inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) + + # Set up transform + detector = inputExposure.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + # Loop over each bright star + stamps, goodFracs, stamps_fitPsfResults = [], [], [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + footprintIndex = associations.get(starIndex, None) + stampMI = MaskedImageF(self.paddedStampBBox) + + # Set NEIGHBOR footprints in the mask plane + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + + # Apply the warp to the star stamp (in-place) + warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + + # Trim to the base stamp size, check mask coverage, update metadata + stampMI = stampMI[self.stampBBox] + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size + goodFracs.append(goodFrac) + if goodFrac < self.config.minAreaFraction: + continue + + # Fit a scaled PSF and a pedestal to each bright star cutout + psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) + constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + if self.config.useExtendedPsf: + psfImage = extendedPsf # Assumed to be warped, center at [0,0] + else: + psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + fitPsfResults = {} + if self.config.doFitPsf: + fitPsfResults = self._fitPsf(stampMI, psfImage) + stamps_fitPsfResults.append(fitPsfResults) + + # Save the stamp if the PSF fit was successful or no fit requested + if fitPsfResults or not self.config.doFitPsf: + stamp = BrightStarStamp( + maskedImage=stampMI, + # TODO: what to do about this PSF? + psf=constantPsf, + wcs=makeModifiedWcs(pixToPolar, wcs, False), + visit=cast(int, dataId["visit"]), + detector=cast(int, dataId["detector"]), + refId=obj["id"], + refMag=obj["mag"], + position=pixCoord, + scale=fitPsfResults.get("scale", None), + scaleErr=fitPsfResults.get("scaleErr", None), + pedestal=fitPsfResults.get("pedestal", None), + pedestalErr=fitPsfResults.get("pedestalErr", None), + pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), + xGradient=fitPsfResults.get("xGradient", None), + yGradient=fitPsfResults.get("yGradient", None), + globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), + globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), + psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + stamps.append(stamp) + + self.log.info( + "Extracted %i bright star stamp%s. " + "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", + len(stamps), + "" if len(stamps) == 1 else "s", + len(refCatBright) - len(stamps), + "" if len(refCatBright) - len(stamps) == 1 else "s", + np.sum(np.array(goodFracs) < self.config.minAreaFraction), + ( + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) + if self.config.doFitPsf + else 0 + ), + ) + brightStarStamps = BrightStarStamps(stamps) + return Struct(brightStarStamps=brightStarStamps) + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( + refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + + proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( + refCatSubset[fluxField] <= proxFluxRange[1] + ) + brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( + refCatSubset[fluxField] <= brightFluxRange[1] + ) + + coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") + excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore + refCatBrightIsolated = [] + for coord in cast(Iterable[SkyCoord], coords[brightStars]): + neighbors = coords[proxStars] + seps = coord.separation(neighbors).to(u.arcsec) + tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched + refCatBrightIsolated.append(not tooClose.any()) + + refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + + fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore + refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + + self.log.info( + "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " + "stars within %s arcsec.", + len(refCatBright), + len(refCatFull), + "" if len(refCatFull) == 1 else "s", + self.config.magRange, + self.config.excludeArcsecRadius, + ) + + return refCatBright + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + # TODO: This is np.float64, else FITS metadata serialization fails + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Create consistently masked data + badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = solutions[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py new file mode 100644 index 000000000..fcc108d45 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -0,0 +1,226 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" + +__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] + +import numpy as np +from lsst.afw.image import ImageF +from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty +from lsst.geom import Box2I, Extent2I, Point2I +from lsst.meas.algorithms import BrightStarStamps +from lsst.pex.config import Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarStackConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), +): + """Connections for BrightStarStackTask.""" + + brightStarStamps = Input( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + multiple=True, + deferLoad=True, + ) + extendedPsf = Output( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # MaskedImageF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + + +class BrightStarStackConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarStackConnections, +): + """Configuration parameters for BrightStarStackTask.""" + + subsetStampNumber = Field[int]( + doc="Number of stamps per subset to generate stacked images for.", + default=20, + ) + globalReducedChiSquaredThreshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=5.0, + ) + psfReducedChiSquaredThreshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=50.0, + ) + + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded (masked) pixels.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + # "SAT", + # "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + stackType = Field[str]( + default="MEANCLIP", + doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", + ) + stackNumSigmaClip = Field[float]( + doc="Number of sigma to use for clipping when stacking.", + default=3.0, + ) + stackNumIter = Field[int]( + doc="Number of iterations to use for clipping when stacking.", + default=5, + ) + + +class BrightStarStackTask(PipelineTask): + """Stack bright star postage stamps to produce an extended PSF model.""" + + ConfigClass = BrightStarStackConfig + _DefaultName = "brightStarStack" + config: BrightStarStackConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + output = self.run(**inputs) + butlerQC.put(output, outputRefs) + + def _applyStampFit(self, stamp): + """Apply fitted stamp components to a single bright star stamp.""" + stampMI = stamp.maskedImage + stampBBox = stampMI.getBBox() + xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) + xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0()) + yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= stamp.pedestal + stampMI -= xPlane + stampMI -= yPlane + stampMI /= stamp.scale + + @timeMethod + def run( + self, + brightStarStamps: BrightStarStamps, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + inputBackground : `~lsst.afw.image.Background` + The background model for the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + stackTypeProperty = stringToStatisticsProperty(self.config.stackType) + statisticsControl = StatisticsControl( + numSigmaClip=self.config.stackNumSigmaClip, + numIter=self.config.stackNumIter, + ) + + subsetStampMIs = [] + tempStampMIs = [] + for stampsDDH in brightStarStamps: + stamps = stampsDDH.get() + for stamp in stamps: + if ( + stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold + or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold + ): + continue + stampMI = stamp.maskedImage + self._applyStampFit(stamp) + tempStampMIs.append(stampMI) + + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + statisticsControl.setAndMask(badMaskBitMask) + + if len(tempStampMIs) == self.config.subsetStampNumber: + subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl)) + # TODO: what to do with remaining temp stamps? + tempStampMIs = [] + + # TODO: which stamp mask plane to use here? + badMaskBitMask = subsetStampMIs[0].mask.getPlaneBitMask(self.config.badMaskPlanes) + statisticsControl.setAndMask(badMaskBitMask) + extendedPsfMI = statisticsStack(subsetStampMIs, stackTypeProperty, statisticsControl) + + extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() + extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) + extendedPsfMI.setXY0(extendedPsfOrigin) + + return Struct(extendedPsf=extendedPsfMI.getImage()) + + # stack = [] + # chiStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # sigma-clip mean stack the rest + # append to stack + # compute the scatter (MAD/sigma-clipped var, etc) of the rest + # divide by sqrt(var plane), and append to chiStack + # after for-loop, combine images in median stack for final result + # also combine chi-images, save separately + + # idea: run with two different thresholds, and compare the results + + # medianStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # median/sigma-clip stack the rest + # append to medianStack + # after for-loop, combine images in median stack for final result diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py new file mode 100644 index 000000000..67b88d02f --- /dev/null +++ b/tests/test_brightStarCutout.py @@ -0,0 +1,102 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.cameraGeom.testUtils +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageD, ImageF, MaskedImageF +from lsst.afw.math import FixedKernel +from lsst.geom import Point2I +from lsst.meas.algorithms import KernelPsf +from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask + + +class BrightStarCutoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Fit values + self.scale = 2.34e5 + self.pedestal = 3210.1 + self.xGradient = 5.432 + self.yGradient = 10.987 + + # Create a pedestal + 2D plane + xCoords = np.linspace(-50, 50, 101) + yCoords = np.linspace(-50, 50, 101) + xPlane, yPlane = np.meshgrid(xCoords, yCoords) + pedestal = np.ones_like(xPlane) * self.pedestal + + # Create a pseudo-PSF + dist_from_center = np.sqrt(xPlane**2 + yPlane**2) + psfArray = np.exp(-dist_from_center / 5) + psfArray /= np.sum(psfArray) + fixedKernel = FixedKernel(ImageD(psfArray)) + self.psf = KernelPsf(fixedKernel) + + # Bring everything together to construct a stamp masked image + stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient + stampIm = ImageF((stampArray).astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 654.321) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + self.stampMI.setXY0(Point2I(-50, -50)) + + # Ensure that all mask planes required by the task are in-place; + # new mask plane entries will be created as necessary + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_fitPsf(self): + """Test the PSF fitting method.""" + brightStarCutoutConfig = BrightStarCutoutConfig() + brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) + fitPsfResults = brightStarCutoutTask._fitPsf( + self.stampMI, + self.psf, + ) + self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) + self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) + self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) + self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() From 7bb272bb722933abf464ffda4b1f923969bec4e7 Mon Sep 17 00:00:00 2001 From: "Amir E. Bazkiaei" Date: Fri, 4 Apr 2025 16:05:28 +1100 Subject: [PATCH 2/5] tickets/DM-48377 draft --- .../brightStarSubtraction/brightStarCutout.py | 25 +- .../brightStarSubtraction/brightStarStack.py | 7 +- .../subtractBrightStar.py | 814 ++++++++++++++++++ 3 files changed, 840 insertions(+), 6 deletions(-) create mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index a9d58a762..3943de566 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -60,6 +60,7 @@ from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput from lsst.utils.timer import timeMethod +from copy import deepcopy NEIGHBOR_MASK_PLANE = "NEIGHBOR" @@ -293,6 +294,8 @@ def run( # Restore original subtracted background inputMI = inputExposure.getMaskedImage() inputMI += inputBackground.getImage() + # Amir: the above addition to inputMI, also adds to the inputExposure. + # Amir: but the calibration, three lines later, only is applied to the inputMI. # Set up NEIGHBOR mask plane; associate footprints with stars inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) @@ -317,9 +320,11 @@ def run( # Set NEIGHBOR footprints in the mask plane if footprintIndex: neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + # self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) else: - self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + # self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) # Define linear shifting to recenter stamps coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star @@ -329,7 +334,8 @@ def run( pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) # Apply the warp to the star stamp (in-place) - warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + # warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + warpImage(stampMI, inputMI, pixToPolar, warpingControl) # Trim to the base stamp size, check mask coverage, update metadata stampMI = stampMI[self.stampBBox] @@ -579,6 +585,7 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfMaskedPixels = ImageF(psfImage.getBBox()) psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) # TODO: This is np.float64, else FITS metadata serialization fails + # Amir: the following tries to find the fraction of the psf flux in the masked area of the psf image. psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: return {} # Handle cases where the PSF image is mostly masked @@ -587,8 +594,10 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, paddedPsfImage = ImageF(stampMI.getBBox()) paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + mask = self.add_psf_mask(paddedPsfImage, stampMI) # Create consistently masked data - badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + # badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + badSpans = SpanSet.fromMask(mask, badMaskBitMask) goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) if self.config.useMedianVariance: @@ -598,7 +607,6 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, imageData /= sigmaData psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) psfData /= sigmaData - # Fit the PSF scale factor and global pedestal nData = len(imageData) coefficientMatrix = np.ones((nData, 4), dtype=float) # A @@ -659,3 +667,10 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, psfMaskedFluxFrac=psfMaskedFluxFrac, ) + + def add_psf_mask(self, psfImage, stampMI): + cond = np.isnan(psfImage.array) + cond |= psfImage.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index fcc108d45..2aa40c769 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -24,7 +24,7 @@ __all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] import numpy as np -from lsst.afw.image import ImageF +from lsst.afw.image import ImageF, MaskedImageF from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty from lsst.geom import Box2I, Extent2I, Point2I from lsst.meas.algorithms import BrightStarStamps @@ -187,12 +187,16 @@ def run( badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) statisticsControl.setAndMask(badMaskBitMask) + # Amir: In case the total number of stamps is less than 20, the following will result in an + # empty subsetStampMIs list. if len(tempStampMIs) == self.config.subsetStampNumber: subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl)) # TODO: what to do with remaining temp stamps? tempStampMIs = [] # TODO: which stamp mask plane to use here? + # TODO: Amir: there might be cases where subsetStampMIs is an empty list. What do we want to do then? + # Currently, we get an "IndexError: list index out of range" badMaskBitMask = subsetStampMIs[0].mask.getPlaneBitMask(self.config.badMaskPlanes) statisticsControl.setAndMask(badMaskBitMask) extendedPsfMI = statisticsStack(subsetStampMIs, stackTypeProperty, statisticsControl) @@ -200,6 +204,7 @@ def run( extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) extendedPsfMI.setXY0(extendedPsfOrigin) + # return Struct(extendedPsf=[extendedPsfMI]) return Struct(extendedPsf=extendedPsfMI.getImage()) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py new file mode 100644 index 000000000..6cb224f4a --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py @@ -0,0 +1,814 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Retrieve extended PSF model and subtract bright stars at visit level.""" + +__all__ = ["BrightStarSubtractConnections", "BrightStarSubtractConfig", "BrightStarSubtractTask"] + +import logging +from typing import Any +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table, Column +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, TAN_PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + LoadReferenceObjectsConfig, + ReferenceObjectLoader, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod +from copy import deepcopy + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + +logger = logging.getLogger(__name__) + + +class BrightStarSubtractConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), + defaultTemplates={ + # "outputExposureName": "brightStar_subtracted", + "outputExposureName": "postISRCCD", + "outputBackgroundName": "brightStars", + "badStampsName": "brightStars", + }, +): + inputCalexp = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + inputExposure = Input( + doc="Input exposure from which to subtract bright star stamps.", + name="postISRCCD", + storageClass="Exposure", + dimensions=( + "exposure", + "detector", + ), + ) + inputExtendedPsf = Input( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # MaskedImageF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + refCat = PrerequisiteInput( + doc="Reference catalog that contains bright star positions", + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + # outputBadStamps = Output( + # doc="The stamps that are not normalized and consequently not subtracted from the exposure.", + # name="{badStampsName}_unsubtracted_stamps", + # storageClass="BrightStarStamps", + # dimensions=( + # "visit", + # "detector", + # ), + # ) + + outputExposure = Output( + doc="Exposure with bright stars subtracted.", + name="{outputExposureName}_subtracted", + storageClass="ExposureF", + dimensions=( + "exposure", + "detector", + ), + ) + outputBackgroundExposure = Output( + doc="Exposure containing only the modelled bright stars.", + name="{outputBackgroundName}_background", + storageClass="ExposureF", + dimensions=( + "visit", + "detector", + ), + ) + # scaledModels = Output( + # doc="Stamps containing models scaled to the level of stars", + # name="scaledModels", + # storageClass="BrightStarStamps", + # dimensions=( + # "visit", + # "detector", + # ), + # ) + + +class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightStarSubtractConnections): + """Configuration parameters for BrightStarSubtractTask""" + + doWriteSubtractor = Field[bool]( + doc="Should an exposure containing all bright star models be written to disk?", + default=True, + ) + doWriteSubtractedExposure = Field[bool]( + doc="Should an exposure with bright stars subtracted be written to disk?", + default=True, + ) + magLimit = Field[float]( + doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", + default=18, + ) + minValidAnnulusFraction = Field[float]( + doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be " + "saved for subsequent generation of a PSF.", + default=0.0, + ) + numSigmaClip = Field[float]( + doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=4, + ) + numIter = Field[int]( + doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=3, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + "lanczos6": "Lanczos kernel of order 6", + "lanczos7": "Lanczos kernel of order 7", + }, + ) + scalingType = ChoiceField[str]( + doc="How the model should be scaled to each bright star; implemented options are " + "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " + "least square fitting on each pixel with no bad mask plane set.", + default="leastSquare", + allowed={ + "annularFlux": "reuse BrightStarStamp annular flux measurement", + "leastSquare": "find least square scaling factor", + }, + ) + annularFluxStatistic = ChoiceField[str]( + doc="Type of statistic to use to compute annular flux.", + default="MEANCLIP", + allowed={ + "MEAN": "mean", + "MEDIAN": "median", + "MEANCLIP": "clipped mean", + }, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " + "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " + "as the stamps are expected to already be normalized.", + # Note that `BAD` should always be included, as secondary detected + # sources (i.e., detected sources other than the primary source of + # interest) also get set to `BAD`. + # Lee: find out the value of "BAD" and set the nan values into that number in the mask plane(?) + default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), + ) + subtractionBox = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(250, 250), + ) + subtractionBoxBuffer = Field[float]( + doc=( + "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the " + "processed stars will be saved in. This is also the size of the extended PSF model. The buffer " + "region is masked and contain no data and subtractionBox determines the region where contains " + "the data." + ), + default=1.1, + ) + doApplySkyCorr = Field[bool]( + doc="Apply full focal plane sky correction before extracting stars?", + default=True, + ) + min_iterations = Field[int]( + doc="Minimum number of iterations to complete before evaluating changes in each iteration.", + default=3, + ) + refObjLoader = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + # excludeArcsecRadius = Field[float]( + # doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + # default=5, + # ) + # excludeMagRange = ListField[float]( + # doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + # default=[0, 20], + # ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + # PSF Fitting + # useExtendedPsf = Field[bool]( + # doc="Use the extended PSF model to normalize bright star cutouts.", + # default=False, + # ) + # doFitPsf = Field[bool]( + # doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + # default=True, + # ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + + # # Misc + # loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + # doc="Reference object loader for astrometric calibration.", + # ) + + +class BrightStarSubtractTask(PipelineTask): + """Use an extended PSF model to subtract bright stars from a calibrated + exposure (i.e. at single-visit level). + + This task uses both a set of bright star stamps produced by + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask` + and an extended PSF model produced by + `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. + """ + + ConfigClass = BrightStarSubtractConfig + _DefaultName = "subtractBrightStars" + + def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + # # Placeholders to set up Statistics if scalingType is leastSquare. + # self.statsControl, self.statsFlag = None, None + # # Warping control; only contains shiftingALg provided in config. + + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize #* self.config.stampSizePadding + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Docstring inherited. + inputs = butlerQC.get(inputRefs) + dataId = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + # TODO: include the un-subtracted star here! + subtractor = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) + if self.config.doWriteSubtractedExposure: + outputExposure = inputs["inputExposure"].clone() + outputExposure.image -= subtractor.image + else: + outputExposure = None + outputBackgroundExposure = ExposureF(subtractor) if self.config.doWriteSubtractor else None + output = Struct( + outputExposure=outputExposure, + outputBackgroundExposure=outputBackgroundExposure, + ) + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputCalexp: ExposureF, + inputBackground: BackgroundList, + # inputBrightStarStamps, #next plan is to use stamps for pedestal and gradients? + inputExtendedPsf: ImageF, + dataId: dict[str, Any] | DataCoordinate, + # inputBackground: BackgroundList, + refObjLoader: ReferenceObjectLoader, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputCalexp.getWcs() + bbox = inputCalexp.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + refCatBright.sort("mag") + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + inputFixed = inputCalexp.getMaskedImage() + inputFixed += inputBackground.getImage() + inputCalexp.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + allFootprints, associations = self._associateFootprints(inputCalexp, pixCoords, plane="DETECTED") + + subtractorExp = ExposureF(bbox=bbox) + templateSubtractor = subtractorExp.maskedImage + + detector = inputCalexp.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToTan = detector.getTransform(PIXELS, TAN_PIXELS) + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + self.warpedDataDict = {} + removalIndices = [] + for j in range(self.config.min_iterations): + scaleList = [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + inputMI = deepcopy(inputFixed) + restSubtractor = deepcopy(templateSubtractor) + myNumber = 0 + for key in self.warpedDataDict.keys(): + if self.warpedDataDict[key]["subtractor"] is not None and key != obj['id']: + restSubtractor.image += self.warpedDataDict[key]["subtractor"].image + myNumber += 1 + self.log.debug(f"Number of stars subtracted before finding the scale factor for {obj['id']}: ", myNumber) + inputMI.image -= restSubtractor.image + + footprintIndex = associations.get(starIndex, None) + + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + rawStamp= self._getCutout(inputExposure=inputMI, coordPix=pixCoord, stampSize=self.config.stampSize.list()) + if rawStamp is None: + self.log.debug(f"No stamp for star with refID {obj['id']}") + removalIndices.append(starIndex) + continue + warpedStamp = self._warpRawStamp(obj["id"], obj["mag"], rawStamp, warpingControl, pixToTan, pixCoord) + warpedModel = ImageF(warpedStamp.getBBox()) + inputExtendedPsfGeneral = deepcopy(inputExtendedPsf) + good_pixels = warpImage(warpedModel, inputExtendedPsfGeneral, pixToPolar.inverted(), warpingControl) + self.warpedDataDict[obj["id"]] = {"stamp": warpedStamp, "model": warpedModel, "starIndex": starIndex, "pixCoord": pixCoord} + if j == 0: + self.warpedDataDict[obj["id"]]["scale"] = None + self.warpedDataDict[obj["id"]]["subtractor"] = None + fitPsfResults = {} + fitPsfResults = self._fitPsf( warpedStamp, warpedModel) + if fitPsfResults: + scaleList.append(fitPsfResults["scale"]) + self.warpedDataDict[obj["id"]]["scale"] = fitPsfResults["scale"] + + + cond = np.isnan(warpedModel.array) + warpedModel.array[cond] = 0 + warpedModel.array *= fitPsfResults["scale"] + overlapBBox = Box2I(warpedStamp.getBBox()) + overlapBBox.clip(inputCalexp.getBBox()) + + subtractor = deepcopy(templateSubtractor) + subtractor[overlapBBox] += warpedModel[overlapBBox] + self.warpedDataDict[obj["id"]]["subtractor"] = subtractor + + + else: + scaleList.append(np.nan) + if j == 0: + refCatBright.remove_rows(removalIndices) + updatedPixCoords = [item for i, item in enumerate(pixCoords) if i not in removalIndices] + pixCoords = updatedPixCoords + new_scale_column = Column(scaleList, name=f'scale_0{j}') + # The following is handy when developing, not sure if we want to do that in the final version! + refCatBright.add_columns([new_scale_column]) + + subtractor = deepcopy(templateSubtractor) + for key in self.warpedDataDict.keys(): + if self.warpedDataDict[key]["scale"] is not None: + subtractor.image.array += self.warpedDataDict[key]["subtractor"].image.array + return subtractor + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > brightFluxRange[0]) & ( + refCatFull[fluxField] < brightFluxRange[1] + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + fluxNanojansky = refCatSubset[fluxField][:] * u.nJy # type: ignore + refCatSubset["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + return refCatSubset + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + # TODO: This is np.float64, else FITS metadata serialization fails + # Amir: what do we want to do for subtraction? we do not have the luxury of removing the star from the process here! + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + # if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + # return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Create consistently masked data + mask = self.add_psf_mask(psfImage, stampMI) + badSpans = SpanSet.fromMask(mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} + scale = solutions[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) + + def add_psf_mask(self, psfImage, stampMI): + """Add psf frame mask into the stamp's mask. + + Args: + psfImage (`~lsst.afw.image.ImageF`): PSF data + stampMI (`~lsst.afw.image.MaskedImageF`): the stamp of the star that being fitted. + + Returns: + `~lsst.afw.image.MaskX``: the mask frame containing the stamp plus the psf model mask. + """ + cond = np.isnan(psfImage.array) + cond |= psfImage.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask + + + def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): + """Get a cutout from an input exposure, handling edge cases. + + Generate a cutout from an input exposure centered on a given position + and with a given size. + If any part of the cutout is outside the input exposure bounding box, + the cutout is padded with NaNs. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image to extract bright star stamps from. + coordPix : `~lsst.geom.Point2D` + Center of the cutout in pixel space. + stampSize : `list` [`int`] + Size of the cutout, in pixels. + + Returns + ------- + stamp : `~lsst.afw.image.ExposureF` or `None` + The cutout, or `None` if the cutout is entirely outside the input + exposure bounding box. + + Notes + ----- + This method is a short-term workaround until DM-40042 is implemented. + At that point, it should be replaced by a call to the Exposure method + ``getCutout``, which will handle edge cases automatically. + """ + corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2) + dimensions = Extent2I(stampSize) + stampBBox = Box2I(corner, dimensions) + overlapBBox = Box2I(stampBBox) + overlapBBox.clip(inputExposure.getBBox()) + if overlapBBox.getArea() > 0: + # Create full-sized stamp with pixels initially flagged as NO_DATA. + stamp = ExposureF(bbox=stampBBox) + stamp.image[:] = np.nan + stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) + # # Restore pixels which overlap the input exposure. + overlap = inputExposure.Factory(inputExposure, overlapBBox) + stamp.maskedImage[overlapBBox] = overlap + else: + stamp = None + return stamp + + def _warpRawStamp(self, obj, mag, rawStamp, warpingControl, pixToTan, pixCoord): + destImage = MaskedImageF(*self.config.stampSize) + bottomLeft = Point2D(rawStamp.getXY0()) + newBottomLeft = pixToTan.applyForward(bottomLeft) + newBottomLeft = Point2I(newBottomLeft) + destImage.setXY0(newBottomLeft) + # Define linear shifting to recenter stamps + newCenter = pixToTan.applyForward(pixCoord) + self.modelCenter = self.config.stampSize[0] // 2, self.config.stampSize[1] // 2 + shift = (self.modelCenter[0] + newBottomLeft[0] - newCenter[0], self.modelCenter[1] + newBottomLeft[1] - newCenter[1]) + affineShift = AffineTransform(shift) + shiftTransform = makeTransform(affineShift) + + # Define full transform (warp and shift) + starWarper = pixToTan.then(shiftTransform) + + # Apply it + goodPix = warpImage(destImage, rawStamp.getMaskedImage(), starWarper, warpingControl) + if not goodPix: + return None + return destImage + + # # Arbitrarily set origin of shifted star to 0 + # destImage.setXY0(0, 0) \ No newline at end of file From 7ca2be4a6c34b3f96df79b377abf2674dbef67a4 Mon Sep 17 00:00:00 2001 From: "Amir E. Bazkiaei" Date: Fri, 2 May 2025 17:01:13 +1000 Subject: [PATCH 3/5] Scaling psf model before fitting This is controled by a new config parameter `scalePsfModel` added to `BrightStarCutoutTask` and `BrightStarSubtractTask`. --- .../brightStarSubtraction/brightStarCutout.py | 25 ++- .../brightStarSubtraction/brightStarStack.py | 15 +- .../brightStarSubtract.py | 0 .../subtractBrightStar.py | 146 ++++++++++++------ 4 files changed, 127 insertions(+), 59 deletions(-) delete mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 3943de566..3be878eb0 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -180,6 +180,11 @@ class BrightStarCutoutConfig( }, ) + scalePsfModel = Field[bool]( + doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", + default=True, + ) + # PSF Fitting useExtendedPsf = Field[bool]( doc="Use the extended PSF model to normalize bright star cutouts.", @@ -348,10 +353,18 @@ def run( # Fit a scaled PSF and a pedestal to each bright star cutout psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + # TODO: discuss with Lee whether we should warp the psf here as well? if self.config.useExtendedPsf: psfImage = extendedPsf # Assumed to be warped, center at [0,0] else: psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + if self.config.scalePsfModel: + psfNeg = psfImage.array < 0 + self.modelScale = np.nanmean(stampMI.image.array) / np.nanmean(psfImage.array[~psfNeg]) + psfImage.array *= self.modelScale ######## model scale correction ######## + else: + self.modelScale = 1 + fitPsfResults = {} if self.config.doFitPsf: fitPsfResults = self._fitPsf(stampMI, psfImage) @@ -586,7 +599,8 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) # TODO: This is np.float64, else FITS metadata serialization fails # Amir: the following tries to find the fraction of the psf flux in the masked area of the psf image. - psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + # psfMaskedFluxFrac = np.dot(psfImage.array.astype(bool).flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.astype(bool).sum() if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: return {} # Handle cases where the PSF image is mostly masked @@ -622,13 +636,14 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, return {} # Handle singular matrix errors if sumSquaredResiduals.size == 0: return {} # Handle cases where sum of the squared residuals are empty - scale = solutions[0] + # scale = solutions[0] + scale = solutions[0] * self.modelScale ######## model scale correction ######## if scale <= 0: return {} # Handle cases where the PSF scale fit has failed - scaleErr = np.sqrt(covarianceMatrix[0, 0]) + scaleErr = np.sqrt(covarianceMatrix[0, 0]) * self.modelScale ######## model scale correction ######## pedestal = solutions[1] pedestalErr = np.sqrt(covarianceMatrix[1, 1]) - scalePedestalCov = covarianceMatrix[0, 1] + scalePedestalCov = covarianceMatrix[0, 1] * self.modelScale ######## model scale correction ######## xGradient = solutions[3] yGradient = solutions[2] @@ -641,6 +656,7 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + paddedPsfImage.array /= self.modelScale ######## model scale correction ######## psfBBoxModel = ( psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + pedestal @@ -652,7 +668,6 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfBBoxChiSquared = np.sum(psfBBoxResiduals) psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom - return dict( scale=scale, scaleErr=scaleErr, diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index 2aa40c769..a99d9be3e 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -24,9 +24,9 @@ __all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] import numpy as np -from lsst.afw.image import ImageF, MaskedImageF +from lsst.afw.image import ImageF from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty -from lsst.geom import Box2I, Extent2I, Point2I +from lsst.geom import Point2I from lsst.meas.algorithms import BrightStarStamps from lsst.pex.config import Field, ListField from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct @@ -66,7 +66,7 @@ class BrightStarStackConfig( subsetStampNumber = Field[int]( doc="Number of stamps per subset to generate stacked images for.", - default=20, + default=2, ) globalReducedChiSquaredThreshold = Field[float]( doc="Threshold for global reduced chi-squared for bright star stamps.", @@ -172,9 +172,14 @@ def run( subsetStampMIs = [] tempStampMIs = [] + all_stars = 0 + used_stars = 0 for stampsDDH in brightStarStamps: stamps = stampsDDH.get() + all_stars += len(stamps) for stamp in stamps: + # print("globalReducedChiSquared: stamp ", stamp.globalReducedChiSquared, "config ", self.config.globalReducedChiSquaredThreshold) + # print("psfReducedChiSquared: stamp ", stamp.psfReducedChiSquared, "config ", self.config.psfReducedChiSquaredThreshold) if ( stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold @@ -193,7 +198,11 @@ def run( subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl)) # TODO: what to do with remaining temp stamps? tempStampMIs = [] + used_stars += self.config.subsetStampNumber + self.metadata["psfStarCount"] = {} + self.metadata["psfStarCount"]["all"] = all_stars + self.metadata["psfStarCount"]["used"] = used_stars # TODO: which stamp mask plane to use here? # TODO: Amir: there might be cases where subsetStampMIs is an empty list. What do we want to do then? # Currently, we get an "IndexError: list index out of range" diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py index 6cb224f4a..04de11116 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py @@ -27,14 +27,13 @@ from typing import Any import astropy.units as u import numpy as np -from astropy.coordinates import SkyCoord from astropy.table import Table, Column from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, TAN_PIXELS from lsst.afw.detection import Footprint, FootprintSet, Threshold -from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom import SkyWcs, SpanSet from lsst.afw.geom.transformFactory import makeTransform from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF -from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.afw.math import BackgroundList, WarpingControl, warpImage from lsst.daf.butler import DataCoordinate from lsst.geom import ( AffineTransform, @@ -264,27 +263,10 @@ class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightSta doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", default=[0, 18], ) - # excludeArcsecRadius = Field[float]( - # doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", - # default=5, - # ) - # excludeMagRange = ListField[float]( - # doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", - # default=[0, 20], - # ) minAreaFraction = Field[float]( doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", default=0.1, ) - # PSF Fitting - # useExtendedPsf = Field[bool]( - # doc="Use the extended PSF model to normalize bright star cutouts.", - # default=False, - # ) - # doFitPsf = Field[bool]( - # doc="Fit a scaled PSF and a pedestal to each bright star cutout.", - # default=True, - # ) useMedianVariance = Field[bool]( doc="Use the median of the variance plane for PSF fitting.", default=False, @@ -293,11 +275,10 @@ class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightSta doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", default=0.97, ) - - # # Misc - # loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( - # doc="Reference object loader for astrometric calibration.", - # ) + scalePsfModel = Field[bool]( + doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", + default=True, + ) class BrightStarSubtractTask(PipelineTask): @@ -322,8 +303,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) stampSize = Extent2D(*self.config.stampSize.list()) stampRadius = floor(stampSize / 2) + # Define a central bounding box of the configured stamp size. self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) - paddedStampSize = stampSize #* self.config.stampSizePadding + paddedStampSize = stampSize self.paddedStampRadius = floor(paddedStampSize / 2) self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( self.paddedStampRadius @@ -339,7 +321,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): name=self.config.connections.refCat, config=self.config.refObjLoader, ) - # TODO: include the un-subtracted star here! + # TODO: include the un-subtracted stars here! subtractor = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) if self.config.doWriteSubtractedExposure: outputExposure = inputs["inputExposure"].clone() @@ -359,35 +341,46 @@ def run( inputExposure: ExposureF, inputCalexp: ExposureF, inputBackground: BackgroundList, - # inputBrightStarStamps, #next plan is to use stamps for pedestal and gradients? inputExtendedPsf: ImageF, dataId: dict[str, Any] | DataCoordinate, - # inputBackground: BackgroundList, refObjLoader: ReferenceObjectLoader, ): - """Identify bright stars within an exposure using a reference catalog, - extract stamps around each, warp/shift stamps onto a common frame and - then optionally fit a PSF plus plane model. + """Generate a bright star subtractor image using scaled extended PSF models. + + Identifies bright stars within the calibrated exposure using a + reference catalog, extracts stamps around each, warps the extended PSF + model onto the stamp frame, fits for a scale factor and pedestal for + each star iteratively, and combines the scaled models into a single + subtractor exposure. Parameters ---------- inputExposure : `~lsst.afw.image.ExposureF` - The background-subtracted image to extract bright star stamps. + The Post-ISR CCD frame. Note: Currently appears unused directly + within this method's main logic, but required by the pipeline + definition. Cutouts are based on `inputCalexp` + `inputBackground`. + inputCalexp: `~lsst.afw.image.ExposureF` + The background-subtracted calibrated exposure used for identifying + stars, extracting stamps, and fitting models. inputBackground : `~lsst.afw.math.BackgroundList` - The background model associated with the input exposure. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. + The background model associated with `inputCalexp`. Added back + before processing stamps. + inputExtendedPsf : `~lsst.afw.image.ImageF` + The extended PSF model (e.g., from MeasureExtendedPsfTask). + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader used to query the reference catalog for bright stars within + the exposure footprint. dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure that bright stars are extracted from. - Both 'visit' and 'detector' will be persisted in the output data. + The data identifier for the input exposure. Returns ------- - brightStarResults : `~lsst.pipe.base.Struct` - Results as a struct with attributes: - - ``brightStarStamps`` - (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + subtractor : `~lsst.afw.image.ExposureF` + An exposure containing the combined, scaled models of the bright + stars identified and processed. This image can be subtracted from + the original `inputExposure` (or `inputCalexp` + `inputBackground`). + The image plane contains the model flux, while the variance and + mask planes are typically empty or minimal unless specifically populated. """ wcs = inputCalexp.getWcs() bbox = inputCalexp.getBBox() @@ -399,9 +392,13 @@ def run( spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] pixCoords = wcs.skyToPixel(spherePoints) + # Create image with background added back + # Using calibrated exposure for finding the scale factor and creating subtrator. + # The generated subtractor will be subtracted from PostISRCCd. inputFixed = inputCalexp.getMaskedImage() inputFixed += inputBackground.getImage() inputCalexp.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + # Associate detected footprints (from DETECTED plane) with the bright reference stars. allFootprints, associations = self._associateFootprints(inputCalexp, pixCoords, plane="DETECTED") subtractorExp = ExposureF(bbox=bbox) @@ -419,10 +416,12 @@ def run( for j in range(self.config.min_iterations): scaleList = [] for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + # Start with the background-added image for each star inputMI = deepcopy(inputFixed) restSubtractor = deepcopy(templateSubtractor) myNumber = 0 for key in self.warpedDataDict.keys(): + # Subtract the *current best models* of all *other* stars before fitting this one. if self.warpedDataDict[key]["subtractor"] is not None and key != obj['id']: restSubtractor.image += self.warpedDataDict[key]["subtractor"].image myNumber += 1 @@ -447,7 +446,7 @@ def run( self.log.debug(f"No stamp for star with refID {obj['id']}") removalIndices.append(starIndex) continue - warpedStamp = self._warpRawStamp(obj["id"], obj["mag"], rawStamp, warpingControl, pixToTan, pixCoord) + warpedStamp = self._warpRawStamp(rawStamp, warpingControl, pixToTan, pixCoord) warpedModel = ImageF(warpedStamp.getBBox()) inputExtendedPsfGeneral = deepcopy(inputExtendedPsf) good_pixels = warpImage(warpedModel, inputExtendedPsfGeneral, pixToPolar.inverted(), warpingControl) @@ -456,6 +455,10 @@ def run( self.warpedDataDict[obj["id"]]["scale"] = None self.warpedDataDict[obj["id"]]["subtractor"] = None fitPsfResults = {} + if self.config.scalePsfModel: + psfNeg = warpedModel.array < 0 + self.modelScale = np.nanmean(warpedStamp.image.array) / np.nanmean(warpedModel.array[~psfNeg]) + warpedModel.array *= self.modelScale ######## model scale correction ######## fitPsfResults = self._fitPsf( warpedStamp, warpedModel) if fitPsfResults: scaleList.append(fitPsfResults["scale"]) @@ -475,6 +478,9 @@ def run( else: scaleList.append(np.nan) + if "subtractor" not in self.warpedDataDict[obj["id"]].keys(): + self.warpedDataDict[obj["id"]]["subtractor"] = None + self.warpedDataDict[obj["id"]]["scale"] = None if j == 0: refCatBright.remove_rows(removalIndices) updatedPixCoords = [item for i, item in enumerate(pixCoords) if i not in removalIndices] @@ -726,14 +732,25 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, ) def add_psf_mask(self, psfImage, stampMI): - """Add psf frame mask into the stamp's mask. + """Add problematic PSF pixels to the stamp's mask. + + Identifies pixels in the PSF model image that are NaN or negative + and sets the corresponding bits (hardcoded as plane 0, likely 'BAD') + in a copy of the input stamp's mask. - Args: - psfImage (`~lsst.afw.image.ImageF`): PSF data - stampMI (`~lsst.afw.image.MaskedImageF`): the stamp of the star that being fitted. + Parameters + ---------- + psfImage : `~lsst.afw.image.ImageF` + PSF model image defined on the stamp grid. + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image stamp of the star being fitted. Its mask is used + as the base. - Returns: - `~lsst.afw.image.MaskX``: the mask frame containing the stamp plus the psf model mask. + Returns + ------- + mask : `~lsst.afw.image.Mask` + A mask object based on the input stamp's mask, updated to include + masked pixels derived from the PSF model image. """ cond = np.isnan(psfImage.array) cond |= psfImage.array < 0 @@ -788,7 +805,34 @@ def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): stamp = None return stamp - def _warpRawStamp(self, obj, mag, rawStamp, warpingControl, pixToTan, pixCoord): + def _warpRawStamp(self, rawStamp, warpingControl, pixToTan, pixCoord): + """Warps a raw image stamp onto a common tangent plane projection. + + Applies a transformation (`pixToTan`) followed by a shift + transform to warp the input `rawStamp` onto a destination + `MaskedImageF` aligned with a tangent plane centered near the object. + The shift aims to place the object center at the center of the + destination image. + + Parameters + ---------- + rawStamp : `~lsst.afw.image.ExposureF` + The raw cutout image stamp (e.g., from `_getCutout`). + warpingControl : `~lsst.afw.math.WarpingControl` + Configuration for the warping process. + pixToTan : `~lsst.afw.geom.Transform` + Transformation from the raw stamp's pixel coordinates to the + common tangent plane coordinates. + pixCoord : `~lsst.geom.Point2D` + Pixel coordinates of the object center in the original exposure, + used to calculate the centering shift. + + Returns + ------- + warped_stamp : `~lsst.afw.image.MaskedImageF` or `None` + The warped and shifted masked image, or None if warping failed + (e.g., due to insufficient good pixels). + """ destImage = MaskedImageF(*self.config.stampSize) bottomLeft = Point2D(rawStamp.getXY0()) newBottomLeft = pixToTan.applyForward(bottomLeft) From e470f6fde59fb31a496195beff3f189088aee386 Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Fri, 24 Oct 2025 08:26:12 +0000 Subject: [PATCH 4/5] modifications to brightStarCutout.py module --- .../brightStarSubtraction/brightStarCutout.py | 368 ++++++-- .../brightStarSubtraction/brightStarStack.py | 2 - .../subtractBrightStar.py | 858 ------------------ tests/test_brightStarCutout.py | 11 +- 4 files changed, 283 insertions(+), 956 deletions(-) delete mode 100644 python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 3be878eb0..9afda1f6f 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -29,7 +29,7 @@ import numpy as np from astropy.coordinates import SkyCoord from astropy.table import Table -from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, FOCAL_PLANE from lsst.afw.detection import Footprint, FootprintSet, Threshold from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs from lsst.afw.geom.transformFactory import makeTransform @@ -47,6 +47,7 @@ arcseconds, floor, radians, + Angle, ) from lsst.meas.algorithms import ( BrightStarStamp, @@ -61,6 +62,8 @@ from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput from lsst.utils.timer import timeMethod from copy import deepcopy +import math + NEIGHBOR_MASK_PLANE = "NEIGHBOR" @@ -149,8 +152,6 @@ class BrightStarCutoutConfig( NEIGHBOR_MASK_PLANE, ], ) - - # Cutout geometry stampSize = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(251, 251), @@ -179,9 +180,8 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - scalePsfModel = Field[bool]( - doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", + doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", default=True, ) @@ -202,6 +202,14 @@ class BrightStarCutoutConfig( doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", default=0.97, ) + fitIterations = Field[int]( + doc="Number of iterations over pedestal-gradient and scaling fit.", + default=5, + ) + offFrameMagLim = Field[float]( + doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + default=15.0, + ) # Misc loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( @@ -238,6 +246,7 @@ def __init__(self, initInputs=None, *args, **kwargs): self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( self.paddedStampRadius ) + self.modelScale = 1 def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) @@ -299,8 +308,6 @@ def run( # Restore original subtracted background inputMI = inputExposure.getMaskedImage() inputMI += inputBackground.getImage() - # Amir: the above addition to inputMI, also adds to the inputExposure. - # Amir: but the calibration, three lines later, only is applied to the inputMI. # Set up NEIGHBOR mask plane; associate footprints with stars inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) @@ -319,16 +326,17 @@ def run( # Loop over each bright star stamps, goodFracs, stamps_fitPsfResults = [], [], [] for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + # Excluding faint stars that are not within the frame. + if obj["mag"] > self.config.offFrameMagLim and not self.star_in_frame(pixCoord, bbox): + continue footprintIndex = associations.get(starIndex, None) stampMI = MaskedImageF(self.paddedStampBBox) # Set NEIGHBOR footprints in the mask plane if footprintIndex: neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - # self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) else: - # self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) # Define linear shifting to recenter stamps @@ -339,7 +347,6 @@ def run( pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) # Apply the warp to the star stamp (in-place) - # warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) warpImage(stampMI, inputMI, pixToPolar, warpingControl) # Trim to the base stamp size, check mask coverage, update metadata @@ -353,47 +360,62 @@ def run( # Fit a scaled PSF and a pedestal to each bright star cutout psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) - # TODO: discuss with Lee whether we should warp the psf here as well? if self.config.useExtendedPsf: - psfImage = extendedPsf # Assumed to be warped, center at [0,0] + psfImage = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] else: psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) - if self.config.scalePsfModel: - psfNeg = psfImage.array < 0 - self.modelScale = np.nanmean(stampMI.image.array) / np.nanmean(psfImage.array[~psfNeg]) - psfImage.array *= self.modelScale ######## model scale correction ######## - else: - self.modelScale = 1 + # TODO: maybe we want to generate a smaller psf in case the following happens? + # The following could happen for when the user chooses small stampSize ~(50, 50) + if ( + psfImage.array.shape[0] > stampMI.image.array.shape[0] + or psfImage.array.shape[1] > stampMI.image.array.shape[1] + ): + continue + # Computing an scale factor that brings the model to the similar level of the star. + self.computeModelScale(stampMI, psfImage) + psfImage.array *= self.modelScale # ####### model scale correction ######## fitPsfResults = {} + if self.config.doFitPsf: fitPsfResults = self._fitPsf(stampMI, psfImage) stamps_fitPsfResults.append(fitPsfResults) # Save the stamp if the PSF fit was successful or no fit requested if fitPsfResults or not self.config.doFitPsf: + distance_mm, theta_angle = self.star_location_on_focal(pixCoord, detector) + stamp = BrightStarStamp( - maskedImage=stampMI, - # TODO: what to do about this PSF? + stamp_im=stampMI, psf=constantPsf, wcs=makeModifiedWcs(pixToPolar, wcs, False), visit=cast(int, dataId["visit"]), detector=cast(int, dataId["detector"]), - refId=obj["id"], - refMag=obj["mag"], + ref_id=obj["id"], + ref_mag=obj["mag"], position=pixCoord, + focal_plane_radius=distance_mm, + focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here scale=fitPsfResults.get("scale", None), - scaleErr=fitPsfResults.get("scaleErr", None), + scale_err=fitPsfResults.get("scaleErr", None), pedestal=fitPsfResults.get("pedestal", None), - pedestalErr=fitPsfResults.get("pedestalErr", None), - pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), - xGradient=fitPsfResults.get("xGradient", None), - yGradient=fitPsfResults.get("yGradient", None), - globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), - globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), - psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), - psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), - psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + pedestal_err=fitPsfResults.get("pedestalErr", None), + pedestal_scale_cov=fitPsfResults.get("pedestalScaleCov", None), + gradient_x=fitPsfResults.get("xGradient", None), + gradient_y=fitPsfResults.get("yGradient", None), + global_reduced_chi_squared=fitPsfResults.get("globalReducedChiSquared", None), + global_degrees_of_freedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psf_reduced_chi_squared=fitPsfResults.get("psfReducedChiSquared", None), + psf_degrees_of_freedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psf_masked_flux_fraction=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + print( + obj["mag"], + fitPsfResults.get("globalReducedChiSquared", None), + fitPsfResults.get("globalDegreesOfFreedom", None), + fitPsfResults.get("psfReducedChiSquared", None), + fitPsfResults.get("psfDegreesOfFreedom", None), + fitPsfResults.get("psfMaskedFluxFrac", None), ) stamps.append(stamp) @@ -414,6 +436,25 @@ def run( brightStarStamps = BrightStarStamps(stamps) return Struct(brightStarStamps=brightStarStamps) + def star_location_on_focal(self, pixCoord, detector): + star_focal_plane_coords = detector.transform(pixCoord, PIXELS, FOCAL_PLANE) + star_x_fp = star_focal_plane_coords.getX() + star_y_fp = star_focal_plane_coords.getY() + distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2) + theta_rad = math.atan2(star_y_fp, star_x_fp) + theta_angle = Angle(theta_rad, radians) + return distance_mm, theta_angle + + def star_in_frame(self, pixCoord, inputExposureBBox): + if ( + pixCoord[0] < 0 + or pixCoord[1] < 0 + or pixCoord[0] > inputExposureBBox.getDimensions()[0] + or pixCoord[1] > inputExposureBBox.getDimensions()[1] + ): + return False + return True + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: """Get a bright star subset of the reference catalog. @@ -597,76 +638,108 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, # Calculate the fraction of the PSF image flux masked by bad pixels psfMaskedPixels = ImageF(psfImage.getBBox()) psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) - # TODO: This is np.float64, else FITS metadata serialization fails - # Amir: the following tries to find the fraction of the psf flux in the masked area of the psf image. - psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() - # psfMaskedFluxFrac = np.dot(psfImage.array.astype(bool).flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.astype(bool).sum() + psfMaskedFluxFrac = ( + np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + ) if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: return {} # Handle cases where the PSF image is mostly masked - # Create a padded version of the input constant PSF image - paddedPsfImage = ImageF(stampMI.getBBox()) - paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() - - mask = self.add_psf_mask(paddedPsfImage, stampMI) - # Create consistently masked data - # badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) - badSpans = SpanSet.fromMask(mask, badMaskBitMask) - goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) - varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask). + gradientGoodSpans = self.generate_gradient_spans(stampMI, badMaskBitMask) + varianceData = gradientGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) if self.config.useMedianVariance: varianceData = np.median(varianceData) sigmaData = np.sqrt(varianceData) - imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B - imageData /= sigmaData - psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) - psfData /= sigmaData - # Fit the PSF scale factor and global pedestal + + for i in range(self.config.fitIterations): + # Gradient-pedestal fitting: + if i: + # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can + # remove the star using the scale factor. + stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821 + else: + stamp = deepcopy(stampMI.image.array) + + imageDataGr = gradientGoodSpans.flatten(stamp, stampMI.getXY0()) / sigmaData # B + nData = len(imageDataGr) + coefficientMatrix = np.ones((nData, 3), dtype=float) # A + coefficientMatrix[:, 0] /= sigmaData + coefficientMatrix[:, 1:] = gradientGoodSpans.indices().T + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2] /= sigmaData + + try: + grSolutions, grSumSquaredResiduals, *_ = np.linalg.lstsq( + coefficientMatrix, imageDataGr, rcond=None + ) + covarianceMatrix = np.linalg.inv( + np.dot(coefficientMatrix.transpose(), coefficientMatrix) + ) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if grSumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + + pedestal = grSolutions[0] + pedestalErr = np.sqrt(covarianceMatrix[0, 0]) + scalePedestalCov = None + xGradient = grSolutions[2] + yGradient = grSolutions[1] + + # Scale fitting: + updatedStampMI = deepcopy(stampMI) + self._removePedestalAndGradient(updatedStampMI, pedestal, xGradient, yGradient) + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(updatedStampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Generating a mask plane while considering bad pixels in the psf model. + mask = self.add_psf_mask(paddedPsfImage, updatedStampMI) + # Create consistently masked data + scaleGoodSpans = self.generate_good_spans(mask, updatedStampMI.getBBox(), badMaskBitMask) + + imageData = scaleGoodSpans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) + psfData = scaleGoodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + scaleCoefficientMatrix = psfData.reshape(psfData.shape[0], 1) + + try: + scaleSolution, scaleSumSquaredResiduals, *_ = np.linalg.lstsq( + scaleCoefficientMatrix, imageData, rcond=None + ) + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if scaleSumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = scaleSolution[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + # TODO: calculate scale error and store it. + scaleErr = None + + scale *= self.modelScale # ####### model scale correction ######## nData = len(imageData) - coefficientMatrix = np.ones((nData, 4), dtype=float) # A - coefficientMatrix[:, 0] = psfData - coefficientMatrix[:, 1] /= sigmaData - coefficientMatrix[:, 2:] = goodSpans.indices().T - coefficientMatrix[:, 2] /= sigmaData - coefficientMatrix[:, 3] /= sigmaData - try: - solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) - covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C - except np.linalg.LinAlgError: - return {} # Handle singular matrix errors - if sumSquaredResiduals.size == 0: - return {} # Handle cases where sum of the squared residuals are empty - # scale = solutions[0] - scale = solutions[0] * self.modelScale ######## model scale correction ######## - if scale <= 0: - return {} # Handle cases where the PSF scale fit has failed - scaleErr = np.sqrt(covarianceMatrix[0, 0]) * self.modelScale ######## model scale correction ######## - pedestal = solutions[1] - pedestalErr = np.sqrt(covarianceMatrix[1, 1]) - scalePedestalCov = covarianceMatrix[0, 1] * self.modelScale ######## model scale correction ######## - xGradient = solutions[3] - yGradient = solutions[2] - - # Calculate global (whole image) reduced chi-squared - globalChiSquared = np.sum(sumSquaredResiduals) - globalDegreesOfFreedom = nData - 4 - globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate global (whole image) reduced chi-squared (scaling fit is assumed as the main fitting + # process here.) + globalChiSquared = np.sum(scaleSumSquaredResiduals) + globalDegreesOfFreedom = nData - 1 + globalReducedChiSquared = np.float64(globalChiSquared / globalDegreesOfFreedom) # Calculate PSF BBox reduced chi-squared - psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) - psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() - psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) - paddedPsfImage.array /= self.modelScale ######## model scale correction ######## + psfBBoxscaleGoodSpans = scaleGoodSpans.clippedTo(psfImage.getBBox()) + psfBBoxscaleGoodSpansX, psfBBoxscaleGoodSpansY = psfBBoxscaleGoodSpans.indices() + psfBBoxData = psfBBoxscaleGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + paddedPsfImage.array /= self.modelScale # ####### model scale correction ######## psfBBoxModel = ( - psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + psfBBoxscaleGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + pedestal - + psfBBoxGoodSpansX * xGradient - + psfBBoxGoodSpansY * yGradient + + psfBBoxscaleGoodSpansX * xGradient + + psfBBoxscaleGoodSpansY * yGradient ) - psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 # / psfBBoxVariance psfBBoxChiSquared = np.sum(psfBBoxResiduals) - psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 1 psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom return dict( scale=scale, @@ -683,9 +756,122 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, psfMaskedFluxFrac=psfMaskedFluxFrac, ) - def add_psf_mask(self, psfImage, stampMI): + def add_psf_mask(self, psfImage, stampMI, maskZeros=True): + """ + Creates a new mask by adding PSF bad pixels to an existing stamp mask. + + This method identifies "bad" pixels in the PSF image (NaNs and + optionally zeros/non-positives) and adds them to a deep copy + of the input stamp's mask. + + Args: + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + maskZeros (bool, optional): If True (default), mask pixels + where the PSF is <= 0. If False, only mask pixels < 0. + + Returns: + Any: A new mask object (deep copy) with the PSF mask planes added. + """ cond = np.isnan(psfImage.array) - cond |= psfImage.array < 0 + if maskZeros: + cond |= psfImage.array <= 0 + else: + cond |= psfImage.array < 0 mask = deepcopy(stampMI.mask) mask.array[cond] = np.bitwise_or(mask.array[cond], 1) return mask + + def _removePedestalAndGradient(self, stampMI, pedestal, xGradient, yGradient): + """Apply fitted pedestal and gradients to a single bright star stamp.""" + stampBBox = stampMI.getBBox() + xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) + xPlane = ImageF((xGrid * xGradient).astype(np.float32), xy0=stampMI.getXY0()) + yPlane = ImageF((yGrid * yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= pedestal + stampMI -= xPlane + stampMI -= yPlane + + def remove_star(self, stampMI, scale, psfImage): + """ + Subtracts a scaled PSF model from a star image. + + This performs a simple subtraction: `image - (psf * scale)`. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + scale (float): The scaling factor to apply to the PSF. + psfImage: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + + Returns: + np.ndarray: A new 2D numpy array containing the star-subtracted + image. + """ + star_removed_cutout = stampMI.image.array - psfImage.array * scale + return star_removed_cutout + + def computeModelScale(self, stampMI, psfImage): + """ + Computes the scaling factor of the given model against a star. + + Args: + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The given PSF model. + """ + cond = stampMI.mask.array == 0 + self.starMedianValue = np.median(stampMI.image.array[cond]).astype(np.float64) + + psfPos = psfImage.array > 0 + + imageArray = stampMI.image.array - self.starMedianValue + imageArrayPos = imageArray > 0 + self.modelScale = np.nanmean(imageArray[imageArrayPos]) / np.nanmean(psfImage.array[psfPos]) + + def generate_gradient_spans(self, stampMI, badMaskBitMask): + """ + Generates spans of "good" pixels for gradient fitting. + + This method creates a combined bitmask by OR-ing the provided + `badMaskBitMask` with the "DETECTED" plane from the stamp's mask. + It then calls `self.generate_good_spans` to find all pixel spans + not covered by this combined mask. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + badMaskBitMask (int): A bitmask representing planes to be + considered "bad" for gradient fitting. + + Returns: + gradientGoodSpans: A SpanSet object containing the "good" spans. + """ + detectedMaskBitMask = stampMI.mask.getPlaneBitMask("DETECTED") + gradientBitMask = np.bitwise_or(badMaskBitMask, detectedMaskBitMask) + + gradientGoodSpans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradientBitMask) + return gradientGoodSpans + + def generate_good_spans(self, mask, bBox, badBitMask): + """ + Generates a SpanSet of "good" pixels from a mask. + + This method identifies all spans within a given bounding box (`bBox`) + that are *not* flagged by the `badBitMask` in the provided `mask`. + + Args: + mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`). + bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`). + badBitMask (int): The combined bitmask of planes to exclude. + + Returns: + goodSpans: A SpanSet object representing all "good" spans. + """ + badSpans = SpanSet.fromMask(mask, badBitMask) + goodSpans = SpanSet(bBox).intersectNot(badSpans) + return goodSpans diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index a99d9be3e..1c2fcad4f 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -178,8 +178,6 @@ def run( stamps = stampsDDH.get() all_stars += len(stamps) for stamp in stamps: - # print("globalReducedChiSquared: stamp ", stamp.globalReducedChiSquared, "config ", self.config.globalReducedChiSquaredThreshold) - # print("psfReducedChiSquared: stamp ", stamp.psfReducedChiSquared, "config ", self.config.psfReducedChiSquaredThreshold) if ( stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py deleted file mode 100644 index 04de11116..000000000 --- a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py +++ /dev/null @@ -1,858 +0,0 @@ -# This file is part of pipe_tasks. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (https://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -"""Retrieve extended PSF model and subtract bright stars at visit level.""" - -__all__ = ["BrightStarSubtractConnections", "BrightStarSubtractConfig", "BrightStarSubtractTask"] - -import logging -from typing import Any -import astropy.units as u -import numpy as np -from astropy.table import Table, Column -from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, TAN_PIXELS -from lsst.afw.detection import Footprint, FootprintSet, Threshold -from lsst.afw.geom import SkyWcs, SpanSet -from lsst.afw.geom.transformFactory import makeTransform -from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF -from lsst.afw.math import BackgroundList, WarpingControl, warpImage -from lsst.daf.butler import DataCoordinate -from lsst.geom import ( - AffineTransform, - Box2I, - Extent2D, - Extent2I, - Point2D, - Point2I, - SpherePoint, - arcseconds, - floor, - radians, -) -from lsst.meas.algorithms import ( - LoadReferenceObjectsConfig, - ReferenceObjectLoader, -) -from lsst.pex.config import ChoiceField, ConfigField, Field, ListField -from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct -from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput -from lsst.utils.timer import timeMethod -from copy import deepcopy - -NEIGHBOR_MASK_PLANE = "NEIGHBOR" - -logger = logging.getLogger(__name__) - - -class BrightStarSubtractConnections( - PipelineTaskConnections, - dimensions=("instrument", "visit", "detector"), - defaultTemplates={ - # "outputExposureName": "brightStar_subtracted", - "outputExposureName": "postISRCCD", - "outputBackgroundName": "brightStars", - "badStampsName": "brightStars", - }, -): - inputCalexp = Input( - name="calexp", - storageClass="ExposureF", - doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", - dimensions=("visit", "detector"), - ) - inputBackground = Input( - name="calexpBackground", - storageClass="Background", - doc="Background model for the input exposure, to be added back on during processing.", - dimensions=("visit", "detector"), - ) - inputExposure = Input( - doc="Input exposure from which to subtract bright star stamps.", - name="postISRCCD", - storageClass="Exposure", - dimensions=( - "exposure", - "detector", - ), - ) - inputExtendedPsf = Input( - name="extendedPsf2", # extendedPsfDetector ??? - storageClass="ImageF", # MaskedImageF - doc="Extended PSF model, built from stacking bright star cutouts.", - dimensions=("band",), - ) - refCat = PrerequisiteInput( - doc="Reference catalog that contains bright star positions", - name="gaia_dr3_20230707", - storageClass="SimpleCatalog", - dimensions=("skypix",), - multiple=True, - deferLoad=True, - ) - # outputBadStamps = Output( - # doc="The stamps that are not normalized and consequently not subtracted from the exposure.", - # name="{badStampsName}_unsubtracted_stamps", - # storageClass="BrightStarStamps", - # dimensions=( - # "visit", - # "detector", - # ), - # ) - - outputExposure = Output( - doc="Exposure with bright stars subtracted.", - name="{outputExposureName}_subtracted", - storageClass="ExposureF", - dimensions=( - "exposure", - "detector", - ), - ) - outputBackgroundExposure = Output( - doc="Exposure containing only the modelled bright stars.", - name="{outputBackgroundName}_background", - storageClass="ExposureF", - dimensions=( - "visit", - "detector", - ), - ) - # scaledModels = Output( - # doc="Stamps containing models scaled to the level of stars", - # name="scaledModels", - # storageClass="BrightStarStamps", - # dimensions=( - # "visit", - # "detector", - # ), - # ) - - -class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightStarSubtractConnections): - """Configuration parameters for BrightStarSubtractTask""" - - doWriteSubtractor = Field[bool]( - doc="Should an exposure containing all bright star models be written to disk?", - default=True, - ) - doWriteSubtractedExposure = Field[bool]( - doc="Should an exposure with bright stars subtracted be written to disk?", - default=True, - ) - magLimit = Field[float]( - doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", - default=18, - ) - minValidAnnulusFraction = Field[float]( - doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be " - "saved for subsequent generation of a PSF.", - default=0.0, - ) - numSigmaClip = Field[float]( - doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", - default=4, - ) - numIter = Field[int]( - doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", - default=3, - ) - warpingKernelName = ChoiceField[str]( - doc="Warping kernel", - default="lanczos5", - allowed={ - "bilinear": "bilinear interpolation", - "lanczos3": "Lanczos kernel of order 3", - "lanczos4": "Lanczos kernel of order 4", - "lanczos5": "Lanczos kernel of order 5", - "lanczos6": "Lanczos kernel of order 6", - "lanczos7": "Lanczos kernel of order 7", - }, - ) - scalingType = ChoiceField[str]( - doc="How the model should be scaled to each bright star; implemented options are " - "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " - "least square fitting on each pixel with no bad mask plane set.", - default="leastSquare", - allowed={ - "annularFlux": "reuse BrightStarStamp annular flux measurement", - "leastSquare": "find least square scaling factor", - }, - ) - annularFluxStatistic = ChoiceField[str]( - doc="Type of statistic to use to compute annular flux.", - default="MEANCLIP", - allowed={ - "MEAN": "mean", - "MEDIAN": "median", - "MEANCLIP": "clipped mean", - }, - ) - badMaskPlanes = ListField[str]( - doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " - "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " - "as the stamps are expected to already be normalized.", - # Note that `BAD` should always be included, as secondary detected - # sources (i.e., detected sources other than the primary source of - # interest) also get set to `BAD`. - # Lee: find out the value of "BAD" and set the nan values into that number in the mask plane(?) - default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), - ) - subtractionBox = ListField[int]( - doc="Size of the stamps to be extracted, in pixels.", - default=(250, 250), - ) - subtractionBoxBuffer = Field[float]( - doc=( - "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the " - "processed stars will be saved in. This is also the size of the extended PSF model. The buffer " - "region is masked and contain no data and subtractionBox determines the region where contains " - "the data." - ), - default=1.1, - ) - doApplySkyCorr = Field[bool]( - doc="Apply full focal plane sky correction before extracting stars?", - default=True, - ) - min_iterations = Field[int]( - doc="Minimum number of iterations to complete before evaluating changes in each iteration.", - default=3, - ) - refObjLoader = ConfigField[LoadReferenceObjectsConfig]( - doc="Reference object loader for astrometric calibration.", - ) - maskWarpingKernelName = ChoiceField[str]( - doc="Warping kernel for mask.", - default="bilinear", - allowed={ - "bilinear": "bilinear interpolation", - "lanczos3": "Lanczos kernel of order 3", - "lanczos4": "Lanczos kernel of order 4", - "lanczos5": "Lanczos kernel of order 5", - }, - ) - # Cutout geometry - stampSize = ListField[int]( - doc="Size of the stamps to be extracted, in pixels.", - default=(251, 251), - ) - stampSizePadding = Field[float]( - doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", - default=1.1, - ) - # Star selection - magRange = ListField[float]( - doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", - default=[0, 18], - ) - minAreaFraction = Field[float]( - doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", - default=0.1, - ) - useMedianVariance = Field[bool]( - doc="Use the median of the variance plane for PSF fitting.", - default=False, - ) - psfMaskedFluxFracThreshold = Field[float]( - doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", - default=0.97, - ) - scalePsfModel = Field[bool]( - doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", - default=True, - ) - - -class BrightStarSubtractTask(PipelineTask): - """Use an extended PSF model to subtract bright stars from a calibrated - exposure (i.e. at single-visit level). - - This task uses both a set of bright star stamps produced by - `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask` - and an extended PSF model produced by - `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. - """ - - ConfigClass = BrightStarSubtractConfig - _DefaultName = "subtractBrightStars" - - def __init__(self, *args, **kwargs): - # super().__init__(*args, **kwargs) - # # Placeholders to set up Statistics if scalingType is leastSquare. - # self.statsControl, self.statsFlag = None, None - # # Warping control; only contains shiftingALg provided in config. - - super().__init__(*args, **kwargs) - stampSize = Extent2D(*self.config.stampSize.list()) - stampRadius = floor(stampSize / 2) - # Define a central bounding box of the configured stamp size. - self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) - paddedStampSize = stampSize - self.paddedStampRadius = floor(paddedStampSize / 2) - self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( - self.paddedStampRadius - ) - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - # Docstring inherited. - inputs = butlerQC.get(inputRefs) - dataId = butlerQC.quantum.dataId - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, - config=self.config.refObjLoader, - ) - # TODO: include the un-subtracted stars here! - subtractor = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) - if self.config.doWriteSubtractedExposure: - outputExposure = inputs["inputExposure"].clone() - outputExposure.image -= subtractor.image - else: - outputExposure = None - outputBackgroundExposure = ExposureF(subtractor) if self.config.doWriteSubtractor else None - output = Struct( - outputExposure=outputExposure, - outputBackgroundExposure=outputBackgroundExposure, - ) - butlerQC.put(output, outputRefs) - - @timeMethod - def run( - self, - inputExposure: ExposureF, - inputCalexp: ExposureF, - inputBackground: BackgroundList, - inputExtendedPsf: ImageF, - dataId: dict[str, Any] | DataCoordinate, - refObjLoader: ReferenceObjectLoader, - ): - """Generate a bright star subtractor image using scaled extended PSF models. - - Identifies bright stars within the calibrated exposure using a - reference catalog, extracts stamps around each, warps the extended PSF - model onto the stamp frame, fits for a scale factor and pedestal for - each star iteratively, and combines the scaled models into a single - subtractor exposure. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The Post-ISR CCD frame. Note: Currently appears unused directly - within this method's main logic, but required by the pipeline - definition. Cutouts are based on `inputCalexp` + `inputBackground`. - inputCalexp: `~lsst.afw.image.ExposureF` - The background-subtracted calibrated exposure used for identifying - stars, extracting stamps, and fitting models. - inputBackground : `~lsst.afw.math.BackgroundList` - The background model associated with `inputCalexp`. Added back - before processing stamps. - inputExtendedPsf : `~lsst.afw.image.ImageF` - The extended PSF model (e.g., from MeasureExtendedPsfTask). - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` - Loader used to query the reference catalog for bright stars within - the exposure footprint. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The data identifier for the input exposure. - - Returns - ------- - subtractor : `~lsst.afw.image.ExposureF` - An exposure containing the combined, scaled models of the bright - stars identified and processed. This image can be subtracted from - the original `inputExposure` (or `inputCalexp` + `inputBackground`). - The image plane contains the model flux, while the variance and - mask planes are typically empty or minimal unless specifically populated. - """ - wcs = inputCalexp.getWcs() - bbox = inputCalexp.getBBox() - warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) - - refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) - refCatBright.sort("mag") - zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) - spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] - pixCoords = wcs.skyToPixel(spherePoints) - - # Create image with background added back - # Using calibrated exposure for finding the scale factor and creating subtrator. - # The generated subtractor will be subtracted from PostISRCCd. - inputFixed = inputCalexp.getMaskedImage() - inputFixed += inputBackground.getImage() - inputCalexp.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) - # Associate detected footprints (from DETECTED plane) with the bright reference stars. - allFootprints, associations = self._associateFootprints(inputCalexp, pixCoords, plane="DETECTED") - - subtractorExp = ExposureF(bbox=bbox) - templateSubtractor = subtractorExp.maskedImage - - detector = inputCalexp.detector - pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds - pixToTan = detector.getTransform(PIXELS, TAN_PIXELS) - pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( - makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) - ) - - self.warpedDataDict = {} - removalIndices = [] - for j in range(self.config.min_iterations): - scaleList = [] - for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore - # Start with the background-added image for each star - inputMI = deepcopy(inputFixed) - restSubtractor = deepcopy(templateSubtractor) - myNumber = 0 - for key in self.warpedDataDict.keys(): - # Subtract the *current best models* of all *other* stars before fitting this one. - if self.warpedDataDict[key]["subtractor"] is not None and key != obj['id']: - restSubtractor.image += self.warpedDataDict[key]["subtractor"].image - myNumber += 1 - self.log.debug(f"Number of stars subtracted before finding the scale factor for {obj['id']}: ", myNumber) - inputMI.image -= restSubtractor.image - - footprintIndex = associations.get(starIndex, None) - - if footprintIndex: - neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) - else: - self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) - # Define linear shifting to recenter stamps - coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star - shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) - angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians - rotation = makeTransform(AffineTransform.makeRotation(-angle)) - pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) - rawStamp= self._getCutout(inputExposure=inputMI, coordPix=pixCoord, stampSize=self.config.stampSize.list()) - if rawStamp is None: - self.log.debug(f"No stamp for star with refID {obj['id']}") - removalIndices.append(starIndex) - continue - warpedStamp = self._warpRawStamp(rawStamp, warpingControl, pixToTan, pixCoord) - warpedModel = ImageF(warpedStamp.getBBox()) - inputExtendedPsfGeneral = deepcopy(inputExtendedPsf) - good_pixels = warpImage(warpedModel, inputExtendedPsfGeneral, pixToPolar.inverted(), warpingControl) - self.warpedDataDict[obj["id"]] = {"stamp": warpedStamp, "model": warpedModel, "starIndex": starIndex, "pixCoord": pixCoord} - if j == 0: - self.warpedDataDict[obj["id"]]["scale"] = None - self.warpedDataDict[obj["id"]]["subtractor"] = None - fitPsfResults = {} - if self.config.scalePsfModel: - psfNeg = warpedModel.array < 0 - self.modelScale = np.nanmean(warpedStamp.image.array) / np.nanmean(warpedModel.array[~psfNeg]) - warpedModel.array *= self.modelScale ######## model scale correction ######## - fitPsfResults = self._fitPsf( warpedStamp, warpedModel) - if fitPsfResults: - scaleList.append(fitPsfResults["scale"]) - self.warpedDataDict[obj["id"]]["scale"] = fitPsfResults["scale"] - - - cond = np.isnan(warpedModel.array) - warpedModel.array[cond] = 0 - warpedModel.array *= fitPsfResults["scale"] - overlapBBox = Box2I(warpedStamp.getBBox()) - overlapBBox.clip(inputCalexp.getBBox()) - - subtractor = deepcopy(templateSubtractor) - subtractor[overlapBBox] += warpedModel[overlapBBox] - self.warpedDataDict[obj["id"]]["subtractor"] = subtractor - - - else: - scaleList.append(np.nan) - if "subtractor" not in self.warpedDataDict[obj["id"]].keys(): - self.warpedDataDict[obj["id"]]["subtractor"] = None - self.warpedDataDict[obj["id"]]["scale"] = None - if j == 0: - refCatBright.remove_rows(removalIndices) - updatedPixCoords = [item for i, item in enumerate(pixCoords) if i not in removalIndices] - pixCoords = updatedPixCoords - new_scale_column = Column(scaleList, name=f'scale_0{j}') - # The following is handy when developing, not sure if we want to do that in the final version! - refCatBright.add_columns([new_scale_column]) - - subtractor = deepcopy(templateSubtractor) - for key in self.warpedDataDict.keys(): - if self.warpedDataDict[key]["scale"] is not None: - subtractor.image.array += self.warpedDataDict[key]["subtractor"].image.array - return subtractor - - def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: - """Get a bright star subset of the reference catalog. - - Trim the reference catalog to only those objects within the exposure - bounding box dilated by half the bright star stamp size. - This ensures all stars that overlap the exposure are included. - - Parameters - ---------- - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` - Loader to find objects within a reference catalog. - wcs : `~lsst.afw.geom.SkyWcs` - World coordinate system. - bbox : `~lsst.geom.Box2I` - Bounding box of the exposure. - - Returns - ------- - refCatBright : `~astropy.table.Table` - Bright star subset of the reference catalog. - """ - dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) - withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") - refCatFull = withinExposure.refCat - fluxField: str = withinExposure.fluxField - - brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) - - subsetStars = (refCatFull[fluxField] > brightFluxRange[0]) & ( - refCatFull[fluxField] < brightFluxRange[1] - ) - refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) - fluxNanojansky = refCatSubset[fluxField][:] * u.nJy # type: ignore - refCatSubset["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes - return refCatSubset - - def _associateFootprints( - self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str - ) -> tuple[list[Footprint], dict[int, int]]: - """Associate footprints from a given mask plane with specific objects. - - Footprints from the given mask plane are associated with objects at the - coordinates provided, where possible. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The input exposure with a mask plane. - pixCoords : `list` [`~lsst.geom.Point2D`] - The pixel coordinates of the objects. - plane : `str` - The mask plane used to identify masked pixels. - - Returns - ------- - footprints : `list` [`~lsst.afw.detection.Footprint`] - The footprints from the input exposure. - associations : `dict`[int, int] - Association indices between objects (key) and footprints (value). - """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) - footprints = footprintSet.getFootprints() - associations = {} - for starIndex, pixCoord in enumerate(pixCoords): - for footprintIndex, footprint in enumerate(footprints): - if footprint.contains(Point2I(pixCoord)): - associations[starIndex] = footprintIndex - break - self.log.debug( - "Associated %i of %i star%s to one each of the %i %s footprint%s.", - len(associations), - len(pixCoords), - "" if len(pixCoords) == 1 else "s", - len(footprints), - plane, - "" if len(footprints) == 1 else "s", - ) - return footprints, associations - - def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): - """Set footprints in a given mask plane. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The input exposure to modify. - footprints : `list` [`~lsst.afw.detection.Footprint`] - The footprints to set in the mask plane. - maskPlane : `str` - The mask plane to set the footprints in. - - Notes - ----- - This method modifies the ``inputExposure`` object in-place. - """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) - detThresholdValue = int(detThreshold.getValue()) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) - - # Wipe any existing footprints in the mask plane - inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) - - # Set the footprints in the mask plane - footprintSet.setFootprints(footprints) - footprintSet.setMask(inputExposure.mask, maskPlane) - - def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: - """Fit a scaled PSF and a pedestal to each bright star cutout. - - Parameters - ---------- - stampMI : `~lsst.afw.image.MaskedImageF` - The masked image of the bright star cutout. - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` - The PSF model to fit. - - Returns - ------- - fitPsfResults : `dict`[`str`, `float`] - The result of the PSF fitting, with keys: - - ``scale`` : `float` - The scale factor. - ``scaleErr`` : `float` - The error on the scale factor. - ``pedestal`` : `float` - The pedestal value. - ``pedestalErr`` : `float` - The error on the pedestal value. - ``pedestalScaleCov`` : `float` - The covariance between the pedestal and scale factor. - ``xGradient`` : `float` - The gradient in the x-direction. - ``yGradient`` : `float` - The gradient in the y-direction. - ``globalReducedChiSquared`` : `float` - The global reduced chi-squared goodness-of-fit. - ``globalDegreesOfFreedom`` : `int` - The global number of degrees of freedom. - ``psfReducedChiSquared`` : `float` - The PSF BBox reduced chi-squared goodness-of-fit. - ``psfDegreesOfFreedom`` : `int` - The PSF BBox number of degrees of freedom. - ``psfMaskedFluxFrac`` : `float` - The fraction of the PSF image flux masked by bad pixels. - """ - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) - - # Calculate the fraction of the PSF image flux masked by bad pixels - psfMaskedPixels = ImageF(psfImage.getBBox()) - psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) - # TODO: This is np.float64, else FITS metadata serialization fails - # Amir: what do we want to do for subtraction? we do not have the luxury of removing the star from the process here! - psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) - # if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: - # return {} # Handle cases where the PSF image is mostly masked - - # Create a padded version of the input constant PSF image - paddedPsfImage = ImageF(stampMI.getBBox()) - paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() - - # Create consistently masked data - mask = self.add_psf_mask(psfImage, stampMI) - badSpans = SpanSet.fromMask(mask, badMaskBitMask) - goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) - varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - if self.config.useMedianVariance: - varianceData = np.median(varianceData) - sigmaData = np.sqrt(varianceData) - imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B - imageData /= sigmaData - psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) - psfData /= sigmaData - - # Fit the PSF scale factor and global pedestal - nData = len(imageData) - coefficientMatrix = np.ones((nData, 4), dtype=float) # A - coefficientMatrix[:, 0] = psfData - coefficientMatrix[:, 1] /= sigmaData - coefficientMatrix[:, 2:] = goodSpans.indices().T - coefficientMatrix[:, 2] /= sigmaData - coefficientMatrix[:, 3] /= sigmaData - try: - solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) - covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C - except np.linalg.LinAlgError: - return {} # Handle singular matrix errors - if sumSquaredResiduals.size == 0: - return {} - scale = solutions[0] - if scale <= 0: - return {} # Handle cases where the PSF scale fit has failed - scaleErr = np.sqrt(covarianceMatrix[0, 0]) - pedestal = solutions[1] - pedestalErr = np.sqrt(covarianceMatrix[1, 1]) - scalePedestalCov = covarianceMatrix[0, 1] - xGradient = solutions[3] - yGradient = solutions[2] - - # Calculate global (whole image) reduced chi-squared - globalChiSquared = np.sum(sumSquaredResiduals) - globalDegreesOfFreedom = nData - 4 - globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom - - # Calculate PSF BBox reduced chi-squared - psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) - psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() - psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) - psfBBoxModel = ( - psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale - + pedestal - + psfBBoxGoodSpansX * xGradient - + psfBBoxGoodSpansY * yGradient - ) - psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance - psfBBoxChiSquared = np.sum(psfBBoxResiduals) - psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 - psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom - - return dict( - scale=scale, - scaleErr=scaleErr, - pedestal=pedestal, - pedestalErr=pedestalErr, - xGradient=xGradient, - yGradient=yGradient, - pedestalScaleCov=scalePedestalCov, - globalReducedChiSquared=globalReducedChiSquared, - globalDegreesOfFreedom=globalDegreesOfFreedom, - psfReducedChiSquared=psfBBoxReducedChiSquared, - psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, - psfMaskedFluxFrac=psfMaskedFluxFrac, - ) - - def add_psf_mask(self, psfImage, stampMI): - """Add problematic PSF pixels to the stamp's mask. - - Identifies pixels in the PSF model image that are NaN or negative - and sets the corresponding bits (hardcoded as plane 0, likely 'BAD') - in a copy of the input stamp's mask. - - Parameters - ---------- - psfImage : `~lsst.afw.image.ImageF` - PSF model image defined on the stamp grid. - stampMI : `~lsst.afw.image.MaskedImageF` - The masked image stamp of the star being fitted. Its mask is used - as the base. - - Returns - ------- - mask : `~lsst.afw.image.Mask` - A mask object based on the input stamp's mask, updated to include - masked pixels derived from the PSF model image. - """ - cond = np.isnan(psfImage.array) - cond |= psfImage.array < 0 - mask = deepcopy(stampMI.mask) - mask.array[cond] = np.bitwise_or(mask.array[cond], 1) - return mask - - - def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): - """Get a cutout from an input exposure, handling edge cases. - - Generate a cutout from an input exposure centered on a given position - and with a given size. - If any part of the cutout is outside the input exposure bounding box, - the cutout is padded with NaNs. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image to extract bright star stamps from. - coordPix : `~lsst.geom.Point2D` - Center of the cutout in pixel space. - stampSize : `list` [`int`] - Size of the cutout, in pixels. - - Returns - ------- - stamp : `~lsst.afw.image.ExposureF` or `None` - The cutout, or `None` if the cutout is entirely outside the input - exposure bounding box. - - Notes - ----- - This method is a short-term workaround until DM-40042 is implemented. - At that point, it should be replaced by a call to the Exposure method - ``getCutout``, which will handle edge cases automatically. - """ - corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2) - dimensions = Extent2I(stampSize) - stampBBox = Box2I(corner, dimensions) - overlapBBox = Box2I(stampBBox) - overlapBBox.clip(inputExposure.getBBox()) - if overlapBBox.getArea() > 0: - # Create full-sized stamp with pixels initially flagged as NO_DATA. - stamp = ExposureF(bbox=stampBBox) - stamp.image[:] = np.nan - stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) - # # Restore pixels which overlap the input exposure. - overlap = inputExposure.Factory(inputExposure, overlapBBox) - stamp.maskedImage[overlapBBox] = overlap - else: - stamp = None - return stamp - - def _warpRawStamp(self, rawStamp, warpingControl, pixToTan, pixCoord): - """Warps a raw image stamp onto a common tangent plane projection. - - Applies a transformation (`pixToTan`) followed by a shift - transform to warp the input `rawStamp` onto a destination - `MaskedImageF` aligned with a tangent plane centered near the object. - The shift aims to place the object center at the center of the - destination image. - - Parameters - ---------- - rawStamp : `~lsst.afw.image.ExposureF` - The raw cutout image stamp (e.g., from `_getCutout`). - warpingControl : `~lsst.afw.math.WarpingControl` - Configuration for the warping process. - pixToTan : `~lsst.afw.geom.Transform` - Transformation from the raw stamp's pixel coordinates to the - common tangent plane coordinates. - pixCoord : `~lsst.geom.Point2D` - Pixel coordinates of the object center in the original exposure, - used to calculate the centering shift. - - Returns - ------- - warped_stamp : `~lsst.afw.image.MaskedImageF` or `None` - The warped and shifted masked image, or None if warping failed - (e.g., due to insufficient good pixels). - """ - destImage = MaskedImageF(*self.config.stampSize) - bottomLeft = Point2D(rawStamp.getXY0()) - newBottomLeft = pixToTan.applyForward(bottomLeft) - newBottomLeft = Point2I(newBottomLeft) - destImage.setXY0(newBottomLeft) - # Define linear shifting to recenter stamps - newCenter = pixToTan.applyForward(pixCoord) - self.modelCenter = self.config.stampSize[0] // 2, self.config.stampSize[1] // 2 - shift = (self.modelCenter[0] + newBottomLeft[0] - newCenter[0], self.modelCenter[1] + newBottomLeft[1] - newCenter[1]) - affineShift = AffineTransform(shift) - shiftTransform = makeTransform(affineShift) - - # Define full transform (warp and shift) - starWarper = pixToTan.then(shiftTransform) - - # Apply it - goodPix = warpImage(destImage, rawStamp.getMaskedImage(), starWarper, warpingControl) - if not goodPix: - return None - return destImage - - # # Arbitrarily set origin of shifted star to 0 - # destImage.setXY0(0, 0) \ No newline at end of file diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py index 67b88d02f..88da41440 100644 --- a/tests/test_brightStarCutout.py +++ b/tests/test_brightStarCutout.py @@ -51,7 +51,8 @@ def setUp(self): psfArray = np.exp(-dist_from_center / 5) psfArray /= np.sum(psfArray) fixedKernel = FixedKernel(ImageD(psfArray)) - self.psf = KernelPsf(fixedKernel) + psf = KernelPsf(fixedKernel) + self.psf = psf.computeKernelImage(psf.getAveragePosition()) # Bring everything together to construct a stamp masked image stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient @@ -83,10 +84,10 @@ def test_fitPsf(self): self.stampMI, self.psf, ) - self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) - self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) - self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) - self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + assert abs(fitPsfResults["scale"] - self.scale) / self.scale < 1e-6 + assert abs(fitPsfResults["pedestal"] - self.pedestal) / self.pedestal < 1e-6 + assert abs(fitPsfResults["xGradient"] - self.xGradient) / self.xGradient < 1e-6 + assert abs(fitPsfResults["yGradient"] - self.yGradient) / self.yGradient < 1e-6 def setup_module(module): From 7f7391458977ed1b138636084f522115ea70578b Mon Sep 17 00:00:00 2001 From: Amir Ebadati Bazkiaei Date: Thu, 4 Dec 2025 10:18:41 +0000 Subject: [PATCH 5/5] cleaning the PR --- .../brightStarSubtraction/brightStarCutout.py | 646 ++++++++++-------- .../brightStarSubtraction/brightStarStack.py | 214 +++--- tests/test_brightStarCutout.py | 50 +- 3 files changed, 524 insertions(+), 386 deletions(-) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py index 9afda1f6f..f9ba03cf0 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -110,7 +110,7 @@ class BrightStarCutoutConnections( def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): super().__init__(config=config) assert config is not None - if not config.useExtendedPsf: + if not config.use_extended_psf: self.inputs.remove("extendedPsf") @@ -121,24 +121,26 @@ class BrightStarCutoutConfig( """Configuration parameters for BrightStarCutoutTask.""" # Star selection - magRange = ListField[float]( + mag_range = ListField[float]( doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", default=[0, 18], ) - excludeArcsecRadius = Field[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + exclude_arcsec_radius = Field[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", default=5, ) - excludeMagRange = ListField[float]( - doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + exclude_mag_range = ListField[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", default=[0, 20], ) - minAreaFraction = Field[float]( + min_area_fraction = Field[float]( doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", default=0.1, ) - badMaskPlanes = ListField[str]( - doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + bad_mask_planes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``min_area_fraction`` and, " "optionally, fitting of the PSF.", default=[ "BAD", @@ -152,15 +154,15 @@ class BrightStarCutoutConfig( NEIGHBOR_MASK_PLANE, ], ) - stampSize = ListField[int]( + stamp_size = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(251, 251), ) - stampSizePadding = Field[float]( + stamp_size_padding = Field[float]( doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", default=1.1, ) - warpingKernelName = ChoiceField[str]( + warping_kernel_name = ChoiceField[str]( doc="Warping kernel.", default="lanczos5", allowed={ @@ -170,7 +172,7 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - maskWarpingKernelName = ChoiceField[str]( + mask_warping_kernel_name = ChoiceField[str]( doc="Warping kernel for mask.", default="bilinear", allowed={ @@ -180,39 +182,36 @@ class BrightStarCutoutConfig( "lanczos5": "Lanczos kernel of order 5", }, ) - scalePsfModel = Field[bool]( - doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.", - default=True, + off_frame_mag_lim = Field[float]( + doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + default=15.0, ) # PSF Fitting - useExtendedPsf = Field[bool]( + use_extended_psf = Field[bool]( doc="Use the extended PSF model to normalize bright star cutouts.", default=False, ) - doFitPsf = Field[bool]( + do_fit_psf = Field[bool]( doc="Fit a scaled PSF and a pedestal to each bright star cutout.", default=True, ) - useMedianVariance = Field[bool]( + use_median_variance = Field[bool]( doc="Use the median of the variance plane for PSF fitting.", default=False, ) - psfMaskedFluxFracThreshold = Field[float]( + psf_masked_flux_frac_threshold = Field[float]( doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", default=0.97, ) - fitIterations = Field[int]( + fit_iterations = Field[int]( doc="Number of iterations over pedestal-gradient and scaling fit.", default=5, ) - offFrameMagLim = Field[float]( - doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", - default=15.0, - ) # Misc - loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + + load_reference_objects_config = ConfigField[LoadReferenceObjectsConfig]( doc="Reference object loader for astrometric calibration.", ) @@ -236,17 +235,17 @@ class BrightStarCutoutTask(PipelineTask): _DefaultName = "brightStarCutout" config: BrightStarCutoutConfig - def __init__(self, initInputs=None, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - stampSize = Extent2D(*self.config.stampSize.list()) - stampRadius = floor(stampSize / 2) - self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) - paddedStampSize = stampSize * self.config.stampSizePadding - self.paddedStampRadius = floor(paddedStampSize / 2) - self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( - self.paddedStampRadius + stamp_size = Extent2D(*self.config.stamp_size.list()) + stamp_radius = floor(stamp_size / 2) + self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius) + padded_stamp_size = stamp_size * self.config.stamp_size_padding + self.padded_stamp_radius = floor(padded_stamp_size / 2) + self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.padded_stamp_radius ) - self.modelScale = 1 + self.model_scale = 1 def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) @@ -255,7 +254,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], refCats=inputs.pop("refCat"), name=self.config.connections.refCat, - config=self.config.loadReferenceObjectsConfig, + config=self.config.load_reference_objects_config, ) extendedPsf = inputs.pop("extendedPsf", None) output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) @@ -282,6 +281,8 @@ def run( The background-subtracted image to extract bright star stamps. inputBackground : `~lsst.afw.math.BackgroundList` The background model associated with the input exposure. + extendedPsf: `~lsst.afw.image.ImageF` + The extended PSF model from previous iteration(s). refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional Loader to find objects within a reference catalog. dataId : `dict` or `~lsst.daf.butler.DataCoordinate` @@ -298,12 +299,14 @@ def run( """ wcs = inputExposure.getWcs() bbox = inputExposure.getBBox() - warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + warping_control = WarpingControl( + self.config.warping_kernel_name, self.config.mask_warping_kernel_name + ) - refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) - zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) - spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] - pixCoords = wcs.skyToPixel(spherePoints) + ref_cat_bright = self._get_ref_cat_bright(refObjLoader, wcs, bbox) + zip_ra_dec = zip(ref_cat_bright["coord_ra"] * radians, ref_cat_bright["coord_dec"] * radians) + sphere_points = [SpherePoint(ra, dec) for ra, dec in zip_ra_dec] + pix_coords = wcs.skyToPixel(sphere_points) # Restore original subtracted background inputMI = inputExposure.getMaskedImage() @@ -311,111 +314,106 @@ def run( # Set up NEIGHBOR mask plane; associate footprints with stars inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) - allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + all_footprints, associations = self._associate_footprints(inputExposure, pix_coords, plane="DETECTED") # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) # Set up transform detector = inputExposure.detector - pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds - pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( - makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + pixel_scale = wcs.getPixelScale().asArcseconds() * arcseconds + pix_to_focal_plane_tan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixel_scale.asRadians())) ) # Loop over each bright star - stamps, goodFracs, stamps_fitPsfResults = [], [], [] - for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + stamps, good_fracs, stamps_fit_psf_results = [], [], [] + for star_index, (obj, pix_coord) in enumerate(zip(ref_cat_bright, pix_coords)): # type: ignore # Excluding faint stars that are not within the frame. - if obj["mag"] > self.config.offFrameMagLim and not self.star_in_frame(pixCoord, bbox): + if obj["mag"] > self.config.off_frame_mag_lim and not self.star_in_frame(pix_coord, bbox): continue - footprintIndex = associations.get(starIndex, None) - stampMI = MaskedImageF(self.paddedStampBBox) + footprint_index = associations.get(star_index, None) + stampMI = MaskedImageF(self.padded_stamp_bbox) # Set NEIGHBOR footprints in the mask plane - if footprintIndex: - neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] - self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) + if footprint_index: + neighbor_footprints = [fp for i, fp in enumerate(all_footprints) if i != footprint_index] + self._set_footprints(inputMI, neighbor_footprints, NEIGHBOR_MASK_PLANE) else: - self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) + self._set_footprints(inputMI, all_footprints, NEIGHBOR_MASK_PLANE) # Define linear shifting to recenter stamps - coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star - shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) - angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + coord_focal_plane_tan = pix_to_focal_plane_tan.applyForward(pix_coord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coord_focal_plane_tan)) + angle = np.arctan2(coord_focal_plane_tan.getY(), coord_focal_plane_tan.getX()) * radians rotation = makeTransform(AffineTransform.makeRotation(-angle)) - pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + pix_to_polar = pix_to_focal_plane_tan.then(shift).then(rotation) # Apply the warp to the star stamp (in-place) - warpImage(stampMI, inputMI, pixToPolar, warpingControl) + warpImage(stampMI, inputMI, pix_to_polar, warping_control) # Trim to the base stamp size, check mask coverage, update metadata - stampMI = stampMI[self.stampBBox] - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) - goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size - goodFracs.append(goodFrac) - if goodFrac < self.config.minAreaFraction: + stampMI = stampMI[self.stamp_bbox] + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + good_frac = np.sum(stampMI.mask.array & bad_mask_bit_mask == 0) / stampMI.mask.array.size + good_fracs.append(good_frac) + if good_frac < self.config.min_area_fraction: continue # Fit a scaled PSF and a pedestal to each bright star cutout - psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) - constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) - if self.config.useExtendedPsf: - psfImage = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] + psf = WarpedPsf(inputExposure.getPsf(), pix_to_polar, warping_control) + constant_psf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + if self.config.use_extended_psf: + psf_image = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] else: - psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + psf_image = constant_psf.computeKernelImage(constant_psf.getAveragePosition()) # TODO: maybe we want to generate a smaller psf in case the following happens? - # The following could happen for when the user chooses small stampSize ~(50, 50) + # The following could happen for when the user chooses small stamp_size ~(50, 50) if ( - psfImage.array.shape[0] > stampMI.image.array.shape[0] - or psfImage.array.shape[1] > stampMI.image.array.shape[1] + psf_image.array.shape[0] > stampMI.image.array.shape[0] + or psf_image.array.shape[1] > stampMI.image.array.shape[1] ): continue # Computing an scale factor that brings the model to the similar level of the star. - self.computeModelScale(stampMI, psfImage) - psfImage.array *= self.modelScale # ####### model scale correction ######## + self.estimate_model_scale_value(stampMI, psf_image) + psf_image.array *= self.model_scale # ####### model scale correction ######## - fitPsfResults = {} + fit_psf_results = {} - if self.config.doFitPsf: - fitPsfResults = self._fitPsf(stampMI, psfImage) - stamps_fitPsfResults.append(fitPsfResults) + if self.config.do_fit_psf: + fit_psf_results = self._fit_psf(stampMI, psf_image) + stamps_fit_psf_results.append(fit_psf_results) # Save the stamp if the PSF fit was successful or no fit requested - if fitPsfResults or not self.config.doFitPsf: - distance_mm, theta_angle = self.star_location_on_focal(pixCoord, detector) + if fit_psf_results or not self.config.do_fit_psf: + distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) stamp = BrightStarStamp( stamp_im=stampMI, - psf=constantPsf, - wcs=makeModifiedWcs(pixToPolar, wcs, False), + psf=constant_psf, + wcs=makeModifiedWcs(pix_to_polar, wcs, False), visit=cast(int, dataId["visit"]), detector=cast(int, dataId["detector"]), ref_id=obj["id"], ref_mag=obj["mag"], - position=pixCoord, + position=pix_coord, focal_plane_radius=distance_mm, focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here - scale=fitPsfResults.get("scale", None), - scale_err=fitPsfResults.get("scaleErr", None), - pedestal=fitPsfResults.get("pedestal", None), - pedestal_err=fitPsfResults.get("pedestalErr", None), - pedestal_scale_cov=fitPsfResults.get("pedestalScaleCov", None), - gradient_x=fitPsfResults.get("xGradient", None), - gradient_y=fitPsfResults.get("yGradient", None), - global_reduced_chi_squared=fitPsfResults.get("globalReducedChiSquared", None), - global_degrees_of_freedom=fitPsfResults.get("globalDegreesOfFreedom", None), - psf_reduced_chi_squared=fitPsfResults.get("psfReducedChiSquared", None), - psf_degrees_of_freedom=fitPsfResults.get("psfDegreesOfFreedom", None), - psf_masked_flux_fraction=fitPsfResults.get("psfMaskedFluxFrac", None), - ) - print( - obj["mag"], - fitPsfResults.get("globalReducedChiSquared", None), - fitPsfResults.get("globalDegreesOfFreedom", None), - fitPsfResults.get("psfReducedChiSquared", None), - fitPsfResults.get("psfDegreesOfFreedom", None), - fitPsfResults.get("psfMaskedFluxFrac", None), + scale=fit_psf_results.get("scale", None), + scale_err=fit_psf_results.get("scale_err", None), + pedestal=fit_psf_results.get("pedestal", None), + pedestal_err=fit_psf_results.get("pedestal_err", None), + pedestal_scale_cov=fit_psf_results.get("pedestal_scale_cov", None), + gradient_x=fit_psf_results.get("x_gradient", None), + gradient_y=fit_psf_results.get("y_gradient", None), + curvature_x=fit_psf_results.get("curvature_x", None), + curvature_y=fit_psf_results.get("curvature_y", None), + cross_tilt=fit_psf_results.get("cross_tilt", None), + global_reduced_chi_squared=fit_psf_results.get("global_reduced_chi_squared", None), + global_degrees_of_freedom=fit_psf_results.get("global_degrees_of_freedom", None), + psf_reduced_chi_squared=fit_psf_results.get("psf_reduced_chi_squared", None), + psf_degrees_of_freedom=fit_psf_results.get("psf_degrees_of_freedom", None), + psf_masked_flux_fraction=fit_psf_results.get("psf_masked_flux_frac", None), ) stamps.append(stamp) @@ -424,38 +422,69 @@ def run( "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", len(stamps), "" if len(stamps) == 1 else "s", - len(refCatBright) - len(stamps), - "" if len(refCatBright) - len(stamps) == 1 else "s", - np.sum(np.array(goodFracs) < self.config.minAreaFraction), + len(ref_cat_bright) - len(stamps), + "" if len(ref_cat_bright) - len(stamps) == 1 else "s", + np.sum(np.array(good_fracs) < self.config.min_area_fraction), ( - np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) - if self.config.doFitPsf + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fit_psf_results])) + if self.config.do_fit_psf else 0 ), ) brightStarStamps = BrightStarStamps(stamps) return Struct(brightStarStamps=brightStarStamps) - def star_location_on_focal(self, pixCoord, detector): - star_focal_plane_coords = detector.transform(pixCoord, PIXELS, FOCAL_PLANE) + def star_location_on_focal(self, pix_coord, detector): + """ + Calculates the radial coordinates of a star on the focal plane. + + Transforms the given pixel coordinates to the focal plane and computes + the radial distance and angle relative to the optical axis. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) coordinates of the star on the + detector in pixels. + detector (Detector): `~lsst.afw.cameraGeom.Detector` + The detector object capable of transforming coordinates + from PIXELS to FOCAL_PLANE. + + Returns: + tuple: A tuple containing: + - distance_mm (float): The radial distance from the center in millimeters. + - theta_angle (Angle): The azimuthal angle of the star on the focal plane. + """ + star_focal_plane_coords = detector.transform(pix_coord, PIXELS, FOCAL_PLANE) star_x_fp = star_focal_plane_coords.getX() star_y_fp = star_focal_plane_coords.getY() - distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2) + distance_mm = np.sqrt(star_x_fp ** 2 + star_y_fp ** 2) theta_rad = math.atan2(star_y_fp, star_x_fp) theta_angle = Angle(theta_rad, radians) return distance_mm, theta_angle - def star_in_frame(self, pixCoord, inputExposureBBox): + def star_in_frame(self, pix_coord, inputExposureBBox): + """ + Checks if a star's pixel coordinates lie within the exposure boundaries. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) pixel coordinates of the star. + inputExposureBBox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns: + bool: True if the coordinates are within the frame limits, False otherwise. + """ if ( - pixCoord[0] < 0 - or pixCoord[1] < 0 - or pixCoord[0] > inputExposureBBox.getDimensions()[0] - or pixCoord[1] > inputExposureBBox.getDimensions()[1] + pix_coord[0] < 0 + or pix_coord[1] < 0 + or pix_coord[0] > inputExposureBBox.getDimensions()[0] + or pix_coord[1] > inputExposureBBox.getDimensions()[1] ): return False return True - def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + def _get_ref_cat_bright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: """Get a bright star subset of the reference catalog. Trim the reference catalog to only those objects within the exposure @@ -473,57 +502,59 @@ def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbo Returns ------- - refCatBright : `~astropy.table.Table` + ref_cat_bright : `~astropy.table.Table` Bright star subset of the reference catalog. """ - dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) - withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") - refCatFull = withinExposure.refCat - fluxField: str = withinExposure.fluxField + dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius) + within_exposure = refObjLoader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean") + ref_cat_full = within_exposure.refCat + flux_field: str = within_exposure.fluxField - proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) - brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + prox_flux_range = sorted(((self.config.exclude_mag_range * u.ABmag).to(u.nJy)).to_value()) + bright_flux_range = sorted(((self.config.mag_range * u.ABmag).to(u.nJy)).to_value()) - subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( - refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + subset_stars = (ref_cat_full[flux_field] > np.min((prox_flux_range[0], bright_flux_range[0]))) & ( + ref_cat_full[flux_field] < np.max((prox_flux_range[1], bright_flux_range[1])) + ) + ref_cat_subset = Table( + ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=subset_stars) ) - refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) - proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( - refCatSubset[fluxField] <= proxFluxRange[1] + prox_stars = (ref_cat_subset[flux_field] >= prox_flux_range[0]) & ( + ref_cat_subset[flux_field] <= prox_flux_range[1] ) - brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( - refCatSubset[fluxField] <= brightFluxRange[1] + bright_stars = (ref_cat_subset[flux_field] >= bright_flux_range[0]) & ( + ref_cat_subset[flux_field] <= bright_flux_range[1] ) - coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") - excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore - refCatBrightIsolated = [] - for coord in cast(Iterable[SkyCoord], coords[brightStars]): - neighbors = coords[proxStars] + coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad") + exclude_arcsec_radius = self.config.exclude_arcsec_radius * u.arcsec # type: ignore + ref_cat_bright_isolated = [] + for coord in cast(Iterable[SkyCoord], coords[bright_stars]): + neighbors = coords[prox_stars] seps = coord.separation(neighbors).to(u.arcsec) - tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched - refCatBrightIsolated.append(not tooClose.any()) + too_close = (seps > 0) & (seps <= exclude_arcsec_radius) # not self matched + ref_cat_bright_isolated.append(not too_close.any()) - refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + ref_cat_bright = cast(Table, ref_cat_subset[bright_stars][ref_cat_bright_isolated]) - fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore - refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + flux_nanojansky = ref_cat_bright[flux_field][:] * u.nJy # type: ignore + ref_cat_bright["mag"] = flux_nanojansky.to(u.ABmag).to_value() # AB magnitudes self.log.info( "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " "stars within %s arcsec.", - len(refCatBright), - len(refCatFull), - "" if len(refCatFull) == 1 else "s", - self.config.magRange, - self.config.excludeArcsecRadius, + len(ref_cat_bright), + len(ref_cat_full), + "" if len(ref_cat_full) == 1 else "s", + self.config.mag_range, + self.config.exclude_arcsec_radius, ) - return refCatBright + return ref_cat_bright - def _associateFootprints( - self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + def _associate_footprints( + self, inputExposure: ExposureF, pix_coords: list[Point2D], plane: str ) -> tuple[list[Footprint], dict[int, int]]: """Associate footprints from a given mask plane with specific objects. @@ -534,7 +565,7 @@ def _associateFootprints( ---------- inputExposure : `~lsst.afw.image.ExposureF` The input exposure with a mask plane. - pixCoords : `list` [`~lsst.geom.Point2D`] + pix_coords : `list` [`~lsst.geom.Point2D`] The pixel coordinates of the objects. plane : `str` The mask plane used to identify masked pixels. @@ -546,27 +577,27 @@ def _associateFootprints( associations : `dict`[int, int] Association indices between objects (key) and footprints (value). """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, det_threshold) footprints = footprintSet.getFootprints() associations = {} - for starIndex, pixCoord in enumerate(pixCoords): - for footprintIndex, footprint in enumerate(footprints): - if footprint.contains(Point2I(pixCoord)): - associations[starIndex] = footprintIndex + for star_index, pix_coord in enumerate(pix_coords): + for footprint_index, footprint in enumerate(footprints): + if footprint.contains(Point2I(pix_coord)): + associations[star_index] = footprint_index break self.log.debug( "Associated %i of %i star%s to one each of the %i %s footprint%s.", len(associations), - len(pixCoords), - "" if len(pixCoords) == 1 else "s", + len(pix_coords), + "" if len(pix_coords) == 1 else "s", len(footprints), plane, "" if len(footprints) == 1 else "s", ) return footprints, associations - def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + def _set_footprints(self, inputExposure: ExposureF, footprints: list, mask_plane: str): """Set footprints in a given mask plane. Parameters @@ -575,188 +606,217 @@ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: The input exposure to modify. footprints : `list` [`~lsst.afw.detection.Footprint`] The footprints to set in the mask plane. - maskPlane : `str` + mask_plane : `str` The mask plane to set the footprints in. Notes ----- This method modifies the ``inputExposure`` object in-place. """ - detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) - detThresholdValue = int(detThreshold.getValue()) - footprintSet = FootprintSet(inputExposure.mask, detThreshold) + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(mask_plane), Threshold.BITMASK) + det_threshold_value = int(det_threshold.getValue()) + footprint_set = FootprintSet(inputExposure.mask, det_threshold) # Wipe any existing footprints in the mask plane - inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + inputExposure.mask.clearMaskPlane(int(np.log2(det_threshold_value))) # Set the footprints in the mask plane - footprintSet.setFootprints(footprints) - footprintSet.setMask(inputExposure.mask, maskPlane) + footprint_set.setFootprints(footprints) + footprint_set.setMask(inputExposure.mask, mask_plane) - def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + def _fit_psf(self, stampMI: MaskedImageF, psf_image: ImageD | ImageF) -> dict[str, Any]: """Fit a scaled PSF and a pedestal to each bright star cutout. Parameters ---------- stampMI : `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF model to fit. Returns ------- - fitPsfResults : `dict`[`str`, `float`] + fit_psf_results : `dict`[`str`, `float`] The result of the PSF fitting, with keys: ``scale`` : `float` The scale factor. - ``scaleErr`` : `float` + ``scale_err`` : `float` The error on the scale factor. ``pedestal`` : `float` The pedestal value. - ``pedestalErr`` : `float` + ``pedestal_err`` : `float` The error on the pedestal value. - ``pedestalScaleCov`` : `float` + ``pedestal_scale_cov`` : `float` The covariance between the pedestal and scale factor. - ``xGradient`` : `float` + ``x_gradient`` : `float` The gradient in the x-direction. - ``yGradient`` : `float` + ``y_gradient`` : `float` The gradient in the y-direction. - ``globalReducedChiSquared`` : `float` + ``global_reduced_chi_squared`` : `float` The global reduced chi-squared goodness-of-fit. - ``globalDegreesOfFreedom`` : `int` + ``global_degrees_of_freedom`` : `int` The global number of degrees of freedom. - ``psfReducedChiSquared`` : `float` + ``psf_reduced_chi_squared`` : `float` The PSF BBox reduced chi-squared goodness-of-fit. - ``psfDegreesOfFreedom`` : `int` + ``psf_degrees_of_freedom`` : `int` The PSF BBox number of degrees of freedom. - ``psfMaskedFluxFrac`` : `float` + ``psf_masked_flux_frac`` : `float` The fraction of the PSF image flux masked by bad pixels. """ - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) # Calculate the fraction of the PSF image flux masked by bad pixels - psfMaskedPixels = ImageF(psfImage.getBBox()) - psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) - psfMaskedFluxFrac = ( - np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + psf_masked_pixels = ImageF(psf_image.getBBox()) + psf_masked_pixels.array[:, :] = (stampMI.mask[psf_image.getBBox()].array & bad_mask_bit_mask).astype( + bool + ) + psf_masked_flux_frac = ( + np.dot(psf_image.array.flat, psf_masked_pixels.array.flat).astype(np.float64) + / psf_image.array.sum() ) - if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + if psf_masked_flux_frac > self.config.psf_masked_flux_frac_threshold: return {} # Handle cases where the PSF image is mostly masked # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask). - gradientGoodSpans = self.generate_gradient_spans(stampMI, badMaskBitMask) - varianceData = gradientGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) - if self.config.useMedianVariance: - varianceData = np.median(varianceData) - sigmaData = np.sqrt(varianceData) + gradient_good_spans = self.generate_gradient_spans(stampMI, bad_mask_bit_mask) + variance_data = gradient_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data = np.median(variance_data) + sigma_data = np.sqrt(variance_data) - for i in range(self.config.fitIterations): + for i in range(self.config.fit_iterations): # Gradient-pedestal fitting: if i: # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can # remove the star using the scale factor. - stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821 + stamp = self.remove_star(stampMI, scale, padded_psf_image) # noqa: F821 else: stamp = deepcopy(stampMI.image.array) - imageDataGr = gradientGoodSpans.flatten(stamp, stampMI.getXY0()) / sigmaData # B - nData = len(imageDataGr) - coefficientMatrix = np.ones((nData, 3), dtype=float) # A - coefficientMatrix[:, 0] /= sigmaData - coefficientMatrix[:, 1:] = gradientGoodSpans.indices().T - coefficientMatrix[:, 1] /= sigmaData - coefficientMatrix[:, 2] /= sigmaData + image_data_gr = gradient_good_spans.flatten(stamp, stampMI.getXY0()) / sigma_data # B + n_data = len(image_data_gr) + + xy = gradient_good_spans.indices() + y = xy[0, :] + x = xy[1, :] + coefficient_matrix = np.ones((n_data, 6), dtype=float) # A + coefficient_matrix[:, 0] /= sigma_data + coefficient_matrix[:, 1] = y / sigma_data + coefficient_matrix[:, 2] = x / sigma_data + coefficient_matrix[:, 3] = y ** 2 / sigma_data + coefficient_matrix[:, 4] = x ** 2 / sigma_data + coefficient_matrix[:, 5] = x * y / sigma_data + # scikit might have a fitting tool try: - grSolutions, grSumSquaredResiduals, *_ = np.linalg.lstsq( - coefficientMatrix, imageDataGr, rcond=None + gr_solutions, gr_sum_squared_residuals, *_ = np.linalg.lstsq( + coefficient_matrix, image_data_gr, rcond=None ) - covarianceMatrix = np.linalg.inv( - np.dot(coefficientMatrix.transpose(), coefficientMatrix) + covariance_matrix = np.linalg.inv( + np.dot(coefficient_matrix.transpose(), coefficient_matrix) ) # C except np.linalg.LinAlgError: return {} # Handle singular matrix errors - if grSumSquaredResiduals.size == 0: + if gr_sum_squared_residuals.size == 0: return {} # Handle cases where sum of the squared residuals are empty - pedestal = grSolutions[0] - pedestalErr = np.sqrt(covarianceMatrix[0, 0]) - scalePedestalCov = None - xGradient = grSolutions[2] - yGradient = grSolutions[1] + pedestal = gr_solutions[0] + pedestal_err = np.sqrt(covariance_matrix[0, 0]) + pedestal_scale_cov = None + x_gradient = gr_solutions[2] + y_gradient = gr_solutions[1] + x_curvature = gr_solutions[4] + y_curvature = gr_solutions[3] + cross_tilt = gr_solutions[5] # Scale fitting: updatedStampMI = deepcopy(stampMI) - self._removePedestalAndGradient(updatedStampMI, pedestal, xGradient, yGradient) + self._removePedestalAndGradient( + updatedStampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, cross_tilt + ) # Create a padded version of the input constant PSF image - paddedPsfImage = ImageF(updatedStampMI.getBBox()) - paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + padded_psf_image = ImageF(updatedStampMI.getBBox()) + padded_psf_image[psf_image.getBBox()] = psf_image.convertF() # Generating a mask plane while considering bad pixels in the psf model. - mask = self.add_psf_mask(paddedPsfImage, updatedStampMI) + mask = self.add_psf_mask(padded_psf_image, updatedStampMI) # Create consistently masked data - scaleGoodSpans = self.generate_good_spans(mask, updatedStampMI.getBBox(), badMaskBitMask) + scale_good_spans = self.generate_good_spans(mask, updatedStampMI.getBBox(), bad_mask_bit_mask) - imageData = scaleGoodSpans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) - psfData = scaleGoodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) - scaleCoefficientMatrix = psfData.reshape(psfData.shape[0], 1) + variance_data_scale = scale_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data_scale = np.median(variance_data_scale) + sigma_data_scale = np.sqrt(variance_data_scale) + image_data = scale_good_spans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) + psf_data = scale_good_spans.flatten(padded_psf_image.array, padded_psf_image.getXY0()) + + image_data /= sigma_data_scale + psf_data /= sigma_data_scale + scale_coefficient_matrix = psf_data.reshape(psf_data.shape[0], 1) try: - scaleSolution, scaleSumSquaredResiduals, *_ = np.linalg.lstsq( - scaleCoefficientMatrix, imageData, rcond=None + scale_solution, scale_sum_squared_residuals, *_ = np.linalg.lstsq( + scale_coefficient_matrix, image_data, rcond=None ) except np.linalg.LinAlgError: return {} # Handle singular matrix errors - if scaleSumSquaredResiduals.size == 0: + if scale_sum_squared_residuals.size == 0: return {} # Handle cases where sum of the squared residuals are empty - scale = scaleSolution[0] + scale = scale_solution[0] if scale <= 0: return {} # Handle cases where the PSF scale fit has failed - # TODO: calculate scale error and store it. - scaleErr = None - scale *= self.modelScale # ####### model scale correction ######## - nData = len(imageData) + scale *= self.model_scale # ####### model scale correction ######## + n_data = len(image_data) + + scale_covariance_matrix = np.linalg.inv( + np.dot(scale_coefficient_matrix.transpose(), scale_coefficient_matrix) + ) # C + scale_err = scale_covariance_matrix[0].astype(float)[0] - # Calculate global (whole image) reduced chi-squared (scaling fit is assumed as the main fitting + # Calculate global (whole image) reduced chi-squared (scale fit is assumed as the main fitting # process here.) - globalChiSquared = np.sum(scaleSumSquaredResiduals) - globalDegreesOfFreedom = nData - 1 - globalReducedChiSquared = np.float64(globalChiSquared / globalDegreesOfFreedom) + global_chi_squared = np.sum(scale_sum_squared_residuals) + global_degrees_of_freedom = n_data - 1 + global_reduced_chi_squared = np.float64(global_chi_squared / global_degrees_of_freedom) # Calculate PSF BBox reduced chi-squared - psfBBoxscaleGoodSpans = scaleGoodSpans.clippedTo(psfImage.getBBox()) - psfBBoxscaleGoodSpansX, psfBBoxscaleGoodSpansY = psfBBoxscaleGoodSpans.indices() - psfBBoxData = psfBBoxscaleGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) - paddedPsfImage.array /= self.modelScale # ####### model scale correction ######## - psfBBoxModel = ( - psfBBoxscaleGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + psf_bbox_scale_good_spans = scale_good_spans.clippedTo(psf_image.getBBox()) + psf_bbox_scale_good_spans_x, psf_bbox_scale_good_spans_y = psf_bbox_scale_good_spans.indices() + psf_bbox_data = psf_bbox_scale_good_spans.flatten(stampMI.image.array, stampMI.getXY0()) + padded_psf_image.array /= self.model_scale # ####### model scale correction ######## + psf_bbox_model = ( + psf_bbox_scale_good_spans.flatten(padded_psf_image.array, stampMI.getXY0()) * scale + pedestal - + psfBBoxscaleGoodSpansX * xGradient - + psfBBoxscaleGoodSpansY * yGradient + + psf_bbox_scale_good_spans_x * x_gradient + + psf_bbox_scale_good_spans_y * y_gradient ) - psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 # / psfBBoxVariance - psfBBoxChiSquared = np.sum(psfBBoxResiduals) - psfBBoxDegreesOfFreedom = len(psfBBoxData) - 1 - psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + psf_bbox_residuals = (psf_bbox_data - psf_bbox_model) ** 2 # / psfBBoxVariance + psf_bbox_chi_squared = np.sum(psf_bbox_residuals) + psf_bbox_degrees_of_freedom = len(psf_bbox_data) - 1 + psf_bbox_reduced_chi_squared = psf_bbox_chi_squared / psf_bbox_degrees_of_freedom + return dict( scale=scale, - scaleErr=scaleErr, + scale_err=scale_err, pedestal=pedestal, - pedestalErr=pedestalErr, - xGradient=xGradient, - yGradient=yGradient, - pedestalScaleCov=scalePedestalCov, - globalReducedChiSquared=globalReducedChiSquared, - globalDegreesOfFreedom=globalDegreesOfFreedom, - psfReducedChiSquared=psfBBoxReducedChiSquared, - psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, - psfMaskedFluxFrac=psfMaskedFluxFrac, + pedestal_err=pedestal_err, + x_gradient=x_gradient, + y_gradient=y_gradient, + curvature_x=x_curvature, + curvature_y=y_curvature, + cross_tilt=cross_tilt, + pedestal_scale_cov=pedestal_scale_cov, + global_reduced_chi_squared=global_reduced_chi_squared, + global_degrees_of_freedom=global_degrees_of_freedom, + psf_reduced_chi_squared=psf_bbox_reduced_chi_squared, + psf_degrees_of_freedom=psf_bbox_degrees_of_freedom, + psf_masked_flux_frac=psf_masked_flux_frac, ) - def add_psf_mask(self, psfImage, stampMI, maskZeros=True): + def add_psf_mask(self, psf_image, stampMI, maskZeros=True): """ Creates a new mask by adding PSF bad pixels to an existing stamp mask. @@ -765,7 +825,7 @@ def add_psf_mask(self, psfImage, stampMI, maskZeros=True): of the input stamp's mask. Args: - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF image object. stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. @@ -775,26 +835,34 @@ def add_psf_mask(self, psfImage, stampMI, maskZeros=True): Returns: Any: A new mask object (deep copy) with the PSF mask planes added. """ - cond = np.isnan(psfImage.array) + cond = np.isnan(psf_image.array) if maskZeros: - cond |= psfImage.array <= 0 + cond |= psf_image.array <= 0 else: - cond |= psfImage.array < 0 + cond |= psf_image.array < 0 mask = deepcopy(stampMI.mask) mask.array[cond] = np.bitwise_or(mask.array[cond], 1) return mask - def _removePedestalAndGradient(self, stampMI, pedestal, xGradient, yGradient): + def _removePedestalAndGradient( + self, stampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, cross_tilt + ): """Apply fitted pedestal and gradients to a single bright star stamp.""" - stampBBox = stampMI.getBBox() - xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) - xPlane = ImageF((xGrid * xGradient).astype(np.float32), xy0=stampMI.getXY0()) - yPlane = ImageF((yGrid * yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stamp_bbox = stampMI.getBBox() + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + x_plane = ImageF((x_grid * x_gradient).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * y_gradient).astype(np.float32), xy0=stampMI.getXY0()) + x_curve = ImageF((x_grid ** 2 * x_curvature).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid ** 2 * y_curvature).astype(np.float32), xy0=stampMI.getXY0()) + cross_tilt = ImageF((x_grid * y_grid * cross_tilt).astype(np.float32), xy0=stampMI.getXY0()) stampMI -= pedestal - stampMI -= xPlane - stampMI -= yPlane + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= cross_tilt - def remove_star(self, stampMI, scale, psfImage): + def remove_star(self, stampMI, scale, psf_image): """ Subtracts a scaled PSF model from a star image. @@ -804,74 +872,76 @@ def remove_star(self, stampMI, scale, psfImage): stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. scale (float): The scaling factor to apply to the PSF. - psfImage: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The PSF image object. Returns: np.ndarray: A new 2D numpy array containing the star-subtracted image. """ - star_removed_cutout = stampMI.image.array - psfImage.array * scale + star_removed_cutout = stampMI.image.array - psf_image.array * scale return star_removed_cutout - def computeModelScale(self, stampMI, psfImage): + def estimate_model_scale_value(self, stampMI, psf_image): """ Computes the scaling factor of the given model against a star. Args: stampMI : `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` The given PSF model. """ cond = stampMI.mask.array == 0 - self.starMedianValue = np.median(stampMI.image.array[cond]).astype(np.float64) + self.star_median = np.median(stampMI.image.array[cond]).astype(np.float64) - psfPos = psfImage.array > 0 + psf_positives = psf_image.array > 0 - imageArray = stampMI.image.array - self.starMedianValue - imageArrayPos = imageArray > 0 - self.modelScale = np.nanmean(imageArray[imageArrayPos]) / np.nanmean(psfImage.array[psfPos]) + image_array = stampMI.image.array - self.star_median + image_array_positives = image_array > 0 + self.model_scale = np.nanmean(image_array[image_array_positives]) / np.nanmean( + psf_image.array[psf_positives] + ) - def generate_gradient_spans(self, stampMI, badMaskBitMask): + def generate_gradient_spans(self, stampMI, bad_mask_bit_mask): """ Generates spans of "good" pixels for gradient fitting. This method creates a combined bitmask by OR-ing the provided - `badMaskBitMask` with the "DETECTED" plane from the stamp's mask. + `bad_mask_bit_mask` with the "DETECTED" plane from the stamp's mask. It then calls `self.generate_good_spans` to find all pixel spans not covered by this combined mask. Args: stampMI: `~lsst.afw.image.MaskedImageF` The masked image of the bright star cutout. - badMaskBitMask (int): A bitmask representing planes to be + bad_mask_bit_mask (int): A bitmask representing planes to be considered "bad" for gradient fitting. Returns: - gradientGoodSpans: A SpanSet object containing the "good" spans. + gradient_good_spans: A SpanSet object containing the "good" spans. """ - detectedMaskBitMask = stampMI.mask.getPlaneBitMask("DETECTED") - gradientBitMask = np.bitwise_or(badMaskBitMask, detectedMaskBitMask) + bit_mask_detected = stampMI.mask.getPlaneBitMask("DETECTED") + gradient_bit_mask = np.bitwise_or(bad_mask_bit_mask, bit_mask_detected) - gradientGoodSpans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradientBitMask) - return gradientGoodSpans + gradient_good_spans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradient_bit_mask) + return gradient_good_spans - def generate_good_spans(self, mask, bBox, badBitMask): + def generate_good_spans(self, mask, bBox, bad_bit_mask): """ Generates a SpanSet of "good" pixels from a mask. This method identifies all spans within a given bounding box (`bBox`) - that are *not* flagged by the `badBitMask` in the provided `mask`. + that are *not* flagged by the `bad_bit_mask` in the provided `mask`. Args: mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`). bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`). - badBitMask (int): The combined bitmask of planes to exclude. + bad_bit_mask (int): The combined bitmask of planes to exclude. Returns: - goodSpans: A SpanSet object representing all "good" spans. + good_spans: A SpanSet object representing all "good" spans. """ - badSpans = SpanSet.fromMask(mask, badBitMask) - goodSpans = SpanSet(bBox).intersectNot(badSpans) - return goodSpans + bad_spans = SpanSet.fromMask(mask, bad_bit_mask) + good_spans = SpanSet(bBox).intersectNot(bad_spans) + return good_spans diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py index 1c2fcad4f..a28518803 100644 --- a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -52,7 +52,7 @@ class BrightStarStackConnections( ) extendedPsf = Output( name="extendedPsf2", # extendedPsfDetector ??? - storageClass="ImageF", # MaskedImageF + storageClass="ImageF", # stamp_imF doc="Extended PSF model, built from stacking bright star cutouts.", dimensions=("band",), ) @@ -64,20 +64,28 @@ class BrightStarStackConfig( ): """Configuration parameters for BrightStarStackTask.""" - subsetStampNumber = Field[int]( - doc="Number of stamps per subset to generate stacked images for.", - default=2, - ) - globalReducedChiSquaredThreshold = Field[float]( + global_reduced_chi_squared_threshold = Field[float]( doc="Threshold for global reduced chi-squared for bright star stamps.", default=5.0, ) - psfReducedChiSquaredThreshold = Field[float]( + psf_reduced_chi_squared_threshold = Field[float]( doc="Threshold for PSF reduced chi-squared for bright star stamps.", default=50.0, ) + bright_star_threshold = Field[float]( + doc="Stars brighter than this magnitude, are considered as bright stars.", + default=12.0, + ) + bright_global_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=250.0, + ) + psf_bright_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=400.0, + ) - badMaskPlanes = ListField[str]( + bad_mask_planes = ListField[str]( doc="Mask planes that identify excluded (masked) pixels.", default=[ "BAD", @@ -85,24 +93,43 @@ class BrightStarStackConfig( "CROSSTALK", "EDGE", "NO_DATA", - # "SAT", - # "SUSPECT", + "SAT", + "SUSPECT", "UNMASKEDNAN", NEIGHBOR_MASK_PLANE, ], ) - stackType = Field[str]( - default="MEANCLIP", + stack_type = Field[str]( + default="WEIGHTED_MEDIAN", doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", ) - stackNumSigmaClip = Field[float]( + stack_num_sigma_clip = Field[float]( doc="Number of sigma to use for clipping when stacking.", default=3.0, ) - stackNumIter = Field[int]( + stack_num_iter = Field[int]( doc="Number of iterations to use for clipping when stacking.", default=5, ) + magnitude_bins = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Bins of magnitudes for weighting purposes.", + default=[20, 19, 18, 17, 16, 15, 13, 10], + ) + subset_stamp_number = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Number of stamps per subset to generate stacked " + "images for. The length of this parameter must be equal to the length of magnitude_bins plus one.", + default=[300, 200, 150, 100, 100, 100, 1], + ) + min_focal_plane_radius = Field[float]( + doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than " + "this will be omitted.", + default=-1.0, + ) + max_focal_plane_radius = Field[float]( + doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than " + "this will be omitted.", + default=2000.0, + ) class BrightStarStackTask(PipelineTask): @@ -122,14 +149,24 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): def _applyStampFit(self, stamp): """Apply fitted stamp components to a single bright star stamp.""" - stampMI = stamp.maskedImage - stampBBox = stampMI.getBBox() - xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) - xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0()) - yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI = stamp.stamp_im + stamp_bbox = stampMI.getBBox() + + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + + x_plane = ImageF((x_grid * stamp.gradient_x).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * stamp.gradient_y).astype(np.float32), xy0=stampMI.getXY0()) + + x_curve = ImageF((x_grid**2 * stamp.curvature_x).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid**2 * stamp.curvature_y).astype(np.float32), xy0=stampMI.getXY0()) + cross_tilt = ImageF((x_grid * y_grid * stamp.cross_tilt).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= stamp.pedestal - stampMI -= xPlane - stampMI -= yPlane + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= cross_tilt stampMI /= stamp.scale @timeMethod @@ -164,75 +201,92 @@ def run( ``brightStarStamps`` (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) """ - stackTypeProperty = stringToStatisticsProperty(self.config.stackType) - statisticsControl = StatisticsControl( - numSigmaClip=self.config.stackNumSigmaClip, - numIter=self.config.stackNumIter, + if self.config.stack_type == "WEIGHTED_MEDIAN": + stack_type_property = stringToStatisticsProperty("MEDIAN") + else: + stack_type_property = stringToStatisticsProperty(self.config.stack_type) + statistics_control = StatisticsControl( + numSigmaClip=self.config.stack_num_sigma_clip, + numIter=self.config.stack_num_iter, ) - subsetStampMIs = [] - tempStampMIs = [] - all_stars = 0 - used_stars = 0 + mag_bins_dict = {} + subset_stampMIs = {} + self.metadata["psf_star_count"] = {} + self.metadata["psf_star_count"]["all"] = 0 + for i in range(len(self.config.subset_stamp_number)): + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] = 0 for stampsDDH in brightStarStamps: stamps = stampsDDH.get() - all_stars += len(stamps) + self.metadata["psf_star_count"]["all"] += len(stamps) for stamp in stamps: - if ( - stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold - or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold - ): - continue - stampMI = stamp.maskedImage - self._applyStampFit(stamp) - tempStampMIs.append(stampMI) - - badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) - statisticsControl.setAndMask(badMaskBitMask) - - # Amir: In case the total number of stamps is less than 20, the following will result in an - # empty subsetStampMIs list. - if len(tempStampMIs) == self.config.subsetStampNumber: - subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl)) - # TODO: what to do with remaining temp stamps? - tempStampMIs = [] - used_stars += self.config.subsetStampNumber - - self.metadata["psfStarCount"] = {} - self.metadata["psfStarCount"]["all"] = all_stars - self.metadata["psfStarCount"]["used"] = used_stars + if stamp.ref_mag >= self.config.bright_star_threshold: + global_reduced_chi_squared_threshold = self.config.global_reduced_chi_squared_threshold + psf_reduced_chi_squared_threshold = self.config.psf_reduced_chi_squared_threshold + else: + global_reduced_chi_squared_threshold = ( + self.config.bright_global_reduced_chi_squared_threshold + ) + psf_reduced_chi_squared_threshold = self.config.psf_bright_reduced_chi_squared_threshold + for i in range(len(self.config.subset_stamp_number)): + if ( + stamp.global_reduced_chi_squared > global_reduced_chi_squared_threshold + or stamp.psf_reduced_chi_squared > psf_reduced_chi_squared_threshold + or stamp.focal_plane_radius < self.config.min_focal_plane_radius + or stamp.focal_plane_radius > self.config.max_focal_plane_radius + ): + continue + + if ( + stamp.ref_mag < self.config.magnitude_bins[i] + and stamp.ref_mag > self.config.magnitude_bins[i + 1] + ): + if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys(): + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + stampMI = stamp.stamp_im + self._applyStampFit(stamp) + mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI) + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + if ( + len(mag_bins_dict[self.config.magnitude_bins[i + 1]]) + == self.config.subset_stamp_number[i] + ): + if self.config.magnitude_bins[i + 1] not in subset_stampMIs.keys(): + subset_stampMIs[self.config.magnitude_bins[i + 1]] = [] + subset_stampMIs[self.config.magnitude_bins[i + 1]].append( + statisticsStack( + mag_bins_dict[self.config.magnitude_bins[i + 1]], + stack_type_property, + statistics_control, + ) + ) + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] += len( + mag_bins_dict[self.config.magnitude_bins[i + 1]] + ) + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + + for key in mag_bins_dict.keys(): + if key not in subset_stampMIs.keys(): + subset_stampMIs[key] = [] + subset_stampMIs[key].append( + statisticsStack(mag_bins_dict[key], stack_type_property, statistics_control) + ) + self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key]) + # TODO: which stamp mask plane to use here? - # TODO: Amir: there might be cases where subsetStampMIs is an empty list. What do we want to do then? + # TODO: Amir: there might be cases where subset_stampMIs is an empty list. What do we want to do + # then? # Currently, we get an "IndexError: list index out of range" - badMaskBitMask = subsetStampMIs[0].mask.getPlaneBitMask(self.config.badMaskPlanes) - statisticsControl.setAndMask(badMaskBitMask) - extendedPsfMI = statisticsStack(subsetStampMIs, stackTypeProperty, statisticsControl) + final_subset_stampMIs = [] + for key in subset_stampMIs.keys(): + final_subset_stampMIs.extend(subset_stampMIs[key]) + badMaskBitMask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control) extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) extendedPsfMI.setXY0(extendedPsfOrigin) - # return Struct(extendedPsf=[extendedPsfMI]) return Struct(extendedPsf=extendedPsfMI.getImage()) - - # stack = [] - # chiStack = [] - # for loop over all groups: - # load up all visits for this detector - # drop all with GOF > thresh - # sigma-clip mean stack the rest - # append to stack - # compute the scatter (MAD/sigma-clipped var, etc) of the rest - # divide by sqrt(var plane), and append to chiStack - # after for-loop, combine images in median stack for final result - # also combine chi-images, save separately - - # idea: run with two different thresholds, and compare the results - - # medianStack = [] - # for loop over all groups: - # load up all visits for this detector - # drop all with GOF > thresh - # median/sigma-clip stack the rest - # append to medianStack - # after for-loop, combine images in median stack for final result diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py index 88da41440..0474480e3 100644 --- a/tests/test_brightStarCutout.py +++ b/tests/test_brightStarCutout.py @@ -37,26 +37,36 @@ def setUp(self): # Fit values self.scale = 2.34e5 self.pedestal = 3210.1 - self.xGradient = 5.432 - self.yGradient = 10.987 + self.x_gradient = 5.432 + self.y_gradient = 10.987 + self.curvature_x = 0.1 + self.curvature_y = -0.2 + self.cross_tilt = 1e-2 # Create a pedestal + 2D plane - xCoords = np.linspace(-50, 50, 101) - yCoords = np.linspace(-50, 50, 101) - xPlane, yPlane = np.meshgrid(xCoords, yCoords) - pedestal = np.ones_like(xPlane) * self.pedestal + x_coords = np.linspace(-50, 50, 101) + y_coords = np.linspace(-50, 50, 101) + x_plane, y_plane = np.meshgrid(x_coords, y_coords) + pedestal = np.ones_like(x_plane) * self.pedestal # Create a pseudo-PSF - dist_from_center = np.sqrt(xPlane**2 + yPlane**2) - psfArray = np.exp(-dist_from_center / 5) - psfArray /= np.sum(psfArray) - fixedKernel = FixedKernel(ImageD(psfArray)) - psf = KernelPsf(fixedKernel) + dist_from_center = np.sqrt(x_plane**2 + y_plane**2) + psf_array = np.exp(-dist_from_center / 5) + psf_array /= np.sum(psf_array) + fixed_kernel = FixedKernel(ImageD(psf_array)) + psf = KernelPsf(fixed_kernel) self.psf = psf.computeKernelImage(psf.getAveragePosition()) # Bring everything together to construct a stamp masked image - stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient - stampIm = ImageF((stampArray).astype(np.float32)) + stamp_array = ( + psf_array * self.scale + pedestal + x_plane * self.x_gradient + y_plane * self.y_gradient + ) + stamp_array += ( + x_plane**2 * self.curvature_x + + y_plane**2 * self.curvature_y + + x_plane * y_plane * self.cross_tilt + ) + stampIm = ImageF((stamp_array).astype(np.float32)) stampVa = ImageF(stampIm.getBBox(), 654.321) self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) self.stampMI.setXY0(Point2I(-50, -50)) @@ -80,14 +90,18 @@ def test_fitPsf(self): """Test the PSF fitting method.""" brightStarCutoutConfig = BrightStarCutoutConfig() brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) - fitPsfResults = brightStarCutoutTask._fitPsf( + fit_psf_results = brightStarCutoutTask._fit_psf( self.stampMI, self.psf, ) - assert abs(fitPsfResults["scale"] - self.scale) / self.scale < 1e-6 - assert abs(fitPsfResults["pedestal"] - self.pedestal) / self.pedestal < 1e-6 - assert abs(fitPsfResults["xGradient"] - self.xGradient) / self.xGradient < 1e-6 - assert abs(fitPsfResults["yGradient"] - self.yGradient) / self.yGradient < 1e-6 + + assert abs(fit_psf_results["scale"] - self.scale) / self.scale < 1e-3 + assert abs(fit_psf_results["pedestal"] - self.pedestal) / self.pedestal < 1e-3 + assert abs(fit_psf_results["x_gradient"] - self.x_gradient) / self.x_gradient < 1e-3 + assert abs(fit_psf_results["y_gradient"] - self.y_gradient) / self.y_gradient < 1e-3 + assert abs(fit_psf_results["curvature_x"] - self.curvature_x) / self.curvature_x < 1e-3 + assert abs(fit_psf_results["curvature_y"] - self.curvature_y) / self.curvature_y < 1e-3 + assert abs(fit_psf_results["cross_tilt"] - self.cross_tilt) / self.cross_tilt < 1e-3 def setup_module(module):