Skip to content

Commit 1a107b4

Browse files
Add Set-Like Operations to ElementArray
1 parent 8a44fbb commit 1a107b4

File tree

2 files changed

+412
-38
lines changed

2 files changed

+412
-38
lines changed

pyaml/arrays/element_array.py

Lines changed: 231 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import fnmatch
22
import importlib
3+
from typing import Sequence
4+
5+
import numpy as np
36

47
from ..bpm.bpm import BPM
58
from ..common.element import Element
69
from ..common.exception import PyAMLException
710
from ..magnet.cfm_magnet import CombinedFunctionMagnet
811
from ..magnet.magnet import Magnet
12+
from ..magnet.serialized_magnet import SerializedMagnets
913

1014

1115
class ElementArray(list[Element]):
@@ -33,12 +37,13 @@ class ElementArray(list[Element]):
3337
3438
"""
3539

36-
def __init__(self, arrayName: str, elements: list[Element], use_aggregator=True):
40+
def __init__(self, array_name: str, elements: list[Element], use_aggregator=True):
3741
super().__init__(i for i in elements)
38-
self.__name = arrayName
42+
self.__name = array_name
43+
self.__peer = None
44+
self.__use_aggregator = use_aggregator
3945
if len(elements) > 0:
4046
self.__peer = self[0]._peer if len(self) > 0 else None
41-
self.__use_aggretator = use_aggregator
4247
if self.__peer is None or any([m._peer != self.__peer for m in self]):
4348
raise PyAMLException(
4449
f"{self.__class__.__name__} {self.get_name()}: "
@@ -64,73 +69,261 @@ def names(self) -> list[str]:
6469
"""
6570
return [e.get_name() for e in self]
6671

67-
def __create_array(self, arrName: str, eltType: type, elements: list):
68-
if len(elements) == 0:
69-
return []
72+
def __create_array(self, array_name: str, element_type: type, elements: list):
73+
if element_type is None:
74+
element_type = Element
7075

71-
if issubclass(eltType, Magnet):
76+
if issubclass(element_type, Magnet):
7277
m = importlib.import_module("pyaml.arrays.magnet_array")
73-
arrayClass = getattr(m, "MagnetArray", None)
74-
return arrayClass("", elements, self.__use_aggretator)
75-
elif issubclass(eltType, BPM):
78+
array_class = getattr(m, "MagnetArray", None)
79+
return array_class(array_name, elements, self.__use_aggregator)
80+
elif issubclass(element_type, BPM):
7681
m = importlib.import_module("pyaml.arrays.bpm_array")
77-
arrayClass = getattr(m, "BPMArray", None)
78-
return arrayClass("", elements, self.__use_aggretator)
79-
elif issubclass(eltType, CombinedFunctionMagnet):
82+
array_class = getattr(m, "BPMArray", None)
83+
return array_class(array_name, elements, self.__use_aggregator)
84+
elif issubclass(element_type, CombinedFunctionMagnet):
8085
m = importlib.import_module("pyaml.arrays.cfm_magnet_array")
81-
arrayClass = getattr(m, "CombinedFunctionMagnetArray", None)
82-
return arrayClass("", elements, self.__use_aggretator)
83-
elif issubclass(eltType, Element):
84-
return ElementArray("", elements, self.__use_aggretator)
86+
array_class = getattr(m, "CombinedFunctionMagnetArray", None)
87+
return array_class(array_name, elements, self.__use_aggregator)
88+
elif issubclass(element_type, SerializedMagnets):
89+
m = importlib.import_module("pyaml.arrays.serialized_magnet_array")
90+
array_class = getattr(m, "SerializedMagnetsArray", None)
91+
return array_class(array_name, elements, self.__use_aggregator)
92+
elif issubclass(element_type, Element):
93+
return ElementArray(array_name, elements, self.__use_aggregator)
8594
else:
86-
raise PyAMLException(f"Unsupported sliced array for type {str(eltType)}")
95+
raise PyAMLException(
96+
f"Unsupported sliced array for type {str(element_type)}"
97+
)
8798

88-
def __eval_field(self, attName: str, e: Element) -> str:
89-
funcName = "get_" + attName
90-
func = getattr(e, funcName, None)
99+
def __eval_field(self, attribute_name: str, element: Element) -> str:
100+
function_name = "get_" + attribute_name
101+
func = getattr(element, function_name, None)
91102
return func() if func is not None else ""
92103

