diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py new file mode 100644 index 000000000..fe5088369 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -0,0 +1,2 @@ +from .brightStarCutout import * +from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py new file mode 100644 index 000000000..f9ba03cf0 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -0,0 +1,947 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Extract bright star cutouts; normalize and warp, optionally fit the PSF.""" + +__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] + +from typing import Any, Iterable, cast + +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, FOCAL_PLANE +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, + Angle, +) +from lsst.meas.algorithms import ( + BrightStarStamp, + BrightStarStamps, + KernelPsf, + LoadReferenceObjectsConfig, + ReferenceObjectLoader, + WarpedPsf, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod +from copy import deepcopy +import math + + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarCutoutConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for BrightStarCutoutTask.""" + + refCat = PrerequisiteInput( + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + doc="Reference catalog that contains bright star positions.", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + inputExposure = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + extendedPsf = Input( + name="extendedPsf2", + storageClass="ImageF", + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + brightStarStamps = Output( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + ) + + def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): + super().__init__(config=config) + assert config is not None + if not config.use_extended_psf: + self.inputs.remove("extendedPsf") + + +class BrightStarCutoutConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarCutoutConnections, +): + """Configuration parameters for BrightStarCutoutTask.""" + + # Star selection + mag_range = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + exclude_arcsec_radius = Field[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", + default=5, + ) + exclude_mag_range = ListField[float]( + doc="Stars with a star in the range ``exclude_mag_range`` mag in ``exclude_arcsec_radius`` are not " + "used.", + default=[0, 20], + ) + min_area_fraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + bad_mask_planes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``min_area_fraction`` and, " + "optionally, fitting of the PSF.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + stamp_size = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stamp_size_padding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + warping_kernel_name = ChoiceField[str]( + doc="Warping kernel.", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + mask_warping_kernel_name = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + off_frame_mag_lim = Field[float]( + doc="Stars fainter than this limit are only included if they appear within the frame boundaries.", + default=15.0, + ) + + # PSF Fitting + use_extended_psf = Field[bool]( + doc="Use the extended PSF model to normalize bright star cutouts.", + default=False, + ) + do_fit_psf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + default=True, + ) + use_median_variance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psf_masked_flux_frac_threshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + fit_iterations = Field[int]( + doc="Number of iterations over pedestal-gradient and scaling fit.", + default=5, + ) + + # Misc + + load_reference_objects_config = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + + +class BrightStarCutoutTask(PipelineTask): + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + The BrightStarCutoutTask is used to extract, process, and store small image + cutouts (or "postage stamps") around bright stars. + This task essentially consists of three principal steps. + First, it identifies bright stars within an exposure using a reference + catalog and extracts a stamp around each. + Second, it shifts and warps each stamp to remove optical distortions and + sample all stars on the same pixel grid. + Finally, it optionally fits a PSF plus plane flux model to the cutout. + This final fitting procedure may be used to normalize each bright star + stamp prior to stacking when producing extended PSF models. + """ + + ConfigClass = BrightStarCutoutConfig + _DefaultName = "brightStarCutout" + config: BrightStarCutoutConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + stamp_size = Extent2D(*self.config.stamp_size.list()) + stamp_radius = floor(stamp_size / 2) + self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius) + padded_stamp_size = stamp_size * self.config.stamp_size_padding + self.padded_stamp_radius = floor(padded_stamp_size / 2) + self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.padded_stamp_radius + ) + self.model_scale = 1 + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.load_reference_objects_config, + ) + extendedPsf = inputs.pop("extendedPsf", None) + output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) + # Only ingest Stamp if it exists; prevents ingesting an empty FITS file + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputBackground: BackgroundList, + extendedPsf: ImageF | None, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + extendedPsf: `~lsst.afw.image.ImageF` + The extended PSF model from previous iteration(s). + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warping_control = WarpingControl( + self.config.warping_kernel_name, self.config.mask_warping_kernel_name + ) + + ref_cat_bright = self._get_ref_cat_bright(refObjLoader, wcs, bbox) + zip_ra_dec = zip(ref_cat_bright["coord_ra"] * radians, ref_cat_bright["coord_dec"] * radians) + sphere_points = [SpherePoint(ra, dec) for ra, dec in zip_ra_dec] + pix_coords = wcs.skyToPixel(sphere_points) + + # Restore original subtracted background + inputMI = inputExposure.getMaskedImage() + inputMI += inputBackground.getImage() + + # Set up NEIGHBOR mask plane; associate footprints with stars + inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + all_footprints, associations = self._associate_footprints(inputExposure, pix_coords, plane="DETECTED") + + # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here + inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) + + # Set up transform + detector = inputExposure.detector + pixel_scale = wcs.getPixelScale().asArcseconds() * arcseconds + pix_to_focal_plane_tan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixel_scale.asRadians())) + ) + + # Loop over each bright star + stamps, good_fracs, stamps_fit_psf_results = [], [], [] + for star_index, (obj, pix_coord) in enumerate(zip(ref_cat_bright, pix_coords)): # type: ignore + # Excluding faint stars that are not within the frame. + if obj["mag"] > self.config.off_frame_mag_lim and not self.star_in_frame(pix_coord, bbox): + continue + footprint_index = associations.get(star_index, None) + stampMI = MaskedImageF(self.padded_stamp_bbox) + + # Set NEIGHBOR footprints in the mask plane + if footprint_index: + neighbor_footprints = [fp for i, fp in enumerate(all_footprints) if i != footprint_index] + self._set_footprints(inputMI, neighbor_footprints, NEIGHBOR_MASK_PLANE) + else: + self._set_footprints(inputMI, all_footprints, NEIGHBOR_MASK_PLANE) + + # Define linear shifting to recenter stamps + coord_focal_plane_tan = pix_to_focal_plane_tan.applyForward(pix_coord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coord_focal_plane_tan)) + angle = np.arctan2(coord_focal_plane_tan.getY(), coord_focal_plane_tan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pix_to_polar = pix_to_focal_plane_tan.then(shift).then(rotation) + + # Apply the warp to the star stamp (in-place) + warpImage(stampMI, inputMI, pix_to_polar, warping_control) + + # Trim to the base stamp size, check mask coverage, update metadata + stampMI = stampMI[self.stamp_bbox] + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + good_frac = np.sum(stampMI.mask.array & bad_mask_bit_mask == 0) / stampMI.mask.array.size + good_fracs.append(good_frac) + if good_frac < self.config.min_area_fraction: + continue + + # Fit a scaled PSF and a pedestal to each bright star cutout + psf = WarpedPsf(inputExposure.getPsf(), pix_to_polar, warping_control) + constant_psf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + if self.config.use_extended_psf: + psf_image = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0] + else: + psf_image = constant_psf.computeKernelImage(constant_psf.getAveragePosition()) + # TODO: maybe we want to generate a smaller psf in case the following happens? + # The following could happen for when the user chooses small stamp_size ~(50, 50) + if ( + psf_image.array.shape[0] > stampMI.image.array.shape[0] + or psf_image.array.shape[1] > stampMI.image.array.shape[1] + ): + continue + # Computing an scale factor that brings the model to the similar level of the star. + self.estimate_model_scale_value(stampMI, psf_image) + psf_image.array *= self.model_scale # ####### model scale correction ######## + + fit_psf_results = {} + + if self.config.do_fit_psf: + fit_psf_results = self._fit_psf(stampMI, psf_image) + stamps_fit_psf_results.append(fit_psf_results) + + # Save the stamp if the PSF fit was successful or no fit requested + if fit_psf_results or not self.config.do_fit_psf: + distance_mm, theta_angle = self.star_location_on_focal(pix_coord, detector) + + stamp = BrightStarStamp( + stamp_im=stampMI, + psf=constant_psf, + wcs=makeModifiedWcs(pix_to_polar, wcs, False), + visit=cast(int, dataId["visit"]), + detector=cast(int, dataId["detector"]), + ref_id=obj["id"], + ref_mag=obj["mag"], + position=pix_coord, + focal_plane_radius=distance_mm, + focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here + scale=fit_psf_results.get("scale", None), + scale_err=fit_psf_results.get("scale_err", None), + pedestal=fit_psf_results.get("pedestal", None), + pedestal_err=fit_psf_results.get("pedestal_err", None), + pedestal_scale_cov=fit_psf_results.get("pedestal_scale_cov", None), + gradient_x=fit_psf_results.get("x_gradient", None), + gradient_y=fit_psf_results.get("y_gradient", None), + curvature_x=fit_psf_results.get("curvature_x", None), + curvature_y=fit_psf_results.get("curvature_y", None), + cross_tilt=fit_psf_results.get("cross_tilt", None), + global_reduced_chi_squared=fit_psf_results.get("global_reduced_chi_squared", None), + global_degrees_of_freedom=fit_psf_results.get("global_degrees_of_freedom", None), + psf_reduced_chi_squared=fit_psf_results.get("psf_reduced_chi_squared", None), + psf_degrees_of_freedom=fit_psf_results.get("psf_degrees_of_freedom", None), + psf_masked_flux_fraction=fit_psf_results.get("psf_masked_flux_frac", None), + ) + stamps.append(stamp) + + self.log.info( + "Extracted %i bright star stamp%s. " + "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", + len(stamps), + "" if len(stamps) == 1 else "s", + len(ref_cat_bright) - len(stamps), + "" if len(ref_cat_bright) - len(stamps) == 1 else "s", + np.sum(np.array(good_fracs) < self.config.min_area_fraction), + ( + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fit_psf_results])) + if self.config.do_fit_psf + else 0 + ), + ) + brightStarStamps = BrightStarStamps(stamps) + return Struct(brightStarStamps=brightStarStamps) + + def star_location_on_focal(self, pix_coord, detector): + """ + Calculates the radial coordinates of a star on the focal plane. + + Transforms the given pixel coordinates to the focal plane and computes + the radial distance and angle relative to the optical axis. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) coordinates of the star on the + detector in pixels. + detector (Detector): `~lsst.afw.cameraGeom.Detector` + The detector object capable of transforming coordinates + from PIXELS to FOCAL_PLANE. + + Returns: + tuple: A tuple containing: + - distance_mm (float): The radial distance from the center in millimeters. + - theta_angle (Angle): The azimuthal angle of the star on the focal plane. + """ + star_focal_plane_coords = detector.transform(pix_coord, PIXELS, FOCAL_PLANE) + star_x_fp = star_focal_plane_coords.getX() + star_y_fp = star_focal_plane_coords.getY() + distance_mm = np.sqrt(star_x_fp ** 2 + star_y_fp ** 2) + theta_rad = math.atan2(star_y_fp, star_x_fp) + theta_angle = Angle(theta_rad, radians) + return distance_mm, theta_angle + + def star_in_frame(self, pix_coord, inputExposureBBox): + """ + Checks if a star's pixel coordinates lie within the exposure boundaries. + + Args: + pix_coord: `~lsst.geom.Point2D` or tuple + The (x, y) pixel coordinates of the star. + inputExposureBBox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns: + bool: True if the coordinates are within the frame limits, False otherwise. + """ + if ( + pix_coord[0] < 0 + or pix_coord[1] < 0 + or pix_coord[0] > inputExposureBBox.getDimensions()[0] + or pix_coord[1] > inputExposureBBox.getDimensions()[1] + ): + return False + return True + + def _get_ref_cat_bright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + ref_cat_bright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius) + within_exposure = refObjLoader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean") + ref_cat_full = within_exposure.refCat + flux_field: str = within_exposure.fluxField + + prox_flux_range = sorted(((self.config.exclude_mag_range * u.ABmag).to(u.nJy)).to_value()) + bright_flux_range = sorted(((self.config.mag_range * u.ABmag).to(u.nJy)).to_value()) + + subset_stars = (ref_cat_full[flux_field] > np.min((prox_flux_range[0], bright_flux_range[0]))) & ( + ref_cat_full[flux_field] < np.max((prox_flux_range[1], bright_flux_range[1])) + ) + ref_cat_subset = Table( + ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=subset_stars) + ) + + prox_stars = (ref_cat_subset[flux_field] >= prox_flux_range[0]) & ( + ref_cat_subset[flux_field] <= prox_flux_range[1] + ) + bright_stars = (ref_cat_subset[flux_field] >= bright_flux_range[0]) & ( + ref_cat_subset[flux_field] <= bright_flux_range[1] + ) + + coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad") + exclude_arcsec_radius = self.config.exclude_arcsec_radius * u.arcsec # type: ignore + ref_cat_bright_isolated = [] + for coord in cast(Iterable[SkyCoord], coords[bright_stars]): + neighbors = coords[prox_stars] + seps = coord.separation(neighbors).to(u.arcsec) + too_close = (seps > 0) & (seps <= exclude_arcsec_radius) # not self matched + ref_cat_bright_isolated.append(not too_close.any()) + + ref_cat_bright = cast(Table, ref_cat_subset[bright_stars][ref_cat_bright_isolated]) + + flux_nanojansky = ref_cat_bright[flux_field][:] * u.nJy # type: ignore + ref_cat_bright["mag"] = flux_nanojansky.to(u.ABmag).to_value() # AB magnitudes + + self.log.info( + "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " + "stars within %s arcsec.", + len(ref_cat_bright), + len(ref_cat_full), + "" if len(ref_cat_full) == 1 else "s", + self.config.mag_range, + self.config.exclude_arcsec_radius, + ) + + return ref_cat_bright + + def _associate_footprints( + self, inputExposure: ExposureF, pix_coords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pix_coords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, det_threshold) + footprints = footprintSet.getFootprints() + associations = {} + for star_index, pix_coord in enumerate(pix_coords): + for footprint_index, footprint in enumerate(footprints): + if footprint.contains(Point2I(pix_coord)): + associations[star_index] = footprint_index + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pix_coords), + "" if len(pix_coords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _set_footprints(self, inputExposure: ExposureF, footprints: list, mask_plane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + mask_plane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + det_threshold = Threshold(inputExposure.mask.getPlaneBitMask(mask_plane), Threshold.BITMASK) + det_threshold_value = int(det_threshold.getValue()) + footprint_set = FootprintSet(inputExposure.mask, det_threshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(det_threshold_value))) + + # Set the footprints in the mask plane + footprint_set.setFootprints(footprints) + footprint_set.setMask(inputExposure.mask, mask_plane) + + def _fit_psf(self, stampMI: MaskedImageF, psf_image: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fit_psf_results : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scale_err`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestal_err`` : `float` + The error on the pedestal value. + ``pedestal_scale_cov`` : `float` + The covariance between the pedestal and scale factor. + ``x_gradient`` : `float` + The gradient in the x-direction. + ``y_gradient`` : `float` + The gradient in the y-direction. + ``global_reduced_chi_squared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``global_degrees_of_freedom`` : `int` + The global number of degrees of freedom. + ``psf_reduced_chi_squared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psf_degrees_of_freedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psf_masked_flux_frac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + bad_mask_bit_mask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psf_masked_pixels = ImageF(psf_image.getBBox()) + psf_masked_pixels.array[:, :] = (stampMI.mask[psf_image.getBBox()].array & bad_mask_bit_mask).astype( + bool + ) + psf_masked_flux_frac = ( + np.dot(psf_image.array.flat, psf_masked_pixels.array.flat).astype(np.float64) + / psf_image.array.sum() + ) + if psf_masked_flux_frac > self.config.psf_masked_flux_frac_threshold: + return {} # Handle cases where the PSF image is mostly masked + + # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask). + gradient_good_spans = self.generate_gradient_spans(stampMI, bad_mask_bit_mask) + variance_data = gradient_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data = np.median(variance_data) + sigma_data = np.sqrt(variance_data) + + for i in range(self.config.fit_iterations): + # Gradient-pedestal fitting: + if i: + # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can + # remove the star using the scale factor. + stamp = self.remove_star(stampMI, scale, padded_psf_image) # noqa: F821 + else: + stamp = deepcopy(stampMI.image.array) + + image_data_gr = gradient_good_spans.flatten(stamp, stampMI.getXY0()) / sigma_data # B + n_data = len(image_data_gr) + + xy = gradient_good_spans.indices() + y = xy[0, :] + x = xy[1, :] + coefficient_matrix = np.ones((n_data, 6), dtype=float) # A + coefficient_matrix[:, 0] /= sigma_data + coefficient_matrix[:, 1] = y / sigma_data + coefficient_matrix[:, 2] = x / sigma_data + coefficient_matrix[:, 3] = y ** 2 / sigma_data + coefficient_matrix[:, 4] = x ** 2 / sigma_data + coefficient_matrix[:, 5] = x * y / sigma_data + # scikit might have a fitting tool + + try: + gr_solutions, gr_sum_squared_residuals, *_ = np.linalg.lstsq( + coefficient_matrix, image_data_gr, rcond=None + ) + covariance_matrix = np.linalg.inv( + np.dot(coefficient_matrix.transpose(), coefficient_matrix) + ) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if gr_sum_squared_residuals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + + pedestal = gr_solutions[0] + pedestal_err = np.sqrt(covariance_matrix[0, 0]) + pedestal_scale_cov = None + x_gradient = gr_solutions[2] + y_gradient = gr_solutions[1] + x_curvature = gr_solutions[4] + y_curvature = gr_solutions[3] + cross_tilt = gr_solutions[5] + + # Scale fitting: + updatedStampMI = deepcopy(stampMI) + self._removePedestalAndGradient( + updatedStampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, cross_tilt + ) + + # Create a padded version of the input constant PSF image + padded_psf_image = ImageF(updatedStampMI.getBBox()) + padded_psf_image[psf_image.getBBox()] = psf_image.convertF() + + # Generating a mask plane while considering bad pixels in the psf model. + mask = self.add_psf_mask(padded_psf_image, updatedStampMI) + # Create consistently masked data + scale_good_spans = self.generate_good_spans(mask, updatedStampMI.getBBox(), bad_mask_bit_mask) + + variance_data_scale = scale_good_spans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.use_median_variance: + variance_data_scale = np.median(variance_data_scale) + sigma_data_scale = np.sqrt(variance_data_scale) + + image_data = scale_good_spans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0()) + psf_data = scale_good_spans.flatten(padded_psf_image.array, padded_psf_image.getXY0()) + + image_data /= sigma_data_scale + psf_data /= sigma_data_scale + scale_coefficient_matrix = psf_data.reshape(psf_data.shape[0], 1) + try: + scale_solution, scale_sum_squared_residuals, *_ = np.linalg.lstsq( + scale_coefficient_matrix, image_data, rcond=None + ) + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if scale_sum_squared_residuals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = scale_solution[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + + scale *= self.model_scale # ####### model scale correction ######## + n_data = len(image_data) + + scale_covariance_matrix = np.linalg.inv( + np.dot(scale_coefficient_matrix.transpose(), scale_coefficient_matrix) + ) # C + scale_err = scale_covariance_matrix[0].astype(float)[0] + + # Calculate global (whole image) reduced chi-squared (scale fit is assumed as the main fitting + # process here.) + global_chi_squared = np.sum(scale_sum_squared_residuals) + global_degrees_of_freedom = n_data - 1 + global_reduced_chi_squared = np.float64(global_chi_squared / global_degrees_of_freedom) + + # Calculate PSF BBox reduced chi-squared + psf_bbox_scale_good_spans = scale_good_spans.clippedTo(psf_image.getBBox()) + psf_bbox_scale_good_spans_x, psf_bbox_scale_good_spans_y = psf_bbox_scale_good_spans.indices() + psf_bbox_data = psf_bbox_scale_good_spans.flatten(stampMI.image.array, stampMI.getXY0()) + padded_psf_image.array /= self.model_scale # ####### model scale correction ######## + psf_bbox_model = ( + psf_bbox_scale_good_spans.flatten(padded_psf_image.array, stampMI.getXY0()) * scale + + pedestal + + psf_bbox_scale_good_spans_x * x_gradient + + psf_bbox_scale_good_spans_y * y_gradient + ) + psf_bbox_residuals = (psf_bbox_data - psf_bbox_model) ** 2 # / psfBBoxVariance + psf_bbox_chi_squared = np.sum(psf_bbox_residuals) + psf_bbox_degrees_of_freedom = len(psf_bbox_data) - 1 + psf_bbox_reduced_chi_squared = psf_bbox_chi_squared / psf_bbox_degrees_of_freedom + + return dict( + scale=scale, + scale_err=scale_err, + pedestal=pedestal, + pedestal_err=pedestal_err, + x_gradient=x_gradient, + y_gradient=y_gradient, + curvature_x=x_curvature, + curvature_y=y_curvature, + cross_tilt=cross_tilt, + pedestal_scale_cov=pedestal_scale_cov, + global_reduced_chi_squared=global_reduced_chi_squared, + global_degrees_of_freedom=global_degrees_of_freedom, + psf_reduced_chi_squared=psf_bbox_reduced_chi_squared, + psf_degrees_of_freedom=psf_bbox_degrees_of_freedom, + psf_masked_flux_frac=psf_masked_flux_frac, + ) + + def add_psf_mask(self, psf_image, stampMI, maskZeros=True): + """ + Creates a new mask by adding PSF bad pixels to an existing stamp mask. + + This method identifies "bad" pixels in the PSF image (NaNs and + optionally zeros/non-positives) and adds them to a deep copy + of the input stamp's mask. + + Args: + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + maskZeros (bool, optional): If True (default), mask pixels + where the PSF is <= 0. If False, only mask pixels < 0. + + Returns: + Any: A new mask object (deep copy) with the PSF mask planes added. + """ + cond = np.isnan(psf_image.array) + if maskZeros: + cond |= psf_image.array <= 0 + else: + cond |= psf_image.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask + + def _removePedestalAndGradient( + self, stampMI, pedestal, x_gradient, y_gradient, x_curvature, y_curvature, cross_tilt + ): + """Apply fitted pedestal and gradients to a single bright star stamp.""" + stamp_bbox = stampMI.getBBox() + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + x_plane = ImageF((x_grid * x_gradient).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * y_gradient).astype(np.float32), xy0=stampMI.getXY0()) + x_curve = ImageF((x_grid ** 2 * x_curvature).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid ** 2 * y_curvature).astype(np.float32), xy0=stampMI.getXY0()) + cross_tilt = ImageF((x_grid * y_grid * cross_tilt).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= pedestal + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= cross_tilt + + def remove_star(self, stampMI, scale, psf_image): + """ + Subtracts a scaled PSF model from a star image. + + This performs a simple subtraction: `image - (psf * scale)`. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + scale (float): The scaling factor to apply to the PSF. + psf_image: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF image object. + + Returns: + np.ndarray: A new 2D numpy array containing the star-subtracted + image. + """ + star_removed_cutout = stampMI.image.array - psf_image.array * scale + return star_removed_cutout + + def estimate_model_scale_value(self, stampMI, psf_image): + """ + Computes the scaling factor of the given model against a star. + + Args: + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psf_image : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The given PSF model. + """ + cond = stampMI.mask.array == 0 + self.star_median = np.median(stampMI.image.array[cond]).astype(np.float64) + + psf_positives = psf_image.array > 0 + + image_array = stampMI.image.array - self.star_median + image_array_positives = image_array > 0 + self.model_scale = np.nanmean(image_array[image_array_positives]) / np.nanmean( + psf_image.array[psf_positives] + ) + + def generate_gradient_spans(self, stampMI, bad_mask_bit_mask): + """ + Generates spans of "good" pixels for gradient fitting. + + This method creates a combined bitmask by OR-ing the provided + `bad_mask_bit_mask` with the "DETECTED" plane from the stamp's mask. + It then calls `self.generate_good_spans` to find all pixel spans + not covered by this combined mask. + + Args: + stampMI: `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + bad_mask_bit_mask (int): A bitmask representing planes to be + considered "bad" for gradient fitting. + + Returns: + gradient_good_spans: A SpanSet object containing the "good" spans. + """ + bit_mask_detected = stampMI.mask.getPlaneBitMask("DETECTED") + gradient_bit_mask = np.bitwise_or(bad_mask_bit_mask, bit_mask_detected) + + gradient_good_spans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradient_bit_mask) + return gradient_good_spans + + def generate_good_spans(self, mask, bBox, bad_bit_mask): + """ + Generates a SpanSet of "good" pixels from a mask. + + This method identifies all spans within a given bounding box (`bBox`) + that are *not* flagged by the `bad_bit_mask` in the provided `mask`. + + Args: + mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`). + bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`). + bad_bit_mask (int): The combined bitmask of planes to exclude. + + Returns: + good_spans: A SpanSet object representing all "good" spans. + """ + bad_spans = SpanSet.fromMask(mask, bad_bit_mask) + good_spans = SpanSet(bBox).intersectNot(bad_spans) + return good_spans diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py new file mode 100644 index 000000000..a28518803 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -0,0 +1,292 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" + +__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] + +import numpy as np +from lsst.afw.image import ImageF +from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty +from lsst.geom import Point2I +from lsst.meas.algorithms import BrightStarStamps +from lsst.pex.config import Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarStackConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), +): + """Connections for BrightStarStackTask.""" + + brightStarStamps = Input( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + multiple=True, + deferLoad=True, + ) + extendedPsf = Output( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # stamp_imF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + + +class BrightStarStackConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarStackConnections, +): + """Configuration parameters for BrightStarStackTask.""" + + global_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=5.0, + ) + psf_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=50.0, + ) + bright_star_threshold = Field[float]( + doc="Stars brighter than this magnitude, are considered as bright stars.", + default=12.0, + ) + bright_global_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=250.0, + ) + psf_bright_reduced_chi_squared_threshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=400.0, + ) + + bad_mask_planes = ListField[str]( + doc="Mask planes that identify excluded (masked) pixels.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + stack_type = Field[str]( + default="WEIGHTED_MEDIAN", + doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", + ) + stack_num_sigma_clip = Field[float]( + doc="Number of sigma to use for clipping when stacking.", + default=3.0, + ) + stack_num_iter = Field[int]( + doc="Number of iterations to use for clipping when stacking.", + default=5, + ) + magnitude_bins = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Bins of magnitudes for weighting purposes.", + default=[20, 19, 18, 17, 16, 15, 13, 10], + ) + subset_stamp_number = ListField[int]( + doc="Only used if stack_type == WEIGHTED_MEDIAN. Number of stamps per subset to generate stacked " + "images for. The length of this parameter must be equal to the length of magnitude_bins plus one.", + default=[300, 200, 150, 100, 100, 100, 1], + ) + min_focal_plane_radius = Field[float]( + doc="Minimum distance to focal plane center in mm. Stars with a focal plane radius smaller than " + "this will be omitted.", + default=-1.0, + ) + max_focal_plane_radius = Field[float]( + doc="Maximum distance to focal plane center in mm. Stars with a focal plane radius greater than " + "this will be omitted.", + default=2000.0, + ) + + +class BrightStarStackTask(PipelineTask): + """Stack bright star postage stamps to produce an extended PSF model.""" + + ConfigClass = BrightStarStackConfig + _DefaultName = "brightStarStack" + config: BrightStarStackConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + output = self.run(**inputs) + butlerQC.put(output, outputRefs) + + def _applyStampFit(self, stamp): + """Apply fitted stamp components to a single bright star stamp.""" + stampMI = stamp.stamp_im + stamp_bbox = stampMI.getBBox() + + x_grid, y_grid = np.meshgrid(stamp_bbox.getX().arange(), stamp_bbox.getY().arange()) + + x_plane = ImageF((x_grid * stamp.gradient_x).astype(np.float32), xy0=stampMI.getXY0()) + y_plane = ImageF((y_grid * stamp.gradient_y).astype(np.float32), xy0=stampMI.getXY0()) + + x_curve = ImageF((x_grid**2 * stamp.curvature_x).astype(np.float32), xy0=stampMI.getXY0()) + y_curve = ImageF((y_grid**2 * stamp.curvature_y).astype(np.float32), xy0=stampMI.getXY0()) + cross_tilt = ImageF((x_grid * y_grid * stamp.cross_tilt).astype(np.float32), xy0=stampMI.getXY0()) + + stampMI -= stamp.pedestal + stampMI -= x_plane + stampMI -= y_plane + stampMI -= x_curve + stampMI -= y_curve + stampMI -= cross_tilt + stampMI /= stamp.scale + + @timeMethod + def run( + self, + brightStarStamps: BrightStarStamps, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + inputBackground : `~lsst.afw.image.Background` + The background model for the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + if self.config.stack_type == "WEIGHTED_MEDIAN": + stack_type_property = stringToStatisticsProperty("MEDIAN") + else: + stack_type_property = stringToStatisticsProperty(self.config.stack_type) + statistics_control = StatisticsControl( + numSigmaClip=self.config.stack_num_sigma_clip, + numIter=self.config.stack_num_iter, + ) + + mag_bins_dict = {} + subset_stampMIs = {} + self.metadata["psf_star_count"] = {} + self.metadata["psf_star_count"]["all"] = 0 + for i in range(len(self.config.subset_stamp_number)): + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] = 0 + for stampsDDH in brightStarStamps: + stamps = stampsDDH.get() + self.metadata["psf_star_count"]["all"] += len(stamps) + for stamp in stamps: + if stamp.ref_mag >= self.config.bright_star_threshold: + global_reduced_chi_squared_threshold = self.config.global_reduced_chi_squared_threshold + psf_reduced_chi_squared_threshold = self.config.psf_reduced_chi_squared_threshold + else: + global_reduced_chi_squared_threshold = ( + self.config.bright_global_reduced_chi_squared_threshold + ) + psf_reduced_chi_squared_threshold = self.config.psf_bright_reduced_chi_squared_threshold + for i in range(len(self.config.subset_stamp_number)): + if ( + stamp.global_reduced_chi_squared > global_reduced_chi_squared_threshold + or stamp.psf_reduced_chi_squared > psf_reduced_chi_squared_threshold + or stamp.focal_plane_radius < self.config.min_focal_plane_radius + or stamp.focal_plane_radius > self.config.max_focal_plane_radius + ): + continue + + if ( + stamp.ref_mag < self.config.magnitude_bins[i] + and stamp.ref_mag > self.config.magnitude_bins[i + 1] + ): + if not self.config.magnitude_bins[i + 1] in mag_bins_dict.keys(): + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + stampMI = stamp.stamp_im + self._applyStampFit(stamp) + mag_bins_dict[self.config.magnitude_bins[i + 1]].append(stampMI) + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + if ( + len(mag_bins_dict[self.config.magnitude_bins[i + 1]]) + == self.config.subset_stamp_number[i] + ): + if self.config.magnitude_bins[i + 1] not in subset_stampMIs.keys(): + subset_stampMIs[self.config.magnitude_bins[i + 1]] = [] + subset_stampMIs[self.config.magnitude_bins[i + 1]].append( + statisticsStack( + mag_bins_dict[self.config.magnitude_bins[i + 1]], + stack_type_property, + statistics_control, + ) + ) + self.metadata["psf_star_count"][str(self.config.magnitude_bins[i + 1])] += len( + mag_bins_dict[self.config.magnitude_bins[i + 1]] + ) + mag_bins_dict[self.config.magnitude_bins[i + 1]] = [] + + for key in mag_bins_dict.keys(): + if key not in subset_stampMIs.keys(): + subset_stampMIs[key] = [] + subset_stampMIs[key].append( + statisticsStack(mag_bins_dict[key], stack_type_property, statistics_control) + ) + self.metadata["psf_star_count"][str(key)] += len(mag_bins_dict[key]) + + # TODO: which stamp mask plane to use here? + # TODO: Amir: there might be cases where subset_stampMIs is an empty list. What do we want to do + # then? + # Currently, we get an "IndexError: list index out of range" + final_subset_stampMIs = [] + for key in subset_stampMIs.keys(): + final_subset_stampMIs.extend(subset_stampMIs[key]) + badMaskBitMask = final_subset_stampMIs[0].mask.getPlaneBitMask(self.config.bad_mask_planes) + statistics_control.setAndMask(badMaskBitMask) + extendedPsfMI = statisticsStack(final_subset_stampMIs, stack_type_property, statistics_control) + + extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() + extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) + extendedPsfMI.setXY0(extendedPsfOrigin) + + return Struct(extendedPsf=extendedPsfMI.getImage()) diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py new file mode 100644 index 000000000..0474480e3 --- /dev/null +++ b/tests/test_brightStarCutout.py @@ -0,0 +1,117 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.cameraGeom.testUtils +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageD, ImageF, MaskedImageF +from lsst.afw.math import FixedKernel +from lsst.geom import Point2I +from lsst.meas.algorithms import KernelPsf +from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask + + +class BrightStarCutoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Fit values + self.scale = 2.34e5 + self.pedestal = 3210.1 + self.x_gradient = 5.432 + self.y_gradient = 10.987 + self.curvature_x = 0.1 + self.curvature_y = -0.2 + self.cross_tilt = 1e-2 + + # Create a pedestal + 2D plane + x_coords = np.linspace(-50, 50, 101) + y_coords = np.linspace(-50, 50, 101) + x_plane, y_plane = np.meshgrid(x_coords, y_coords) + pedestal = np.ones_like(x_plane) * self.pedestal + + # Create a pseudo-PSF + dist_from_center = np.sqrt(x_plane**2 + y_plane**2) + psf_array = np.exp(-dist_from_center / 5) + psf_array /= np.sum(psf_array) + fixed_kernel = FixedKernel(ImageD(psf_array)) + psf = KernelPsf(fixed_kernel) + self.psf = psf.computeKernelImage(psf.getAveragePosition()) + + # Bring everything together to construct a stamp masked image + stamp_array = ( + psf_array * self.scale + pedestal + x_plane * self.x_gradient + y_plane * self.y_gradient + ) + stamp_array += ( + x_plane**2 * self.curvature_x + + y_plane**2 * self.curvature_y + + x_plane * y_plane * self.cross_tilt + ) + stampIm = ImageF((stamp_array).astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 654.321) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + self.stampMI.setXY0(Point2I(-50, -50)) + + # Ensure that all mask planes required by the task are in-place; + # new mask plane entries will be created as necessary + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_fitPsf(self): + """Test the PSF fitting method.""" + brightStarCutoutConfig = BrightStarCutoutConfig() + brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) + fit_psf_results = brightStarCutoutTask._fit_psf( + self.stampMI, + self.psf, + ) + + assert abs(fit_psf_results["scale"] - self.scale) / self.scale < 1e-3 + assert abs(fit_psf_results["pedestal"] - self.pedestal) / self.pedestal < 1e-3 + assert abs(fit_psf_results["x_gradient"] - self.x_gradient) / self.x_gradient < 1e-3 + assert abs(fit_psf_results["y_gradient"] - self.y_gradient) / self.y_gradient < 1e-3 + assert abs(fit_psf_results["curvature_x"] - self.curvature_x) / self.curvature_x < 1e-3 + assert abs(fit_psf_results["curvature_y"] - self.curvature_y) / self.curvature_y < 1e-3 + assert abs(fit_psf_results["cross_tilt"] - self.cross_tilt) / self.cross_tilt < 1e-3 + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()