diff --git a/manim/mobject/geometry/shape_matchers.py b/manim/mobject/geometry/shape_matchers.py index b546dfb4f3..e72cf2775d 100644 --- a/manim/mobject/geometry/shape_matchers.py +++ b/manim/mobject/geometry/shape_matchers.py @@ -132,7 +132,7 @@ def set_style(self, fill_opacity: float, **kwargs: Any) -> Self: # type: ignore def get_fill_color(self) -> ManimColor: # The type of the color property is set to Any using the property decorator # vectorized_mobject.py#L571 - temp_color: ManimColor = self.color + temp_color: ManimColor = self.color # type: ignore[has-type] return temp_color diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index fa07c7fd53..7d2ab7dbde 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -147,7 +147,7 @@ def __init__( self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 - def coords_to_point(self, *coords: ManimFloat): + def coords_to_point(self, *coords: float): raise NotImplementedError() def point_to_coords(self, point: Point3D): @@ -1836,12 +1836,19 @@ def construct(self): return T_label_group - def __matmul__(self, coord: Point3D | Mobject): + def __matmul__(self, coord: Iterable[float] | Mobject): if isinstance(coord, Mobject): coord = coord.get_center() return self.coords_to_point(*coord) def __rmatmul__(self, point: Point3D): + """Perform a point-to-coords action for a coordinate scene. + + .. warning:: + + This will not work with NumPy arrays or other objects that + implement ``__matmul__``. + """ return self.point_to_coords(point) @@ -3327,6 +3334,12 @@ def get_radian_label(self, number, font_size: float = 24, **kwargs: Any) -> Math return MathTex(string, font_size=font_size, **kwargs) + def __matmul__(self, coord: Point2D): + return self.polar_to_point(*coord) + + def __rmatmul__(self, point: Point2D): + return self.point_to_polar(point) + class ComplexPlane(NumberPlane): """A :class:`~.NumberPlane` specialized for use with complex numbers. @@ -3399,6 +3412,12 @@ def p2n(self, point: Point3D) -> complex: """Abbreviation for :meth:`point_to_number`.""" return self.point_to_number(point) + def __matmul__(self, coord: float | complex): + return self.number_to_point(coord) + + def __rmatmul__(self, point: Point3D): + return self.point_to_number(point) + def _get_default_coordinate_values(self) -> list[float | complex]: """Generate a list containing the numerical values of the plane's labels. diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index 470d9d0074..06ecfd6409 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -3,6 +3,7 @@ import math import numpy as np +import numpy.testing as nt import pytest from manim import ( @@ -14,6 +15,7 @@ Circle, ComplexPlane, Dot, + NumberLine, NumberPlane, PolarPlane, ThreeDAxes, @@ -192,3 +194,46 @@ def test_input_to_graph_point(): # test the line_graph implementation position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4) np.testing.assert_array_equal(position, (2.6928, 1.2876, 0)) + + +def test_matmul_operations(): + ax = Axes() + nt.assert_equal(ax @ (1, 2), ax.coords_to_point(1, 2)) + # should work with mobjects too, using their center + mob = Dot().move_to((1, 2, 0)) + nt.assert_equal(ax @ mob, ax.coords_to_point(1, 2)) + + # other coordinate systems like PolarPlane and ComplexPlane should override __matmul__ indirectly + polar = PolarPlane() + nt.assert_equal(polar @ (1, 2), polar.polar_to_point(1, 2)) + + complx = ComplexPlane() + nt.assert_equal(complx @ (1 + 2j), complx.number_to_point(1 + 2j)) + + # Numberline doesn't inherit from CoordinateSystem, but it should still work + n = NumberLine() + nt.assert_equal(n @ 3, n.number_to_point(3)) + + +def test_rmatmul_operations(): + point = (1, 2, 0) + + ax = Axes() + nt.assert_equal(point @ ax, ax.point_to_coords(point)) + + polar = PolarPlane() + assert point @ polar == polar.point_to_polar(point) + + complx = ComplexPlane() + nt.assert_equal(point @ complx, complx.point_to_number(point)) + + n = NumberLine() + point = n @ 4 + + nt.assert_equal( + tuple(point) @ n, # ndarray overrides __matmul__ + n.point_to_number(point), + ) + + mob = Dot().move_to(point) + nt.assert_equal(mob @ n, n.point_to_number(mob.get_center())) diff --git a/tests/test_graphical_units/test_coordinate_systems.py b/tests/test_graphical_units/test_coordinate_systems.py index 7d6dad67af..2d9b5cd947 100644 --- a/tests/test_graphical_units/test_coordinate_systems.py +++ b/tests/test_graphical_units/test_coordinate_systems.py @@ -1,6 +1,20 @@ from __future__ import annotations -from manim import * +from manim import ( + BLUE, + GREEN, + ORANGE, + RED, + UL, + YELLOW, + Axes, + LogBase, + NumberPlane, + ThreeDAxes, + ThreeDScene, + VGroup, + np, +) from manim.utils.testing.frames_comparison import frames_comparison __module_test__ = "coordinate_system"