Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ examples:
uv run python examples/write.py

.PHONY: test
test: pytest ruff examples
test: pytest mypy ruff examples


# Build targets (used from CI)
Expand Down
10 changes: 6 additions & 4 deletions core/src/splipy_core/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from numpy import floating, int_
from numpy import floating, int_, integer
from numpy.typing import NDArray

def snap_point(knots: NDArray[floating], eval_pt: float, tolerance: float) -> float: ...
def snap_points(knots: NDArray[floating], eval_pts: NDArray[floating], tolerance: float) -> None: ...
type Scalar = float | floating | int | integer

def snap_point(knots: NDArray[floating], eval_pt: Scalar, tolerance: Scalar) -> Scalar: ...
def snap_points(knots: NDArray[floating], eval_pts: NDArray[floating], tolerance: Scalar) -> None: ...
def evaluate(
knots: NDArray[floating],
order: int,
eval_pts: NDArray[floating],
periodic: int,
tolerance: float,
tolerance: Scalar,
d: int,
from_right: bool = True,
) -> tuple[
Expand Down
65 changes: 30 additions & 35 deletions src/splipy/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from scipy.sparse import csr_matrix

from . import state
from .utils import ensure_listlike_old
from .utils import ensure_listlike

if TYPE_CHECKING:
from .typing import ArrayLike, FloatArray, Scalar
from splipy.typing import Knots, Params

from .typing import FloatArray, Int, Scalar

__all__ = ["BSplineBasis"]

Expand All @@ -41,7 +43,7 @@ class BSplineBasis:
def __init__(
self,
order: int = 2,
knots: ArrayLike | None = None,
knots: Knots | None = None,
periodic: int = -1,
) -> None:
"""Construct a B-Spline basis with a given order and knot vector.
Expand Down Expand Up @@ -88,24 +90,24 @@ def num_functions(self) -> int:
.. warning:: This is different from :func:`splipy.BSplineBasis.__len__`."""
return len(self.knots) - self.order - (self.periodic + 1)

def start(self) -> float:
def start(self) -> Scalar:
"""Start point of parametric domain. For open knot vectors, this is the
first knot.

:return: Knot number *p*, where *p* is the spline order
:rtype: float
"""
return float(self.knots.flat[self.order - 1])
return self.knots.flat[self.order - 1]

def end(self) -> float:
def end(self) -> Scalar:
"""End point of parametric domain. For open knot vectors, this is the
last knot.

:return: Knot number *n*--*p*, where *p* is the spline order and *n* is
the number of knots
:rtype: Float
"""
return float(self.knots.flat[-self.order])
return self.knots.flat[-self.order]

def greville_all(self) -> FloatArray:
"""Fetch all greville points, also known as knot averages:
Expand All @@ -118,7 +120,7 @@ def greville_all(self) -> FloatArray:
operator = np.ones((self.order - 1,), dtype=np.float64) / (self.order - 1)
return np.convolve(self.knots[1 : -1 - (self.periodic + 1)], operator, mode="valid")

def greville_single(self, index: int) -> float:
def greville_single(self, index: Int) -> Scalar:
"""Fetch a greville point, also known as a knot averages:

.. math:: \\sum_{j=i+1}^{i+p-1} \\frac{t_j}{p-1}
Expand All @@ -128,15 +130,15 @@ def greville_single(self, index: int) -> float:
:return: A Greville point
:rtype: float
"""
return float(np.sum(self.knots[index + 1 : index + self.order]) / (self.order - 1))
return np.sum(self.knots[index + 1 : index + self.order]) / (self.order - 1)

@overload
def greville(self, index: int) -> float: ...
def greville(self, index: Int) -> Scalar: ...

@overload
def greville(self) -> FloatArray: ...

def greville(self, index: int | None = None) -> float | FloatArray:
def greville(self, index: Int | None = None) -> Scalar | FloatArray:
"""Fetch greville points, also known as knot averages:

.. math:: \\sum_{j=i+1}^{i+p-1} \\frac{t_j}{p-1}
Expand All @@ -151,15 +153,15 @@ def greville(self, index: int | None = None) -> float | FloatArray:
@overload
def evaluate(
self,
t: ArrayLike | Scalar,
t: Params | Scalar,
d: int = 0,
from_right: bool = ...,
) -> npt.NDArray[np.double]: ...

@overload
def evaluate(
self,
t: ArrayLike | Scalar,
t: Params | Scalar,
d: int = 0,
from_right: bool = ...,
sparse: Literal[False] = ...,
Expand All @@ -168,15 +170,15 @@ def evaluate(
@overload
def evaluate(
self,
t: ArrayLike | Scalar,
t: Params | Scalar,
d: int = 0,
from_right: bool = ...,
sparse: Literal[True] = ...,
) -> csr_matrix[np.float64]: ...

def evaluate(
self,
t: ArrayLike | Scalar,
t: Params | Scalar,
d: int = 0,
from_right: bool = True,
sparse: bool = False,
Expand Down Expand Up @@ -226,7 +228,7 @@ def evaluate_old(self, t, d=0, from_right=True, sparse=False): # type: ignore[n
:rtype: numpy.array
"""
# for single-value input, wrap it into a list so it don't crash on the loop below
t = ensure_listlike_old(t)
t = ensure_listlike(t)
self.snap(t)

p = self.order # knot vector order
Expand Down Expand Up @@ -264,20 +266,20 @@ def evaluate_old(self, t, d=0, from_right=True, sparse=False): # type: ignore[n
for j in range(p - q - 1, p):
k = mu - p + j # 'i'-index in global knot vector (ref Hughes book pg.21)
if j != p - q - 1:
M[j] = M[j] * float(evalT - self.knots[k]) / (self.knots[k + q] - self.knots[k])
M[j] = M[j] * (evalT - self.knots[k]) / (self.knots[k + q] - self.knots[k])

if j != p - 1:
M[j] = M[j] + M[j + 1] * float(self.knots[k + q + 1] - evalT) / (
M[j] = M[j] + M[j + 1] * (self.knots[k + q + 1] - evalT) / (
self.knots[k + q + 1] - self.knots[k + 1]
)

for q in range(p - d, p):
for j in range(p - q - 1, p):
k = mu - p + j # 'i'-index in global knot vector (ref Hughes book pg.21)
if j != p - q - 1:
M[j] = M[j] * float(q) / (self.knots[k + q] - self.knots[k])
M[j] = M[j] * q / (self.knots[k + q] - self.knots[k])
if j != p - 1:
M[j] = M[j] - M[j + 1] * float(q) / (self.knots[k + q + 1] - self.knots[k + 1])
M[j] = M[j] - M[j + 1] * q / (self.knots[k + q + 1] - self.knots[k + 1])

data[i * p : (i + 1) * p] = M
indices[i * p : (i + 1) * p] = np.arange(mu - p, mu) % n
Expand Down Expand Up @@ -332,9 +334,6 @@ def reparam(self, start: Scalar = 0, end: Scalar = 1) -> None:

:raises ValueError: If *end* ≤ *start*
"""
start = float(start)
end = float(end)

if end <= start:
raise ValueError("end must be larger than start")
self.normalize()
Expand All @@ -357,8 +356,6 @@ def knot_continuity(self, knot: Scalar) -> int:
knots.
:rtype: int or float
"""
knot = float(knot)

if self.periodic >= 0:
if knot < self.start() or knot > self.end():
knot = (knot - self.start()) % (self.end() - self.start()) + self.start()
Expand All @@ -375,7 +372,7 @@ def knot_continuity(self, knot: Scalar) -> int:
raise NotAKnotError
return self.order - (hi - lo) - 1

def continuity(self, knot: Scalar) -> int | float:
def continuity(self, knot: Scalar) -> int | Scalar:
"""Get the continuity of the basis functions at a given point.

:return: *p*--*m*--1 at a knot with multiplicity *m*, or ``inf``
Expand Down Expand Up @@ -510,8 +507,6 @@ def insert_knot(self, new_knot: Scalar) -> FloatArray:
:rtype: numpy.array
:raises ValueError: If the new knot is outside the domain
"""
new_knot = float(new_knot)

if self.periodic >= 0:
if new_knot < self.start() or new_knot > self.end():
new_knot = (new_knot - self.start()) % (self.end() - self.start()) + self.start()
Expand Down Expand Up @@ -598,7 +593,7 @@ def matches(self, bspline: BSplineBasis, reverse: bool = False) -> bool:
atol=state.knot_tolerance,
)

def snap_point(self, t: float) -> float:
def snap_point(self, t: Scalar) -> Scalar:
"""Snap evaluation point to knots if it is sufficiently close
as given in by state.state.knot_tolerance.

Expand Down Expand Up @@ -647,24 +642,24 @@ def __len__(self) -> int:
"""Returns the number of knots in this basis."""
return len(self.knots)

def __getitem__(self, i: int) -> float:
def __getitem__(self, i: int) -> Scalar:
"""Returns the knot at a given index."""
return float(self.knots[i])
return self.knots[i] # type: ignore[no-any-return]

def __iadd__(self, a: Scalar) -> Self:
self.knots += float(a)
self.knots += a
return self

def __isub__(self, a: Scalar) -> Self:
self.knots -= float(a)
self.knots -= a
return self

def __imul__(self, a: Scalar) -> Self:
self.knots *= float(a)
self.knots *= a
return self

def __itruediv__(self, a: Scalar) -> Self:
self.knots /= float(a)
self.knots /= a
return self

__ifloordiv__ = __itruediv__ # integer division (should not distinguish)
Expand Down
Loading