diff --git a/pyaml/apidoc/gen_api.py b/pyaml/apidoc/gen_api.py index 94431248..3a4cd81a 100644 --- a/pyaml/apidoc/gen_api.py +++ b/pyaml/apidoc/gen_api.py @@ -117,6 +117,11 @@ def generate_selective_module(m): for c in all_cls: file.write(f" .. autoclass:: {c.__name__}\n") file.write(" :members:\n") + if m in ["pyaml.arrays.element_array"]: + # Include special members for operator overloading + file.write( + " :special-members: __add__, __and__, __or__, __sub__ \n" + ) file.write(" :exclude-members: model_config\n") file.write(" :undoc-members:\n") file.write(" :show-inheritance:\n\n") diff --git a/pyaml/arrays/element.py b/pyaml/arrays/element.py index 1dd83adf..6a40d027 100644 --- a/pyaml/arrays/element.py +++ b/pyaml/arrays/element.py @@ -6,25 +6,25 @@ class ConfigModel(ArrayConfigModel): - """Configuration model for Element array.""" - - ... + """Configuration model for :py:class:`.ElementArray`.""" class Element(ArrayConfig): """ - Element array confirguration + :py:class:`.ElementArray` configuration. Example ------- An element array configuration can also be created by code using - the following example:: + the following example: + + .. code-block:: python from pyaml.arrays.element import Element,ConfigModel as ElementArrayConfigModel - elemArray = Element( - ElementArrayConfigModel(name="MyArray", elements=["elt1","elt2"]) - ) + elt_cfg = Element( + ElementArrayConfigModel(name="MyArray", elements=["BPM_C04-01","SH1A-C04-H"]) + ) """ @@ -34,7 +34,20 @@ def __init__(self, cfg: ArrayConfigModel): def fill_array(self, holder: ElementHolder): """ - Fill the element array in the element holder. + + Fill the :py:class:`.ElementArray` using element holder + (:py:class:`~pyaml.lattice.simulator.Simulator` + or :py:class:`~pyaml.control.controlsystem.ControlSystem`) + and add the array to the holder. This method is called when an + :py:class:`~pyaml.accelerator.Accelerator` is loaded but can be + used to create arrays by code as shown bellow: + + .. code-block:: python + + >>> elt_cfg.fill_array(sr.design) + >>> names = sr.design.get_elements("MyArray").names() + >>> print(names) + ['BPM_C04-01', 'SH1A-C04-H'] Parameters ---------- diff --git a/pyaml/arrays/element_array.py b/pyaml/arrays/element_array.py index 034d9307..d9415729 100644 --- a/pyaml/arrays/element_array.py +++ b/pyaml/arrays/element_array.py @@ -1,11 +1,15 @@ import fnmatch import importlib +from typing import Sequence + +import numpy as np from ..bpm.bpm import BPM from ..common.element import Element from ..common.exception import PyAMLException from ..magnet.cfm_magnet import CombinedFunctionMagnet from ..magnet.magnet import Magnet +from ..magnet.serialized_magnet import SerializedMagnets class ElementArray(list[Element]): @@ -14,7 +18,7 @@ class ElementArray(list[Element]): Parameters ---------- - arrayName : str + array_name : str Array name elements : list[Element] Element list, all elements must be attached to the same instance of @@ -26,19 +30,22 @@ class ElementArray(list[Element]): Example ------- - An array can be retrieved from the configuration as in the following example:: + An array can be retrieved from the configuration as in the following example: + + .. code-block:: python - sr = Accelerator.load("acc.yaml") - elts = sr.design.get_elemens("QuadForTune") + >>> sr = Accelerator.load("acc.yaml") + >>> elements = sr.design.get_elements("QuadForTune") """ - def __init__(self, arrayName: str, elements: list[Element], use_aggregator=True): + def __init__(self, array_name: str, elements: list[Element], use_aggregator=True): super().__init__(i for i in elements) - self.__name = arrayName + self.__name = array_name + self.__peer = None + self.__use_aggregator = use_aggregator if len(elements) > 0: self.__peer = self[0]._peer if len(self) > 0 else None - self.__use_aggretator = use_aggregator if self.__peer is None or any([m._peer != self.__peer for m in self]): raise PyAMLException( f"{self.__class__.__name__} {self.get_name()}: " @@ -48,7 +55,9 @@ def __init__(self, arrayName: str, elements: list[Element], use_aggregator=True) def get_peer(self): """ - Returns the peer (Simulator or ControlSystem) of an element list + Returns the peer (:py:class:`~pyaml.lattice.simulator.Simulator` + or :py:class:`~pyaml.control.controlsystem.ControlSystem`) of + an element list """ return self.__peer @@ -64,73 +73,416 @@ def names(self) -> list[str]: """ return [e.get_name() for e in self] - def __create_array(self, arrName: str, eltType: type, elements: list): - if len(elements) == 0: - return [] + def __create_array(self, array_name: str, element_type: type, elements: list): + if element_type is None: + element_type = Element - if issubclass(eltType, Magnet): + if issubclass(element_type, Magnet): m = importlib.import_module("pyaml.arrays.magnet_array") - arrayClass = getattr(m, "MagnetArray", None) - return arrayClass("", elements, self.__use_aggretator) - elif issubclass(eltType, BPM): + array_class = getattr(m, "MagnetArray", None) + return array_class(array_name, elements, self.__use_aggregator) + elif issubclass(element_type, BPM): m = importlib.import_module("pyaml.arrays.bpm_array") - arrayClass = getattr(m, "BPMArray", None) - return arrayClass("", elements, self.__use_aggretator) - elif issubclass(eltType, CombinedFunctionMagnet): + array_class = getattr(m, "BPMArray", None) + return array_class(array_name, elements, self.__use_aggregator) + elif issubclass(element_type, CombinedFunctionMagnet): m = importlib.import_module("pyaml.arrays.cfm_magnet_array") - arrayClass = getattr(m, "CombinedFunctionMagnetArray", None) - return arrayClass("", elements, self.__use_aggretator) - elif issubclass(eltType, Element): - return ElementArray("", elements, self.__use_aggretator) + array_class = getattr(m, "CombinedFunctionMagnetArray", None) + return array_class(array_name, elements, self.__use_aggregator) + elif issubclass(element_type, SerializedMagnets): + m = importlib.import_module("pyaml.arrays.serialized_magnet_array") + array_class = getattr(m, "SerializedMagnetsArray", None) + return array_class(array_name, elements, self.__use_aggregator) + elif issubclass(element_type, Element): + return ElementArray(array_name, elements, self.__use_aggregator) else: - raise PyAMLException(f"Unsupported sliced array for type {str(eltType)}") + raise PyAMLException( + f"Unsupported sliced array for type {str(element_type)}" + ) - def __eval_field(self, attName: str, e: Element) -> str: - funcName = "get_" + attName - func = getattr(e, funcName, None) + def __eval_field(self, attribute_name: str, element: Element) -> str: + function_name = "get_" + attribute_name + func = getattr(element, function_name, None) return func() if func is not None else "" + def __ensure_compatible_operand(self, other: object) -> "ElementArray": + """Validate the operand used for set-like operations between arrays.""" + if not isinstance(other, ElementArray): + raise TypeError( + f"Unsupported operand type(s) for set operation: " + f"'{type(self).__name__}' and '{type(other).__name__}'" + ) + + if len(self) > 0 and len(other) > 0: + if self.get_peer() is not None and other.get_peer() is not None: + if self.get_peer() != other.get_peer(): + raise PyAMLException( + f"{self.__class__.__name__}: cannot operate on arrays " + "attached to different peers" + ) + return other + + def __auto_array(self, elements: list[Element]): + """Create the most specific array type for the given element list. + + The target element type is the most specific common base class (nearest common + ancestor) of all elements. This supports heterogeneous subclasses (e.g., + several Magnet subclasses) while still returning a MagnetArray when + appropriate. + """ + if len(elements) == 0: + return [] + + import inspect + + def mro_as_list(cls: type) -> list[type]: + # inspect.getmro returns (cls, ..., object) + return list(inspect.getmro(cls)) + + # Start from the first element MRO as reference order (most specific first). + common: list[type] = mro_as_list(type(elements[0])) + + # Intersect while preserving MRO order from the first element. + for e in elements[1:]: + mro_set = set(mro_as_list(type(e))) + common = [c for c in common if c in mro_set] + if not common: + break + + # Pick the first suitable common base within the Element hierarchy. + chosen: type = Element + for c in common: + if c is object: + continue + if issubclass(c, Element): + chosen = c + break + + return self.__create_array("", chosen, elements) + + def __is_bool_mask(self, other: object) -> bool: + """Return True if 'other' looks like a boolean mask (list or numpy array).""" + # --- numpy boolean array --- + try: + if isinstance(other, np.ndarray) and other.dtype == bool: + return True + except Exception: + pass + + # --- python sequence of bools (but not a string/bytes) --- + if isinstance(other, Sequence) and not isinstance( + other, (str, bytes, bytearray) + ): + # Avoid treating ElementArray as a mask + if isinstance(other, ElementArray): + return False + # Accept only actual bool-like values + try: + return all(isinstance(v, bool) for v in other) + except TypeError: + return False + + return False + + def __and__(self, other: object): + """ + Intersection or boolean mask filtering. + + This operator has two distinct behaviors depending on the type of + ``other``. + + 1) Array intersection + If ``other`` is an ElementArray, the result contains elements + whose names are present in both arrays. + + Example + ------- + + .. code-block:: python + + >>> cell1 = sr.live.get_elements("C01") + >>> sexts = sr.live.get_magnets("SEXT") + >>> cell1_sext = cell1 & sexts + + 2) Boolean mask filtering + If ``other`` is a boolean mask (list[bool] or numpy.ndarray of bool), + elements are kept where the mask is True. + + Example + ------- + .. code-block:: python + + >>> mask = cell1.mask_by_type(Magnet) + >>> magnets = cell1 & mask + + Returns + ------- + Array + + The result is automatically typed according to the most specific + common base class of the remaining elements which can be: + :py:class:`.BPMArray` or :py:class:`.MagnetArray` or + :py:class:`.CombinedFunctionMagnetArray` or + :py:class:`.SerializedMagnetsArray` or + :py:class:`.ElementArray`. + + """ + # --- mask filtering --- + if self.__is_bool_mask(other): + mask = list(other) # works for list/tuple and numpy arrays + if len(mask) != len(self): + raise ValueError( + f"{self.__class__.__name__}: mask length ({len(mask)}) " + f"does not match array length ({len(self)})" + ) + res = [e for e, keep in zip(self, mask, strict=True) if bool(keep)] + return self.__auto_array(res) + + # --- array intersection --- + other_arr = self.__ensure_compatible_operand(other) + other_names = {e.get_name() for e in other_arr} + res = [e for e in self if e.get_name() in other_names] + return self.__auto_array(res) + + def __rand__(self, other: object): + # Support "array on the right" for array operands; for masks, we don't enforce + # commutativity. + if isinstance(other, ElementArray): + return other.__and__(self) + return NotImplemented + + def __sub__(self, other: object): + """ + Difference or boolean mask removal. + + This operator has two behaviors depending on the type of ``other``. + + 1) Array difference + If ``other`` is an ElementArray, the result contains elements + whose names are present in ``self`` but not in ``other``. + + Example + ------- + + .. code-block:: python + + >>> hvcorr = sr.live.get_magnets("HVCORR") + >>> hcorr = sr.live.get_magnets("HCORR") + >>> vcorr_only = hvcorr - hcorr + + 2) Boolean mask removal + If ``other`` is a boolean mask (list[bool] or numpy.ndarray of bool), + elements are removed where the mask is True. + This is the inverse of ``& mask``. + + Example + ------- + + .. code-block:: python + + >>> mask = cell1.mask_by_type(Magnet) + >>> non_magnets = cell1 - mask + + Returns + ------- + Array + + The result is automatically typed according to the most specific + common base class of the remaining elements which can be: + :py:class:`.BPMArray` or :py:class:`.MagnetArray` or + :py:class:`.CombinedFunctionMagnetArray` or + :py:class:`.SerializedMagnetsArray` or + :py:class:`.ElementArray`. + + """ + # --- mask removal --- + if self.__is_bool_mask(other): + mask = list(other) + if len(mask) != len(self): + raise ValueError( + f"{self.__class__.__name__}: mask length ({len(mask)}) " + f"does not match array length ({len(self)})" + ) + res = [e for e, remove in zip(self, mask, strict=True) if not bool(remove)] + return self.__auto_array(res) + + # --- array difference --- + other_arr = self.__ensure_compatible_operand(other) + other_names = {e.get_name() for e in other_arr} + res = [e for e in self if e.get_name() not in other_names] + return self.__auto_array(res) + + def __or__(self, other: object): + """ + Union between two ElementArray instances. + + Elements are combined using their names as identity. + Order is stable: elements from ``self`` first, followed by + elements from ``other`` that are not already present. + + Example + ------- + + .. code-block:: python + + >>> hcorr = sr.live.get_magnets("HCORR") + >>> vcorr = sr.live.get_magnets("VCORR") + >>> all_corr = hcorr | vcorr + + Returns + ------- + Array + + The result is automatically typed according to the most specific + common base class of the remaining elements which can be: + :py:class:`.BPMArray` or :py:class:`.MagnetArray` or + :py:class:`.CombinedFunctionMagnetArray` or + :py:class:`.SerializedMagnetsArray` or + :py:class:`.ElementArray`. + + """ + other_arr = self.__ensure_compatible_operand(other) + + seen: set[str] = set() + res: list[Element] = [] + + for e in self: + name = e.get_name() + if name not in seen: + res.append(e) + seen.add(name) + + for e in other_arr: + name = e.get_name() + if name not in seen: + res.append(e) + seen.add(name) + + return self.__auto_array(res) + + def __ror__(self, other: object): + if isinstance(other, ElementArray): + return other.__or__(self) + return NotImplemented + + def __add__(self, other: object): + """ + Alias for the union operator ``|``. + + Example + ------- + + .. code-block:: python + + >>> all_corr = hcorr + vcorr + + Returns + ------- + Array + + The result is automatically typed according to the most specific + common base class of the remaining elements which can be: + :py:class:`.BPMArray` or :py:class:`.MagnetArray` or + :py:class:`.CombinedFunctionMagnetArray` or + :py:class:`.SerializedMagnetsArray` or + :py:class:`.ElementArray`. + + """ + return self.__or__(other) + + def __radd__(self, other: object): + if isinstance(other, ElementArray): + return other.__add__(self) + return NotImplemented + + def mask_by_type(self, element_type: type) -> list[bool]: + """Return a boolean mask indicating which elements are instances of the given + type. + + Parameters + ---------- + element_type : type + The class to test against (e.g., Magnet). + + Returns + ------- + list[bool] + A list of booleans where True indicates the element is an instance + of the given type (including subclasses). + """ + if not isinstance(element_type, type): + raise TypeError(f"{self.__class__.__name__}: element_type must be a type") + + return [isinstance(e, element_type) for e in self] + + def of_type(self, element_type: type): + """Return a new array containing only elements of the given type. + + The resulting array is automatically typed according to the most + specific common base class of the filtered elements. + + Parameters + ---------- + element_type : type + The class to filter by (e.g., Magnet). + + Returns + ------- + ElementArray or specialized array + An auto-typed array containing only matching elements. + Returns [] if no elements match. + """ + if not isinstance(element_type, type): + raise TypeError(f"{self.__class__.__name__}: element_type must be a type") + + filtered = [e for e in self if isinstance(e, element_type)] + return self.__auto_array(filtered) + + def exclude_type(self, element_type): + mask = self.mask_by_type(element_type) + return self - mask + def __getitem__(self, key): if isinstance(key, slice): # Slicing - eltType = None + element_type = None r = [] for i in range(*key.indices(len(self))): - if eltType is None: - eltType = type(self[i]) - elif not isinstance(self[i], eltType): - eltType = Element # Fall back to element + if element_type is None: + element_type = type(self[i]) + elif not isinstance(self[i], element_type): + element_type = Element # Fall back to element r.append(self[i]) - return self.__create_array("", eltType, r) + return self.__create_array("", element_type, r) elif isinstance(key, str): fields = key.split(":") if len(fields) <= 1: # Selection by name - eltType = None + element_type = None r = [] for e in self: if fnmatch.fnmatch(e.get_name(), key): - if eltType is None: - eltType = type(e) - elif not isinstance(e, eltType): - eltType = Element # Fall back to element + if element_type is None: + element_type = type(e) + elif not isinstance(e, element_type): + element_type = Element # Fall back to element r.append(e) else: # Selection by fields - eltType = None + element_type = None r = [] for e in self: txt = self.__eval_field(fields[0], e) if fnmatch.fnmatch(txt, fields[1]): - if eltType is None: - eltType = type(e) - elif not isinstance(e, eltType): - eltType = Element # Fall back to element + if element_type is None: + element_type = type(e) + elif not isinstance(e, element_type): + element_type = Element # Fall back to element r.append(e) - return self.__create_array("", eltType, r) + return self.__create_array("", element_type, r) else: # Default to super selection diff --git a/tests/test_arrays_ops.py b/tests/test_arrays_ops.py new file mode 100644 index 00000000..cdb41594 --- /dev/null +++ b/tests/test_arrays_ops.py @@ -0,0 +1,229 @@ +import numpy as np +import pytest + +from pyaml.accelerator import Accelerator +from pyaml.arrays.element_array import ElementArray +from pyaml.arrays.magnet_array import MagnetArray +from pyaml.configuration.factory import Factory +from pyaml.magnet.magnet import Magnet + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_and_array_intersection_is_autotyped(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + hcorr = sr.live.get_magnets("HCORR") + hvcorr = sr.live.get_magnets("HVCORR") + + inter = hvcorr & hcorr + + assert isinstance(inter, MagnetArray) + assert inter.names() == hcorr.names() + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_and_mask_filters_and_is_autotyped_list_mask( + install_test_package, +): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + # "ElArray" is a mixed ElementArray in the dummy config (see existing tests) + elts = sr.design.get_elements("ElArray") + assert isinstance(elts, ElementArray) + assert len(elts) > 0 + + mask = [isinstance(e, Magnet) for e in elts] + res = elts & mask + + # Only magnets are kept -> result should be MagnetArray + assert isinstance(res, MagnetArray) + assert all(isinstance(e, Magnet) for e in res) + assert len(res) == sum(mask) + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_and_mask_filters_and_is_autotyped_numpy_mask( + install_test_package, +): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + elts = sr.design.get_elements("ElArray") + assert len(elts) > 0 + + mask_list = [isinstance(e, Magnet) for e in elts] + mask_np = np.array(mask_list, dtype=bool) + + res = elts & mask_np + + assert isinstance(res, MagnetArray) + assert all(isinstance(e, Magnet) for e in res) + assert len(res) == int(mask_np.sum()) + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_sub_mask_removes_true_inverse_of_and(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + elts = sr.design.get_elements("ElArray") + assert len(elts) > 0 + + # Keep only magnets with '& mask' + is_magnet = [isinstance(e, Magnet) for e in elts] + only_magnets = elts & is_magnet + assert all(isinstance(e, Magnet) for e in only_magnets) + + # Remove magnets with '- mask' (inverse operation) + without_magnets = elts - is_magnet + + # Result may be ElementArray or a more specific type depending on what's left, + # but it must not contain magnets. + if isinstance(without_magnets, list): + remaining = without_magnets + else: + remaining = list(without_magnets) + + assert all(not isinstance(e, Magnet) for e in remaining) + assert len(remaining) + len(only_magnets) == len(elts) + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_mask_length_mismatch_raises_for_and_and_sub( + install_test_package, +): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + elts = sr.design.get_elements("ElArray") + assert len(elts) > 0 + + bad_mask = [True] * (len(elts) - 1) + + with pytest.raises(ValueError): + _ = elts & bad_mask + + with pytest.raises(ValueError): + _ = elts - bad_mask + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_mask_by_type_returns_correct_boolean_mask(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + elts = sr.design.get_elements("ElArray") + mask = elts.mask_by_type(Magnet) + + assert isinstance(mask, list) + assert len(mask) == len(elts) + assert all(isinstance(v, bool) for v in mask) + + # Check semantic correctness + for e, m in zip(elts, mask, strict=True): + assert m == isinstance(e, Magnet) + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_filter_by_type_returns_autotyped_array(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + elts = sr.design.get_elements("ElArray") + filtered = elts.of_type(Magnet) + + if len(filtered) == 0: + assert filtered == [] + else: + assert isinstance(filtered, MagnetArray) + assert all(isinstance(e, Magnet) for e in filtered) + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_or_union_is_unique_stable_and_autotyped(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + hcorr = sr.live.get_magnets("HCORR") + vcorr = sr.live.get_magnets("VCORR") + + u = hcorr | vcorr + + # auto-typed + assert isinstance(u, MagnetArray) + + # stable order: all hcorr first, then vcorr + assert u.names() == hcorr.names() + vcorr.names() + + # uniqueness: if you union with itself, no duplicates + uu = hcorr | hcorr + assert uu.names() == hcorr.names() + + Factory.clear() + + +@pytest.mark.parametrize( + "install_test_package", + [{"name": "tango-pyaml", "path": "tests/dummy_cs/tango-pyaml"}], + indirect=True, +) +def test_element_array_add_is_alias_of_union(install_test_package): + sr: Accelerator = Accelerator.load("tests/config/sr.yaml") + sr.design.get_lattice().disable_6d() + + hcorr = sr.live.get_magnets("HCORR") + vcorr = sr.live.get_magnets("VCORR") + + u1 = hcorr | vcorr + u2 = hcorr + vcorr + + assert isinstance(u2, MagnetArray) + assert u2.names() == u1.names() + + Factory.clear()