From c572beb0908f82cbdbe11a9007b66ed6cfe91888 Mon Sep 17 00:00:00 2001 From: jacotay7 Date: Wed, 19 Nov 2025 16:53:49 -1000 Subject: [PATCH 1/2] some typing thing --- src/aobasis/base.py | 8 ++++---- src/aobasis/utils.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/aobasis/base.py b/src/aobasis/base.py index c9a12bb..643458a 100644 --- a/src/aobasis/base.py +++ b/src/aobasis/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import numpy as np from pathlib import Path -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from .utils import plot_basis_modes class BasisGenerator(ABC): @@ -31,7 +31,7 @@ def generate(self, n_modes: int, **kwargs) -> np.ndarray: """ pass - def save(self, filepath: str | Path) -> None: + def save(self, filepath: Union[str, Path]) -> None: """ Save the generated basis and actuator positions to a .npz file. """ @@ -46,7 +46,7 @@ def save(self, filepath: str | Path) -> None: ) @classmethod - def load(cls, filepath: str | Path) -> 'BasisGenerator': + def load(cls, filepath: Union[str, Path]) -> 'BasisGenerator': """ Load a basis from a .npz file. Note: This returns a generic container or re-instantiates the specific class if possible. @@ -62,7 +62,7 @@ def load(cls, filepath: str | Path) -> 'BasisGenerator': instance.modes = modes return instance - def plot(self, count: int = 6, outfile: Optional[str | Path] = None, **kwargs): + def plot(self, count: int = 6, outfile: Optional[Union[str, Path]] = None, **kwargs): """Plot the generated modes.""" if self.modes is None: raise ValueError("No modes to plot.") diff --git a/src/aobasis/utils.py b/src/aobasis/utils.py index cc429b2..4cbe3f1 100644 --- a/src/aobasis/utils.py +++ b/src/aobasis/utils.py @@ -1,6 +1,7 @@ import numpy as np import matplotlib.pyplot as plt from pathlib import Path +from typing import Union import math from scipy.interpolate import griddata @@ -20,7 +21,7 @@ def plot_basis_modes( modes: np.ndarray, positions: np.ndarray, count: int = 6, - outfile: Path | str | None = None, + outfile: Union[Path, str, None] = None, cmap: str = "coolwarm", title_prefix: str = "Mode", interpolate: bool = False, From 74590c194dd58bf533a6f0c637375257ee859dcf Mon Sep 17 00:00:00 2001 From: jacotay7 Date: Wed, 19 Nov 2025 17:00:45 -1000 Subject: [PATCH 2/2] trying to fix for python 3.8 --- .github/workflows/ci.yml | 2 +- pyproject.toml | 1 + src/aobasis/zernike.py | 3 +- tests/test_utils.py | 100 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 tests/test_utils.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ea262a..d3d8f42 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,4 +25,4 @@ jobs: pip install -e .[dev] - name: Run tests run: | - pytest + pytest --cov=src/aobasis --cov-fail-under=90 diff --git a/pyproject.toml b/pyproject.toml index 33e0086..f987c5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ requires-python = ">=3.8" [project.optional-dependencies] dev = [ "pytest", + "pytest-cov", "imageio", ] diff --git a/src/aobasis/zernike.py b/src/aobasis/zernike.py index 5052c4a..38a324a 100644 --- a/src/aobasis/zernike.py +++ b/src/aobasis/zernike.py @@ -1,5 +1,6 @@ import numpy as np import math +from typing import Tuple from .base import BasisGenerator class ZernikeBasisGenerator(BasisGenerator): @@ -62,7 +63,7 @@ def generate(self, n_modes: int, ignore_piston: bool = False, **kwargs) -> np.nd self.modes = np.column_stack(modes_list) return self.modes - def _noll_to_nm(self, j: int) -> tuple[int, int]: + def _noll_to_nm(self, j: int) -> Tuple[int, int]: """ Convert Noll index j to radial order n and azimuthal frequency m. Based on Noll, J. Opt. Soc. Am. 66, 207 (1976). diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..70a53b4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock +from aobasis.utils import make_circular_actuator_grid, make_concentric_actuator_grid, plot_basis_modes + +def test_make_circular_actuator_grid(): + diameter = 10.0 + grid_size = 10 + positions = make_circular_actuator_grid(diameter, grid_size) + + assert isinstance(positions, np.ndarray) + assert positions.shape[1] == 2 + + # Check that all points are within the radius + radius = diameter / 2 + distances = np.linalg.norm(positions, axis=1) + assert np.all(distances <= radius * 1.0000001) + +def test_make_concentric_actuator_grid(): + diameter = 10.0 + n_rings = 3 + n_points_innermost = 6 + positions = make_concentric_actuator_grid(diameter, n_rings, n_points_innermost) + + assert isinstance(positions, np.ndarray) + assert positions.shape[1] == 2 + + # Expected number of points: 1 (center) + 6*1 + 6*2 + 6*3 = 1 + 6 + 12 + 18 = 37 + expected_points = 1 + sum(n_points_innermost * i for i in range(1, n_rings + 1)) + assert positions.shape[0] == expected_points + +@patch("aobasis.utils.plt") +def test_plot_basis_modes(mock_plt): + # Setup mock data + n_actuators = 20 + n_modes = 5 + positions = np.random.rand(n_actuators, 2) + modes = np.random.rand(n_actuators, n_modes) + + # Configure mock to return a tuple + mock_fig = MagicMock() + mock_axes = MagicMock() + mock_plt.subplots.return_value = (mock_fig, mock_axes) + + # Test basic plotting + plot_basis_modes(modes, positions, count=3) + + assert mock_plt.subplots.called + assert mock_plt.show.called + +@patch("aobasis.utils.plt") +def test_plot_basis_modes_save(mock_plt, tmp_path): + # Setup mock data + n_actuators = 20 + n_modes = 5 + positions = np.random.rand(n_actuators, 2) + modes = np.random.rand(n_actuators, n_modes) + outfile = tmp_path / "test_plot.png" + + # Configure mock to return a tuple + mock_fig = MagicMock() + mock_axes = MagicMock() + mock_plt.subplots.return_value = (mock_fig, mock_axes) + + # Test saving to file + plot_basis_modes(modes, positions, count=3, outfile=outfile) + + assert mock_plt.subplots.called + mock_plt.savefig.assert_called_with(outfile, dpi=150) + assert mock_plt.close.called + +@patch("aobasis.utils.plt") +def test_plot_basis_modes_interpolate(mock_plt): + # Setup mock data + n_actuators = 20 + n_modes = 5 + positions = np.random.rand(n_actuators, 2) + modes = np.random.rand(n_actuators, n_modes) + + # Configure mock to return a tuple + mock_fig = MagicMock() + mock_axes = MagicMock() + mock_plt.subplots.return_value = (mock_fig, mock_axes) + + # Test interpolation + plot_basis_modes(modes, positions, count=3, interpolate=True) + + assert mock_plt.subplots.called + # We can't easily check if imshow was called on the axes objects without more complex mocking, + # but we can check that no errors were raised. + +def test_plot_basis_modes_invalid_shape(): + n_actuators = 20 + n_modes = 5 + positions = np.random.rand(n_actuators, 2) + modes = np.random.rand(n_actuators + 1, n_modes) # Mismatch + + with pytest.raises(ValueError, match="Mode dimension 0"): + plot_basis_modes(modes, positions)