|
17 | 17 | from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload |
18 | 18 |
|
19 | 19 | import numpy as np |
20 | | -from typing_extensions import Self |
| 20 | +from typing_extensions import Self, TypedDict |
21 | 21 |
|
22 | 22 | from manim import config |
23 | 23 | from manim.constants import * |
|
55 | 55 | from manim.utils.space_ops import angle_of_vector |
56 | 56 |
|
57 | 57 | if TYPE_CHECKING: |
58 | | - import numpy.typing as npt |
59 | | - |
60 | 58 | from manim.mobject.mobject import Mobject |
61 | 59 | from manim.typing import ManimFloat, Point2D, Point3D, Vector3D |
62 | 60 |
|
63 | 61 | LineType = TypeVar("LineType", bound=Line) |
64 | 62 |
|
65 | 63 |
|
| 64 | +class _MatmulConfig(TypedDict): |
| 65 | + """A dictionary for configuring the __matmul__/__rmatmul__ operation. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + method: The method to call |
| 70 | + unpack: whether to unpack the parameter given to __matmul__/__rmatmul__ |
| 71 | + """ |
| 72 | + |
| 73 | + method: str |
| 74 | + unpack: bool |
| 75 | + |
| 76 | + |
66 | 77 | class CoordinateSystem: |
67 | 78 | r"""Abstract base class for Axes and NumberPlane. |
68 | 79 |
|
@@ -1793,20 +1804,29 @@ def construct(self): |
1793 | 1804 |
|
1794 | 1805 | return T_label_group |
1795 | 1806 |
|
1796 | | - _matmul_method = "coords_to_point" |
1797 | | - _rmatmul_method = "point_to_coords" |
| 1807 | + _matmul_config: _MatmulConfig = { |
| 1808 | + "method": "coords_to_point", |
| 1809 | + "unpack": True, |
| 1810 | + } |
| 1811 | + _rmatmul_config: _MatmulConfig = {"method": "point_to_coords", "unpack": False} |
1798 | 1812 |
|
1799 | | - def __matmul__(self, coord: Sequence[float] | Mobject | npt.NDArray[np.float64]): |
| 1813 | + def __matmul__(self, coord): |
1800 | 1814 | if isinstance(coord, Mobject): |
1801 | 1815 | coord = coord.get_center() |
1802 | | - method = getattr(self, self._matmul_method) |
| 1816 | + method = getattr(self, self._matmul_config["method"]) |
1803 | 1817 | assert callable(method) |
1804 | | - return method(*coord) |
| 1818 | + return ( |
| 1819 | + method(*coord) if self._matmul_config.get("unpack", True) else method(coord) |
| 1820 | + ) |
1805 | 1821 |
|
1806 | | - def __rmatmul__(self, point: Point3D): |
1807 | | - method = getattr(self, self._rmatmul_method) |
| 1822 | + def __rmatmul__(self, point): |
| 1823 | + method = getattr(self, self._rmatmul_config["method"]) |
1808 | 1824 | assert callable(method) |
1809 | | - return method(point) |
| 1825 | + return ( |
| 1826 | + method(*point) |
| 1827 | + if self._rmatmul_config.get("unpack", False) |
| 1828 | + else method(point) |
| 1829 | + ) |
1810 | 1830 |
|
1811 | 1831 |
|
1812 | 1832 | class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): |
@@ -2954,8 +2974,11 @@ def construct(self): |
2954 | 2974 | self.add(polarplane_pi) |
2955 | 2975 | """ |
2956 | 2976 |
|
2957 | | - _matmul_method = "polar_to_point" |
2958 | | - _rmatmul_method = "point_to_polar" |
| 2977 | + _matmul_config = { |
| 2978 | + "method": "polar_to_point", |
| 2979 | + "unpack": True, |
| 2980 | + } |
| 2981 | + _rmatmul_config = {"method": "point_to_polar", "unpack": False} |
2959 | 2982 |
|
2960 | 2983 | def __init__( |
2961 | 2984 | self, |
@@ -3327,6 +3350,10 @@ def construct(self): |
3327 | 3350 |
|
3328 | 3351 | """ |
3329 | 3352 |
|
| 3353 | + _matmul_config = {"method": "number_to_point", "unpack": False} |
| 3354 | + |
| 3355 | + _rmatmul_config = {"method": "point_to_number", "unpack": False} |
| 3356 | + |
3330 | 3357 | def __init__(self, **kwargs: Any) -> None: |
3331 | 3358 | super().__init__( |
3332 | 3359 | **kwargs, |
|
0 commit comments