Skip to content

Commit 610f736

Browse files
Fixes for ComplexPlane
1 parent 47b6a77 commit 610f736

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

manim/mobject/graphing/coordinate_systems.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
1818

1919
import numpy as np
20-
from typing_extensions import Self
20+
from typing_extensions import Self, TypedDict
2121

2222
from manim import config
2323
from manim.constants import *
@@ -55,14 +55,25 @@
5555
from manim.utils.space_ops import angle_of_vector
5656

5757
if TYPE_CHECKING:
58-
import numpy.typing as npt
59-
6058
from manim.mobject.mobject import Mobject
6159
from manim.typing import ManimFloat, Point2D, Point3D, Vector3D
6260

6361
LineType = TypeVar("LineType", bound=Line)
6462

6563

64+
class _MatmulConfig(TypedDict):
65+
"""A dictionary for configuring the __matmul__/__rmatmul__ operation.
66+
67+
Parameters
68+
----------
69+
method: The method to call
70+
unpack: whether to unpack the parameter given to __matmul__/__rmatmul__
71+
"""
72+
73+
method: str
74+
unpack: bool
75+
76+
6677
class CoordinateSystem:
6778
r"""Abstract base class for Axes and NumberPlane.
6879
@@ -1793,20 +1804,29 @@ def construct(self):
17931804

17941805
return T_label_group
17951806

1796-
_matmul_method = "coords_to_point"
1797-
_rmatmul_method = "point_to_coords"
1807+
_matmul_config: _MatmulConfig = {
1808+
"method": "coords_to_point",
1809+
"unpack": True,
1810+
}
1811+
_rmatmul_config: _MatmulConfig = {"method": "point_to_coords", "unpack": False}
17981812

1799-
def __matmul__(self, coord: Sequence[float] | Mobject | npt.NDArray[np.float64]):
1813+
def __matmul__(self, coord):
18001814
if isinstance(coord, Mobject):
18011815
coord = coord.get_center()
1802-
method = getattr(self, self._matmul_method)
1816+
method = getattr(self, self._matmul_config["method"])
18031817
assert callable(method)
1804-
return method(*coord)
1818+
return (
1819+
method(*coord) if self._matmul_config.get("unpack", True) else method(coord)
1820+
)
18051821

1806-
def __rmatmul__(self, point: Point3D):
1807-
method = getattr(self, self._rmatmul_method)
1822+
def __rmatmul__(self, point):
1823+
method = getattr(self, self._rmatmul_config["method"])
18081824
assert callable(method)
1809-
return method(point)
1825+
return (
1826+
method(*point)
1827+
if self._rmatmul_config.get("unpack", False)
1828+
else method(point)
1829+
)
18101830

18111831

18121832
class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL):
@@ -2954,8 +2974,11 @@ def construct(self):
29542974
self.add(polarplane_pi)
29552975
"""
29562976

2957-
_matmul_method = "polar_to_point"
2958-
_rmatmul_method = "point_to_polar"
2977+
_matmul_config = {
2978+
"method": "polar_to_point",
2979+
"unpack": True,
2980+
}
2981+
_rmatmul_config = {"method": "point_to_polar", "unpack": False}
29592982

29602983
def __init__(
29612984
self,
@@ -3327,6 +3350,10 @@ def construct(self):
33273350
33283351
"""
33293352

3353+
_matmul_config = {"method": "number_to_point", "unpack": False}
3354+
3355+
_rmatmul_config = {"method": "point_to_number", "unpack": False}
3356+
33303357
def __init__(self, **kwargs: Any) -> None:
33313358
super().__init__(
33323359
**kwargs,

tests/module/mobject/graphing/test_coordinate_system.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,13 @@ def test_matmul_operations():
202202
mob = Dot().move_to((1, 2, 0))
203203
assert (ax @ mob == ax.coords_to_point(1, 2)).all()
204204

205-
# other coordinate systems like PolarPlane should override __matmul__ indirectly
205+
# other coordinate systems like PolarPlane and ComplexPlane should override __matmul__ indirectly
206206
polar = PolarPlane()
207-
# radius, azimuthal angle
208207
assert (polar @ (1, 2) == polar.polar_to_point(1, 2)).all()
209208

209+
complx = ComplexPlane()
210+
assert (complx @ (1 + 2j) == complx.number_to_point(1 + 2j)).all()
211+
210212
# Numberline doesn't inherit from CoordinateSystem, but it should still work
211213
n = NumberLine()
212214
assert (n @ 3 == n.number_to_point(3)).all()

0 commit comments

Comments
 (0)