Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions auto_stretch/stretch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from astropy.io import fits
import numpy as np

"""
Expand All @@ -8,11 +7,11 @@

class Stretch:

def __init__(self, target_bkg=0.25, shadows_clip=-1.25):
def __init__(self, target_bkg:float=0.25, shadows_clip:float=-1.25) -> None:
self.shadows_clip = shadows_clip
self.target_bkg = target_bkg

def _get_avg_dev(self, data):
def _get_avg_dev(self, data: np.ndarray) -> float:
"""Return the average deviation from the median.

Args:
Expand All @@ -25,7 +24,7 @@ def _get_avg_dev(self, data):
return avg_dev


def _mtf(self, m, x):
def _mtf(self, m: float, x: np.ndarray) -> np.ndarray:
"""Midtones Transfer Function

MTF(m, x) = {
Expand Down Expand Up @@ -61,7 +60,7 @@ def _mtf(self, m, x):
return x.reshape(shape)


def _get_stretch_parameters(self, data):
def _get_stretch_parameters(self, data: np.ndarray) -> dict[str, float]:
""" Get the stretch parameters automatically.
m (float) is the midtones balance
c0 (float) is the shadows clipping point
Expand All @@ -80,7 +79,7 @@ def _get_stretch_parameters(self, data):
}


def stretch(self, data):
def stretch(self, data: np.ndarray) -> np.ndarray:
""" Stretch the image.

Args:
Expand Down Expand Up @@ -111,5 +110,5 @@ def stretch(self, data):
return d

# Wrapper function for simpler interface
def apply_stretch(data, target_bkg=0.25, shadows_clip=-1.25):
def apply_stretch(data: np.ndarray, target_bkg: float=0.25, shadows_clip: float=-1.25) -> np.ndarray:
return Stretch(target_bkg, shadows_clip).stretch(data)
39 changes: 39 additions & 0 deletions tests/TestFitsGen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generate a test fits file to ensure autostretch works in production

import numpy as np
from astropy.io import fits

# Parameters
ny, nx = 1024, 1024 # image size
n_stars = 50 # number of stars
seed = 42
output_file = "test_image.fits"

# Random seed for reproducibility
rng = np.random.default_rng(seed)

# Background
image = np.full((ny, nx), 1000.0, dtype=np.float32)

# Add some stars as Gaussian spots
y, x = np.mgrid[0:ny, 0:nx]
for _ in range(n_stars):
x0 = rng.uniform(0, nx)
y0 = rng.uniform(0, ny)
amp = rng.uniform(500, 5000)
sigma = rng.uniform(1.0, 2.5)
image += amp * np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))

# Add noise
image += rng.normal(0, 5, size=image.shape)

# Save to FITS
hdu = fits.PrimaryHDU(image)

hdu.header["NUMSTARS"] = n_stars
hdu.header["SEED"] = seed
hdu.header["HISTORY"] = "Created synthetic image with stars for auto_stretch test"

hdu.writeto(output_file, overwrite=True)

print(f"Created {output_file}")
Binary file added tests/test_image.fits
Binary file not shown.
16 changes: 15 additions & 1 deletion tests/test_stretch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
from astropy.io import fits
from auto_stretch import stretch
from auto_stretch import apply_stretch

Expand All @@ -16,4 +17,17 @@ def test_stretch_from_standalone():
stretched_image = apply_stretch(image)
print(f"Image: {image}")
print(f"Stretched image: {stretched_image}")
assert np.shape(stretched_image) == np.shape(image)
assert np.shape(stretched_image) == np.shape(image)

def test_stretch_with_fits_file():
file = fits.open("test_image.fits")
image = file[0].data
s = stretch.Stretch()
stretched_image = s.stretch(image)
print(f"Image: {image}")
print(f"Stretched image: {stretched_image}")
assert np.shape(stretched_image) == np.shape(image)
fits.writeto("stretched_test_image.fits", stretched_image, header=file[0].header, overwrite=True)

if __name__ == "__main__":
pytest.main([__file__])