104+
def __ensure_compatible_operand(self, other: object) -> "ElementArray":
105+
"""Validate the operand used for set-like operations between arrays."""
106+
if not isinstance(other, ElementArray):
107+
raise TypeError(
108+
f"Unsupported operand type(s) for set operation: "
109+
f"'{type(self).__name__}' and '{type(other).__name__}'"
110+
)
111+
112+
if len(self) > 0 and len(other) > 0:
113+
if self.get_peer() is not None and other.get_peer() is not None:
114+
if self.get_peer() != other.get_peer():
115+
raise PyAMLException(
116+
f"{self.__class__.__name__}: cannot operate on arrays "
117+
"attached to different peers"
118+
)
119+
return other
120+
121+
def __auto_array(self, elements: list[Element]):
122+
"""Create the most specific array type for the given element list.
123+
124+
The target element type is the most specific common base class (nearest common
125+
ancestor) of all elements. This supports heterogeneous subclasses (e.g.,
126+
several Magnet subclasses) while still returning a MagnetArray when
127+
appropriate.
128+
"""
129+
if len(elements) == 0:
130+
return []
131+
132+
import inspect
133+
134+
def mro_as_list(cls: type) -> list[type]:
135+
# inspect.getmro returns (cls, ..., object)
136+
return list(inspect.getmro(cls))
137+
138+
# Start from the first element MRO as reference order (most specific first).
139+
common: list[type] = mro_as_list(type(elements[0]))
140+
141+
# Intersect while preserving MRO order from the first element.
142+
for e in elements[1:]:
143+
mro_set = set(mro_as_list(type(e)))
144+
common = [c for c in common if c in mro_set]
145+
if not common:
146+
break
147+
148+
# Pick the first suitable common base within the Element hierarchy.
149+
chosen: type = Element
150+
for c in common:
151+
if c is object:
152+
continue
153+
if issubclass(c, Element):
154+
chosen = c
155+
break
156+
157+
return self.__create_array("", chosen, elements)
158+
159+
def __is_bool_mask(self, other: object) -> bool:
160+
"""Return True if 'other' looks like a boolean mask (list or numpy array)."""
161+
# --- numpy boolean array ---
162+
try:
163+
if isinstance(other, np.ndarray) and other.dtype == bool:
164+
return True
165+
except Exception:
166+
pass
167+
168+
# --- python sequence of bools (but not a string/bytes) ---
169+
if isinstance(other, Sequence) and not isinstance(
170+
other, (str, bytes, bytearray)
171+
):
172+
# Avoid treating ElementArray as a mask
173+
if isinstance(other, ElementArray):
174+
return False
175+
# Accept only actual bool-like values
176+
try:
177+
return all(isinstance(v, bool) for v in other)
178+
except TypeError:
179+
return False
180+
181+
return False
182+
183+
def __and__(self, other: object):
184+
"""
185+
Intersection or boolean mask filtering.
186+
187+
- If other is ElementArray: intersection (based on element names)
188+
- If other is a boolean mask: keep elements where mask is True
189+
"""
190+
# --- mask filtering ---
191+
if self.__is_bool_mask(other):
192+
mask = list(other) # works for list/tuple and numpy arrays
193+
if len(mask) != len(self):
194+
raise ValueError(
195+
f"{self.__class__.__name__}: mask length ({len(mask)}) "
196+
f"does not match array length ({len(self)})"
197+
)
198+
res = [e for e, keep in zip(self, mask, strict=True) if bool(keep)]
199+
return self.__auto_array(res)
200+
201+
# --- array intersection ---
202+
other_arr = self.__ensure_compatible_operand(other)
203+
other_names = {e.get_name() for e in other_arr}
204+
res = [e for e in self if e.get_name() in other_names]
205+
return self.__auto_array(res)
206+
207+
def __rand__(self, other: object):
208+
# Support "array on the right" for array operands; for masks, we don't enforce
209+
# commutativity.
210+
if isinstance(other, ElementArray):
211+
return other.__and__(self)
212+
return NotImplemented
213+
214+
def __sub__(self, other: object):
215+
"""
216+
Difference or boolean mask removal.
217+
218+
- If other is ElementArray: difference (based on element names)
219+
- If other is a boolean mask: remove elements where mask is True (inverse of
220+
'& mask')
221+
"""
222+
# --- mask removal ---
223+
if self.__is_bool_mask(other):
224+
mask = list(other)
225+
if len(mask) != len(self):
226+
raise ValueError(
227+
f"{self.__class__.__name__}: mask length ({len(mask)}) "
228+
f"does not match array length ({len(self)})"
229+
)
230+
res = [e for e, remove in zip(self, mask, strict=True) if not bool(remove)]
231+
return self.__auto_array(res)
232+
233+
# --- array difference ---
234+
other_arr = self.__ensure_compatible_operand(other)
235+
other_names = {e.get_name() for e in other_arr}
236+
res = [e for e in self if e.get_name() not in other_names]
237+
return self.__auto_array(res)
238+
239+
def mask_by_type(self, element_type: type) -> list[bool]:
240+
"""Return a boolean mask indicating which elements are instances of the given
241+
type.
242+
243+
Parameters
244+
----------
245+
element_type : type
246+
The class to test against (e.g., Magnet).
247+
248+
Returns
249+
-------
250+
list[bool]
251+
A list of booleans where True indicates the element is an instance
252+
of the given type (including subclasses).
253+
"""
254+
if not isinstance(element_type, type):
255+
raise TypeError(f"{self.__class__.__name__}: element_type must be a type")
256+
257+
return [isinstance(e, element_type) for e in self]
258+
259+
def of_type(self, element_type: type):
260+
"""Return a new array containing only elements of the given type.
261+
262+
The resulting array is automatically typed according to the most
263+
specific common base class of the filtered elements.
264+
265+
Parameters
266+
----------
267+
element_type : type
268+
The class to filter by (e.g., Magnet).
269+
270+
Returns
271+
-------
272+
ElementArray or specialized array
273+
An auto-typed array containing only matching elements.
274+
Returns [] if no elements match.
275+
"""
276+
if not isinstance(element_type, type):
277+
raise TypeError(f"{self.__class__.__name__}: element_type must be a type")
278+
279+
filtered = [e for e in self if isinstance(e, element_type)]
280+
return self.__auto_array(filtered)
281+
282+
def exclude_type(self, element_type):
283+
mask = self.mask_by_type(element_type)
284+
return self - mask
285+
93286
def __getitem__(self, key):
94287
if isinstance(key, slice):
95288
# Slicing
96-
eltType = None
289+
element_type = None
97290
r = []
98291
for i in range(*key.indices(len(self))):
99-
if eltType is None:
100-
eltType = type(self[i])
101-
elif not isinstance(self[i], eltType):
102-
eltType = Element # Fall back to element
292+
if element_type is None:
293+
element_type = type(self[i])
294+
elif not isinstance(self[i], element_type):
295+
element_type = Element # Fall back to element
103296
r.append(self[i])
104-
return self.__create_array("", eltType, r)
297+
return self.__create_array("", element_type, r)
105298

