Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
pip install -e .[dev]
- name: Run tests
run: |
pytest
pytest --cov=src/aobasis --cov-fail-under=90
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ requires-python = ">=3.8"
[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov",
"imageio",
]

Expand Down
8 changes: 4 additions & 4 deletions src/aobasis/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
from .utils import plot_basis_modes

class BasisGenerator(ABC):
Expand Down Expand Up @@ -31,7 +31,7 @@ def generate(self, n_modes: int, **kwargs) -> np.ndarray:
"""
pass

def save(self, filepath: str | Path) -> None:
def save(self, filepath: Union[str, Path]) -> None:
"""
Save the generated basis and actuator positions to a .npz file.
"""
Expand All @@ -46,7 +46,7 @@ def save(self, filepath: str | Path) -> None:
)

@classmethod
def load(cls, filepath: str | Path) -> 'BasisGenerator':
def load(cls, filepath: Union[str, Path]) -> 'BasisGenerator':
"""
Load a basis from a .npz file.
Note: This returns a generic container or re-instantiates the specific class if possible.
Expand All @@ -62,7 +62,7 @@ def load(cls, filepath: str | Path) -> 'BasisGenerator':
instance.modes = modes
return instance

def plot(self, count: int = 6, outfile: Optional[str | Path] = None, **kwargs):
def plot(self, count: int = 6, outfile: Optional[Union[str, Path]] = None, **kwargs):
"""Plot the generated modes."""
if self.modes is None:
raise ValueError("No modes to plot.")
Expand Down
3 changes: 2 additions & 1 deletion src/aobasis/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union
import math
from scipy.interpolate import griddata

Expand All @@ -20,7 +21,7 @@ def plot_basis_modes(
modes: np.ndarray,
positions: np.ndarray,
count: int = 6,
outfile: Path | str | None = None,
outfile: Union[Path, str, None] = None,
cmap: str = "coolwarm",
title_prefix: str = "Mode",
interpolate: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion src/aobasis/zernike.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import math
from typing import Tuple
from .base import BasisGenerator

class ZernikeBasisGenerator(BasisGenerator):
Expand Down Expand Up @@ -62,7 +63,7 @@ def generate(self, n_modes: int, ignore_piston: bool = False, **kwargs) -> np.nd
self.modes = np.column_stack(modes_list)
return self.modes

def _noll_to_nm(self, j: int) -> tuple[int, int]:
def _noll_to_nm(self, j: int) -> Tuple[int, int]:
"""
Convert Noll index j to radial order n and azimuthal frequency m.
Based on Noll, J. Opt. Soc. Am. 66, 207 (1976).
Expand Down
100 changes: 100 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from aobasis.utils import make_circular_actuator_grid, make_concentric_actuator_grid, plot_basis_modes

def test_make_circular_actuator_grid():
diameter = 10.0
grid_size = 10
positions = make_circular_actuator_grid(diameter, grid_size)

assert isinstance(positions, np.ndarray)
assert positions.shape[1] == 2

# Check that all points are within the radius
radius = diameter / 2
distances = np.linalg.norm(positions, axis=1)
assert np.all(distances <= radius * 1.0000001)

def test_make_concentric_actuator_grid():
diameter = 10.0
n_rings = 3
n_points_innermost = 6
positions = make_concentric_actuator_grid(diameter, n_rings, n_points_innermost)

assert isinstance(positions, np.ndarray)
assert positions.shape[1] == 2

# Expected number of points: 1 (center) + 6*1 + 6*2 + 6*3 = 1 + 6 + 12 + 18 = 37
expected_points = 1 + sum(n_points_innermost * i for i in range(1, n_rings + 1))
assert positions.shape[0] == expected_points

@patch("aobasis.utils.plt")
def test_plot_basis_modes(mock_plt):
# Setup mock data
n_actuators = 20
n_modes = 5
positions = np.random.rand(n_actuators, 2)
modes = np.random.rand(n_actuators, n_modes)

# Configure mock to return a tuple
mock_fig = MagicMock()
mock_axes = MagicMock()
mock_plt.subplots.return_value = (mock_fig, mock_axes)

# Test basic plotting
plot_basis_modes(modes, positions, count=3)

assert mock_plt.subplots.called
assert mock_plt.show.called

@patch("aobasis.utils.plt")
def test_plot_basis_modes_save(mock_plt, tmp_path):
# Setup mock data
n_actuators = 20
n_modes = 5
positions = np.random.rand(n_actuators, 2)
modes = np.random.rand(n_actuators, n_modes)
outfile = tmp_path / "test_plot.png"

# Configure mock to return a tuple
mock_fig = MagicMock()
mock_axes = MagicMock()
mock_plt.subplots.return_value = (mock_fig, mock_axes)

# Test saving to file
plot_basis_modes(modes, positions, count=3, outfile=outfile)

assert mock_plt.subplots.called
mock_plt.savefig.assert_called_with(outfile, dpi=150)
assert mock_plt.close.called

@patch("aobasis.utils.plt")
def test_plot_basis_modes_interpolate(mock_plt):
# Setup mock data
n_actuators = 20
n_modes = 5
positions = np.random.rand(n_actuators, 2)
modes = np.random.rand(n_actuators, n_modes)

# Configure mock to return a tuple
mock_fig = MagicMock()
mock_axes = MagicMock()
mock_plt.subplots.return_value = (mock_fig, mock_axes)

# Test interpolation
plot_basis_modes(modes, positions, count=3, interpolate=True)

assert mock_plt.subplots.called
# We can't easily check if imshow was called on the axes objects without more complex mocking,
# but we can check that no errors were raised.

def test_plot_basis_modes_invalid_shape():
n_actuators = 20
n_modes = 5
positions = np.random.rand(n_actuators, 2)
modes = np.random.rand(n_actuators + 1, n_modes) # Mismatch

with pytest.raises(ValueError, match="Mode dimension 0"):
plot_basis_modes(modes, positions)