11import fnmatch
22import importlib
3+ from typing import Sequence
4+
5+ import numpy as np
36
47from ..bpm .bpm import BPM
58from ..common .element import Element
69from ..common .exception import PyAMLException
710from ..magnet .cfm_magnet import CombinedFunctionMagnet
811from ..magnet .magnet import Magnet
12+ from ..magnet .serialized_magnet import SerializedMagnets
913
1014
1115class ElementArray (list [Element ]):
@@ -33,12 +37,13 @@ class ElementArray(list[Element]):
3337
3438 """
3539
36- def __init__ (self , arrayName : str , elements : list [Element ], use_aggregator = True ):
40+ def __init__ (self , array_name : str , elements : list [Element ], use_aggregator = True ):
3741 super ().__init__ (i for i in elements )
38- self .__name = arrayName
42+ self .__name = array_name
43+ self .__peer = None
44+ self .__use_aggregator = use_aggregator
3945 if len (elements ) > 0 :
4046 self .__peer = self [0 ]._peer if len (self ) > 0 else None
41- self .__use_aggretator = use_aggregator
4247 if self .__peer is None or any ([m ._peer != self .__peer for m in self ]):
4348 raise PyAMLException (
4449 f"{ self .__class__ .__name__ } { self .get_name ()} : "
@@ -64,73 +69,261 @@ def names(self) -> list[str]:
6469 """
6570 return [e .get_name () for e in self ]
6671
67- def __create_array (self , arrName : str , eltType : type , elements : list ):
68- if len ( elements ) == 0 :
69- return []
72+ def __create_array (self , array_name : str , element_type : type , elements : list ):
73+ if element_type is None :
74+ element_type = Element
7075
71- if issubclass (eltType , Magnet ):
76+ if issubclass (element_type , Magnet ):
7277 m = importlib .import_module ("pyaml.arrays.magnet_array" )
73- arrayClass = getattr (m , "MagnetArray" , None )
74- return arrayClass ( "" , elements , self .__use_aggretator )
75- elif issubclass (eltType , BPM ):
78+ array_class = getattr (m , "MagnetArray" , None )
79+ return array_class ( array_name , elements , self .__use_aggregator )
80+ elif issubclass (element_type , BPM ):
7681 m = importlib .import_module ("pyaml.arrays.bpm_array" )
77- arrayClass = getattr (m , "BPMArray" , None )
78- return arrayClass ( "" , elements , self .__use_aggretator )
79- elif issubclass (eltType , CombinedFunctionMagnet ):
82+ array_class = getattr (m , "BPMArray" , None )
83+ return array_class ( array_name , elements , self .__use_aggregator )
84+ elif issubclass (element_type , CombinedFunctionMagnet ):
8085 m = importlib .import_module ("pyaml.arrays.cfm_magnet_array" )
81- arrayClass = getattr (m , "CombinedFunctionMagnetArray" , None )
82- return arrayClass ("" , elements , self .__use_aggretator )
83- elif issubclass (eltType , Element ):
84- return ElementArray ("" , elements , self .__use_aggretator )
86+ array_class = getattr (m , "CombinedFunctionMagnetArray" , None )
87+ return array_class (array_name , elements , self .__use_aggregator )
88+ elif issubclass (element_type , SerializedMagnets ):
89+ m = importlib .import_module ("pyaml.arrays.serialized_magnet_array" )
90+ array_class = getattr (m , "SerializedMagnetsArray" , None )
91+ return array_class (array_name , elements , self .__use_aggregator )
92+ elif issubclass (element_type , Element ):
93+ return ElementArray (array_name , elements , self .__use_aggregator )
8594 else :
86- raise PyAMLException (f"Unsupported sliced array for type { str (eltType )} " )
95+ raise PyAMLException (
96+ f"Unsupported sliced array for type { str (element_type )} "
97+ )
8798
88- def __eval_field (self , attName : str , e : Element ) -> str :
89- funcName = "get_" + attName
90- func = getattr (e , funcName , None )
99+ def __eval_field (self , attribute_name : str , element : Element ) -> str :
100+ function_name = "get_" + attribute_name
101+ func = getattr (element , function_name , None )
91102 return func () if func is not None else ""
92103
104+ def __ensure_compatible_operand (self , other : object ) -> "ElementArray" :
105+ """Validate the operand used for set-like operations between arrays."""
106+ if not isinstance (other , ElementArray ):
107+ raise TypeError (
108+ f"Unsupported operand type(s) for set operation: "
109+ f"'{ type (self ).__name__ } ' and '{ type (other ).__name__ } '"
110+ )
111+
112+ if len (self ) > 0 and len (other ) > 0 :
113+ if self .get_peer () is not None and other .get_peer () is not None :
114+ if self .get_peer () != other .get_peer ():
115+ raise PyAMLException (
116+ f"{ self .__class__ .__name__ } : cannot operate on arrays "
117+ "attached to different peers"
118+ )
119+ return other
120+
121+ def __auto_array (self , elements : list [Element ]):
122+ """Create the most specific array type for the given element list.
123+
124+ The target element type is the most specific common base class (nearest common
125+ ancestor) of all elements. This supports heterogeneous subclasses (e.g.,
126+ several Magnet subclasses) while still returning a MagnetArray when
127+ appropriate.
128+ """
129+ if len (elements ) == 0 :
130+ return []
131+
132+ import inspect
133+
134+ def mro_as_list (cls : type ) -> list [type ]:
135+ # inspect.getmro returns (cls, ..., object)
136+ return list (inspect .getmro (cls ))
137+
138+ # Start from the first element MRO as reference order (most specific first).
139+ common : list [type ] = mro_as_list (type (elements [0 ]))
140+
141+ # Intersect while preserving MRO order from the first element.
142+ for e in elements [1 :]:
143+ mro_set = set (mro_as_list (type (e )))
144+ common = [c for c in common if c in mro_set ]
145+ if not common :
146+ break
147+
148+ # Pick the first suitable common base within the Element hierarchy.
149+ chosen : type = Element
150+ for c in common :
151+ if c is object :
152+ continue
153+ if issubclass (c , Element ):
154+ chosen = c
155+ break
156+
157+ return self .__create_array ("" , chosen , elements )
158+
159+ def __is_bool_mask (self , other : object ) -> bool :
160+ """Return True if 'other' looks like a boolean mask (list or numpy array)."""
161+ # --- numpy boolean array ---
162+ try :
163+ if isinstance (other , np .ndarray ) and other .dtype == bool :
164+ return True
165+ except Exception :
166+ pass
167+
168+ # --- python sequence of bools (but not a string/bytes) ---
169+ if isinstance (other , Sequence ) and not isinstance (
170+ other , (str , bytes , bytearray )
171+ ):
172+ # Avoid treating ElementArray as a mask
173+ if isinstance (other , ElementArray ):
174+ return False
175+ # Accept only actual bool-like values
176+ try :
177+ return all (isinstance (v , bool ) for v in other )
178+ except TypeError :
179+ return False
180+
181+ return False
182+
183+ def __and__ (self , other : object ):
184+ """
185+ Intersection or boolean mask filtering.
186+
187+ - If other is ElementArray: intersection (based on element names)
188+ - If other is a boolean mask: keep elements where mask is True
189+ """
190+ # --- mask filtering ---
191+ if self .__is_bool_mask (other ):
192+ mask = list (other ) # works for list/tuple and numpy arrays
193+ if len (mask ) != len (self ):
194+ raise ValueError (
195+ f"{ self .__class__ .__name__ } : mask length ({ len (mask )} ) "
196+ f"does not match array length ({ len (self )} )"
197+ )
198+ res = [e for e , keep in zip (self , mask , strict = True ) if bool (keep )]
199+ return self .__auto_array (res )
200+
201+ # --- array intersection ---
202+ other_arr = self .__ensure_compatible_operand (other )
203+ other_names = {e .get_name () for e in other_arr }
204+ res = [e for e in self if e .get_name () in other_names ]
205+ return self .__auto_array (res )
206+
207+ def __rand__ (self , other : object ):
208+ # Support "array on the right" for array operands; for masks, we don't enforce
209+ # commutativity.
210+ if isinstance (other , ElementArray ):
211+ return other .__and__ (self )
212+ return NotImplemented
213+
214+ def __sub__ (self , other : object ):
215+ """
216+ Difference or boolean mask removal.
217+
218+ - If other is ElementArray: difference (based on element names)
219+ - If other is a boolean mask: remove elements where mask is True (inverse of
220+ '& mask')
221+ """
222+ # --- mask removal ---
223+ if self .__is_bool_mask (other ):
224+ mask = list (other )
225+ if len (mask ) != len (self ):
226+ raise ValueError (
227+ f"{ self .__class__ .__name__ } : mask length ({ len (mask )} ) "
228+ f"does not match array length ({ len (self )} )"
229+ )
230+ res = [e for e , remove in zip (self , mask , strict = True ) if not bool (remove )]
231+ return self .__auto_array (res )
232+
233+ # --- array difference ---
234+ other_arr = self .__ensure_compatible_operand (other )
235+ other_names = {e .get_name () for e in other_arr }
236+ res = [e for e in self if e .get_name () not in other_names ]
237+ return self .__auto_array (res )
238+
239+ def mask_by_type (self , element_type : type ) -> list [bool ]:
240+ """Return a boolean mask indicating which elements are instances of the given
241+ type.
242+
243+ Parameters
244+ ----------
245+ element_type : type
246+ The class to test against (e.g., Magnet).
247+
248+ Returns
249+ -------
250+ list[bool]
251+ A list of booleans where True indicates the element is an instance
252+ of the given type (including subclasses).
253+ """
254+ if not isinstance (element_type , type ):
255+ raise TypeError (f"{ self .__class__ .__name__ } : element_type must be a type" )
256+
257+ return [isinstance (e , element_type ) for e in self ]
258+
259+ def of_type (self , element_type : type ):
260+ """Return a new array containing only elements of the given type.
261+
262+ The resulting array is automatically typed according to the most
263+ specific common base class of the filtered elements.
264+
265+ Parameters
266+ ----------
267+ element_type : type
268+ The class to filter by (e.g., Magnet).
269+
270+ Returns
271+ -------
272+ ElementArray or specialized array
273+ An auto-typed array containing only matching elements.
274+ Returns [] if no elements match.
275+ """
276+ if not isinstance (element_type , type ):
277+ raise TypeError (f"{ self .__class__ .__name__ } : element_type must be a type" )
278+
279+ filtered = [e for e in self if isinstance (e , element_type )]
280+ return self .__auto_array (filtered )
281+
282+ def exclude_type (self , element_type ):
283+ mask = self .mask_by_type (element_type )
284+ return self - mask
285+
93286 def __getitem__ (self , key ):
94287 if isinstance (key , slice ):
95288 # Slicing
96- eltType = None
289+ element_type = None
97290 r = []
98291 for i in range (* key .indices (len (self ))):
99- if eltType is None :
100- eltType = type (self [i ])
101- elif not isinstance (self [i ], eltType ):
102- eltType = Element # Fall back to element
292+ if element_type is None :
293+ element_type = type (self [i ])
294+ elif not isinstance (self [i ], element_type ):
295+ element_type = Element # Fall back to element
103296 r .append (self [i ])
104- return self .__create_array ("" , eltType , r )
297+ return self .__create_array ("" , element_type , r )
105298
106299 elif isinstance (key , str ):
107300 fields = key .split (":" )
108301
109302 if len (fields ) <= 1 :
110303 # Selection by name
111- eltType = None
304+ element_type = None
112305 r = []
113306 for e in self :
114307 if fnmatch .fnmatch (e .get_name (), key ):
115- if eltType is None :
116- eltType = type (e )
117- elif not isinstance (e , eltType ):
118- eltType = Element # Fall back to element
308+ if element_type is None :
309+ element_type = type (e )
310+ elif not isinstance (e , element_type ):
311+ element_type = Element # Fall back to element
119312 r .append (e )
120313 else :
121314 # Selection by fields
122- eltType = None
315+ element_type = None
123316 r = []
124317 for e in self :
125318 txt = self .__eval_field (fields [0 ], e )
126319 if fnmatch .fnmatch (txt , fields [1 ]):
127- if eltType is None :
128- eltType = type (e )
129- elif not isinstance (e , eltType ):
130- eltType = Element # Fall back to element
320+ if element_type is None :
321+ element_type = type (e )
322+ elif not isinstance (e , element_type ):
323+ element_type = Element # Fall back to element
131324 r .append (e )
132325
133- return self .__create_array ("" , eltType , r )
326+ return self .__create_array ("" , element_type , r )
134327
135328 else :
136329 # Default to super selection
0 commit comments