106299
elif isinstance(key, str):
107300
fields = key.split(":")
108301

109302
if len(fields) <= 1:
110303
# Selection by name
111-
eltType = None
304+
element_type = None
112305
r = []
113306
for e in self:
114307
if fnmatch.fnmatch(e.get_name(), key):
115-
if eltType is None:
116-
eltType = type(e)
117-
elif not isinstance(e, eltType):
118-
eltType = Element # Fall back to element
308+
if element_type is None:
309+
element_type = type(e)
310+
elif not isinstance(e, element_type):
311+
element_type = Element # Fall back to element
119312
r.append(e)
120313
else:
121314
# Selection by fields
122-
eltType = None
315+
element_type = None
123316
r = []
124317
for e in self:
125318
txt = self.__eval_field(fields[0], e)
126319
if fnmatch.fnmatch(txt, fields[1]):
127-
if eltType is None:
128-
eltType = type(e)
129-
elif not isinstance(e, eltType):
130-
eltType = Element # Fall back to element
320+
if element_type is None:
321+
element_type = type(e)
322+
elif not isinstance(e, element_type):
323+
element_type = Element # Fall back to element
131324
r.append(e)
132325

133-
return self.__create_array("", eltType, r)
326+
return self.__create_array("", element_type, r)
134327

135328
else:
136329
# Default to super selection

0 commit comments

Comments
 (0)