diff --git a/.gitignore b/.gitignore index 34d7f904..4c3dd440 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ examples/.* test/flatfile2.txt .vscode/ + +# Local file: Misalign Analysis sub-module +apsuite/optics_analysis/misalign_analysis/Default_Buttons.pickle diff --git a/apsuite/__init__.py b/apsuite/__init__.py index f62abb29..0341eda2 100644 --- a/apsuite/__init__.py +++ b/apsuite/__init__.py @@ -5,5 +5,5 @@ __version__ = _f.read().strip() __all__ = ( - 'commisslib', 'loco', 'optics_analysis', 'optimization', - 'trackcpp_utils') + 'commisslib', 'dynap', 'loco', 'optics_analysis', 'optimization', + 'orbcorr', 'trackcpp_utils') diff --git a/apsuite/commisslib/meas_bpms_signals.py b/apsuite/commisslib/meas_bpms_signals.py index c6ff7286..9aa7a22a 100644 --- a/apsuite/commisslib/meas_bpms_signals.py +++ b/apsuite/commisslib/meas_bpms_signals.py @@ -20,44 +20,44 @@ def __init__(self): self.trigbpm_delay = None self.trigbpm_nrpulses = 1 self.do_pulse_evg = True - self._timing_event = 'Study' + self._timing_event = "Study" self.event_delay = None - self.event_mode = 'External' + self.event_mode = "External" self.timeout = 40 self.nrpoints_before = 0 self.nrpoints_after = 20000 - self.acq_rate = 'FAcq' + self.acq_rate = "FAcq" self.acq_repeat = False - self.signals2acq = 'XY' + self.signals2acq = "XY" def __str__(self): """.""" - ftmp = '{0:26s} = {1:9.6f} {2:s}\n'.format - dtmp = '{0:26s} = {1:9d} {2:s}\n'.format - stmp = '{0:26s} = {1:9} {2:s}\n'.format - stg = '' + ftmp = "{0:26s} = {1:9.6f} {2:s}\n".format + dtmp = "{0:26s} = {1:9d} {2:s}\n".format + stmp = "{0:26s} = {1:9} {2:s}\n".format + stg = "" dly = self.trigbpm_delay if dly is None: stg += stmp( 'trigbpm_delay', 'same', '(current value will not be changed)') else: - stg += ftmp('trigbpm_delay', dly, '[us]') - stg += dtmp('trigbpm_nrpulses', self.trigbpm_nrpulses, '') - stg += stmp('do_pulse_evg', str(self.do_pulse_evg), '') - stg += stmp('timing_event', self.timing_event, '') + stg += ftmp("trigbpm_delay", dly, "[us]") + stg += dtmp("trigbpm_nrpulses", self.trigbpm_nrpulses, "") + stg += stmp("do_pulse_evg", str(self.do_pulse_evg), "") + stg += stmp("timing_event", self.timing_event, "") dly = self.event_delay if dly is None: stg += stmp( 'event_delay', 'same', '(current value will not be changed)') else: - stg += ftmp('event_delay', dly, '[us]') - stg += stmp('event_mode', self.event_mode, '') - stg += ftmp('timeout', self.timeout, '[s]') - stg += dtmp('nrpoints_before', self.nrpoints_before, '') - stg += dtmp('nrpoints_after', self.nrpoints_after, '') - stg += stmp('acq_rate', self.acq_rate, '') - stg += dtmp('acq_repeat', self.acq_repeat, '') - stg += stmp('signals2acq', str(self.signals2acq), '') + stg += ftmp("event_delay", dly, "[us]") + stg += stmp("event_mode", self.event_mode, "") + stg += ftmp("timeout", self.timeout, "[s]") + stg += dtmp("nrpoints_before", self.nrpoints_before, "") + stg += dtmp("nrpoints_after", self.nrpoints_after, "") + stg += stmp("acq_rate", self.acq_rate, "") + stg += dtmp("acq_repeat", self.acq_repeat, "") + stg += stmp("signals2acq", str(self.signals2acq), "") return stg @property @@ -75,8 +75,8 @@ def from_dict(self, params_dict): """.""" dic = dict() for key, val in params_dict.items(): - if key.startswith('orbit_'): # compatibility with old data - key = key.replace('orbit_', '') + if key.startswith("orbit_"): # compatibility with old data + key = key.replace("orbit_", "") dic[key] = val return super().from_dict(dic) @@ -84,8 +84,8 @@ def from_dict(self, params_dict): class AcqBPMsSignals(_BaseClass): """.""" - BPM_TRIGGER = 'SI-Fam:TI-BPM' - PSM_TRIGGER = 'SI-Fam:TI-BPM-PsMtm' + BPM_TRIGGER = "SI-Fam:TI-BPM" + PSM_TRIGGER = "SI-Fam:TI-BPM-PsMtm" def __init__(self, isonline=True, ispost_mortem=False): """.""" @@ -113,42 +113,44 @@ def load_and_apply(self, fname: str): ret = super().load_and_apply(fname) data = dict() for key, val in self.data.items(): - if key.startswith('bpms_'): # compatibility with old data - key = key.replace('bpms_', '') + if key.startswith("bpms_"): # compatibility with old data + key = key.replace("bpms_", "") data[key] = val self.data = data return ret def create_devices(self): """.""" - self.devices['currinfo'] = CurrInfoSI() - self.devices['fambpms'] = FamBPMs( - devname=FamBPMs.DEVICES.SI, ispost_mortem=self._ispost_mortem, - props2init='acq') - self.devices['tune'] = Tune(Tune.DEVICES.SI) + self.devices["currinfo"] = CurrInfoSI() + self.devices["fambpms"] = FamBPMs( + devname=FamBPMs.DEVICES.SI, + ispost_mortem=self._ispost_mortem, + props2init="acq", + ) + self.devices["tune"] = Tune(Tune.DEVICES.SI) trigname = self.BPM_TRIGGER if self._ispost_mortem: trigname = self.PSM_TRIGGER - self.devices['trigbpm'] = Trigger(trigname) - self.devices['evt_study'] = Event('Study') - self.devices['evg'] = EVG() - self.devices['rfgen'] = RFGen() + self.devices["trigbpm"] = Trigger(trigname) + self.devices["evt_study"] = Event("Study") + self.devices["evg"] = EVG() + self.devices["rfgen"] = RFGen() def get_timing_state(self): """.""" - trigbpm = self.devices['trigbpm'] + trigbpm = self.devices["trigbpm"] state = dict() - state['trigbpm_source'] = trigbpm.source - state['trigbpm_nrpulses'] = trigbpm.nr_pulses - state['trigbpm_delay'] = trigbpm.delay + state["trigbpm_source"] = trigbpm.source + state["trigbpm_nrpulses"] = trigbpm.nr_pulses + state["trigbpm_delay"] = trigbpm.delay if self.params.do_pulse_evg: - state['evg_nrpulses'] = self.devices['evg'].nrpulses + state["evg_nrpulses"] = self.devices["evg"].nrpulses evt = self._get_event(self.params.timing_event) if evt is not None: - state['evt_delay'] = evt.delay - state['evt_mode'] = evt.mode + state["evt_delay"] = evt.delay + state["evt_mode"] = evt.mode return state def recover_timing_state(self, state): @@ -159,27 +161,28 @@ def prepare_timing(self, state=None): """.""" state = dict() if state is None else state - trigbpm = self.devices['trigbpm'] - dly = state.get('trigbpm_delay', self.params.trigbpm_delay) + trigbpm = self.devices["trigbpm"] + dly = state.get("trigbpm_delay", self.params.trigbpm_delay) if dly is not None: trigbpm.delay = dly trigbpm.nr_pulses = state.get( - 'trigbpm_nrpulses', self.params.trigbpm_nrpulses) - src = state.get('trigbpm_source', self.params.timing_event) + "trigbpm_nrpulses", self.params.trigbpm_nrpulses + ) + src = state.get("trigbpm_source", self.params.timing_event) trigbpm.source = src evt = self._get_event(self.params.timing_event) if evt is not None: - dly = state.get('evt_delay', self.params.event_delay) + dly = state.get("evt_delay", self.params.event_delay) if dly is not None: evt.delay = dly - evt.mode = state.get('evt_mode', self.params.event_mode) + evt.mode = state.get("evt_mode", self.params.event_mode) nrpul = 1 if self.params.do_pulse_evg else None - nrpul = state.get('evg_nrpulses', nrpul) + nrpul = state.get("evg_nrpulses", nrpul) if nrpul is not None: - evg = self.devices['evg'] + evg = self.devices["evg"] evg.set_nrpulses(nrpul) evg.cmd_update_events() @@ -188,30 +191,32 @@ def trigger_timing_signal(self): if not self.params.do_pulse_evg: return evt = self._get_event(self.params.timing_event) - if evt is not None and evt.mode_str == 'External': + if evt is not None and evt.mode_str == "External": evt.cmd_external_trigger() else: - self.devices['evg'].cmd_turn_on_injection() + self.devices["evg"].cmd_turn_on_injection() def prepare_bpms_acquisition(self): """.""" - fambpms = self.devices['fambpms'] + fambpms = self.devices["fambpms"] prms = self.params fambpms.mturn_signals2acq = self.params.signals2acq return fambpms.config_mturn_acquisition( nr_points_after=prms.nrpoints_after, nr_points_before=prms.nrpoints_before, - acq_rate=prms.acq_rate, repeat=prms.acq_repeat) + acq_rate=prms.acq_rate, + repeat=prms.acq_repeat, + ) def acquire_data(self): """.""" - fambpms = self.devices['fambpms'] + fambpms = self.devices["fambpms"] ret = self.prepare_bpms_acquisition() tag = self._bpm_tag(idx=abs(int(ret))-1) if ret < 0: - print(tag + ' did not finish last acquisition.') + print(tag + " did not finish last acquisition.") elif ret > 0: - print(tag + ' is not ready for acquisition.') + print(tag + " is not ready for acquisition.") fambpms.reset_mturn_initial_state() self.trigger_timing_signal() @@ -234,21 +239,22 @@ def acquire_data(self): def get_data(self): """Get Orbit and auxiliary data.""" - fbpms = self.devices['fambpms'] + fbpms = self.devices["fambpms"] mturn_orbit = fbpms.get_mturn_signals() data = dict() - data['ispost_mortem'] = self._ispost_mortem - data['timestamp'] = _time.time() - rf_freq = self.devices['rfgen'].frequency - data['rf_frequency'] = rf_freq - data['stored_current'] = self.devices['currinfo'].current + data["ispost_mortem"] = self._ispost_mortem + data["timestamp"] = _time.time() + rf_freq = self.devices["rfgen"].frequency + data["rf_frequency"] = rf_freq + data["stored_current"] = self.devices["currinfo"].current if list(self.params.signals2acq) != list(fbpms.mturn_signals2acq): - raise ValueError('signals2acq was not configured properly.') + raise ValueError("signals2acq was not configured properly.") elif len(mturn_orbit) != len(fbpms.mturn_signals2acq): raise ValueError( - 'Lenght of signals2acq does not match signals acquired.') + "Lenght of signals2acq does not match signals acquired." + ) for i, sig in enumerate(self.params.signals2acq): sig = sig.lower() name = 'sumdata' @@ -260,23 +266,24 @@ def get_data(self): name = 'posq' data[name] = mturn_orbit[i] - tune = self.devices['tune'] - data['tunex'], data['tuney'] = tune.tunex, tune.tuney + tune = self.devices["tune"] + data["tunex"], data["tuney"] = tune.tunex, tune.tuney bpm0 = fbpms.devices[0] - data['acq_rate'] = bpm0.acq_channel_str - data['sampling_frequency'] = fbpms.get_sampling_frequency(rf_freq) - data['nrsamples_pre'] = bpm0.acq_nrsamples_pre - data['nrsamples_post'] = bpm0.acq_nrsamples_post - data['trig_delay_raw'] = self.devices['trigbpm'].delay_raw - data['switching_mode'] = bpm0.switching_mode_str - data['switching_frequency'] = fbpms.get_switching_frequency(rf_freq) - data['tunex_enable'] = tune.enablex - data['tuney_enable'] = tune.enabley + data["acq_rate"] = bpm0.acq_channel_str + data["sampling_frequency"] = fbpms.get_sampling_frequency(rf_freq) + data["nrsamples_pre"] = bpm0.acq_nrsamples_pre + data["nrsamples_post"] = bpm0.acq_nrsamples_post + data["trig_delay_raw"] = self.devices["trigbpm"].delay_raw + data["switching_mode"] = bpm0.switching_mode_str + data["switching_frequency"] = fbpms.get_switching_frequency(rf_freq) + data["tunex_enable"] = tune.enablex + data["tuney_enable"] = tune.enabley return data @staticmethod def filter_data_frequencies( - orb, fmin, fmax, fsampling, keep_within_range=True): + orb, fmin, fmax, fsampling, keep_within_range=True + ): """Filter acquisition matrix considering a frequency range. Args: @@ -292,7 +299,7 @@ def filter_data_frequencies( """ dft = _sp_fft.rfft(orb, axis=0) - freq = _sp_fft.rfftfreq(orb.shape[0], d=1/fsampling) + freq = _sp_fft.rfftfreq(orb.shape[0], d=1 / fsampling) if keep_within_range: idcs = (freq < fmin) | (freq > fmax) dft[idcs] = 0 @@ -316,7 +323,7 @@ def filter_switching_cycles(orb, freq_sampling, freq_switching): """ # Calculate the number of samples per switching cycle - sw_sample_size = round(freq_sampling/freq_switching) + sw_sample_size = round(freq_sampling / freq_switching) osiz = orb.shape[0] nr_sws = osiz // sw_sample_size siz = nr_sws * sw_sample_size @@ -330,7 +337,7 @@ def filter_switching_cycles(orb, freq_sampling, freq_switching): # Replicate the switching signature to match the size of original data sw_pert = _np.tile(sw_sig, (1, nr_sws)) if osiz > siz: - sw_pert = _np.hstack([sw_pert, sw_sig[:, :osiz-siz]]) + sw_pert = _np.hstack([sw_pert, sw_sig[:, : osiz - siz]]) # Subtract the replicated switching signature from the original data return orb - sw_pert.T @@ -404,18 +411,18 @@ def calc_hilbert_transform(data, axis=0): return _sp_sig.hilbert(data, axis=axis) def _bpm_tag(self, idx): - names = self.devices['fambpms'].bpm_names - return f'{names[idx]:s} (idx={idx:d})' + names = self.devices["fambpms"].bpm_names + return f"{names[idx]:s} (idx={idx:d})" def _get_event(self, evtname): if evtname not in _HLTimeSearch.get_configurable_hl_events(): - print('WARN:Event is not configurable.') + print("WARN:Event is not configurable.") return None - stg = f'evt_{evtname.lower():s}' + stg = f"evt_{evtname.lower():s}" evt = self.devices.get(stg, Event(evtname)) if evt.wait_for_connection(timeout=10): self.devices[stg] = evt else: - print('ERR:Event not connected.') + print("ERR:Event not connected.") return None return evt diff --git a/apsuite/optics_analysis/misalign_analysis/__init__.py b/apsuite/optics_analysis/misalign_analysis/__init__.py new file mode 100644 index 00000000..8f85e73c --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/__init__.py @@ -0,0 +1,22 @@ +"""Misalign analysis package.""" + +from . import fitting, functions as functions, si_data +from .base import Base, delete_default_base, get_default_base, \ + save_default_base, set_model +from .buttons import Button +from .si_data import get_model + +del base, buttons + +__all__ = [ + "Base", + "Button", + "functions", + "fitting", + "si_data", + "get_model", + "set_model", + "get_default_base", + "save_default_base", + "delete_default_base", +] diff --git a/apsuite/optics_analysis/misalign_analysis/about.md b/apsuite/optics_analysis/misalign_analysis/about.md new file mode 100644 index 00000000..43fc8365 --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/about.md @@ -0,0 +1,82 @@ +# Pynel package for vertical dispersion analysis + +The pynel package has 2 main objects: Base and Button. These objects have properties that envolves the analysis of vertical dispersion function of the SIRIUS storage ring and its signatures associated to magnets tranversal and rotation misalignments. + +## Object Button + +The Button object associates one kind of error (transversal or rotation misalignment: x, y, roll, pitch and yaw) to one magnet of the SIRIUS ring and store the vertical dispersion signature caused by the magnet and the error choosen. + +#### The creation +The creation of a Button follows 2 possible ways. +- 1st. Passing 3 arguments: the magnet name, the magnet sector and the error associated. \ +Example: ``` qfa_sect5_dx = Button(sect=5, name="QFA", dtype="dx")``` -> creates a quadrupole QFA Button located in the 5th sector with tranversal horizontal misalignment error. +- 2nd. Passing 2 arguments: the magnet indices in the SIRIUS _"pymodels"_ model, and the error associated. \ +Example: ``` sfa1_sect1_dr = Button(indices=[74], dtype="dx")``` -> creates a sextupole SFA1 Button located in the 1st sector with rotation roll error. +For the 2nd option to create a Button, its necessary to check if Pymodels is up-to-date. + +#### About the arguments +- ```name```: the creation of any Button requires its name (when not passing "indices" arg) that is any magnet family name that exists in the SIRIUS ring: _"B1, B2, BC, Q1, Q2 ... QFA ... QDB2 ... SFA1 ... SFP2 ... SDB3"_. \ +If a inexistent family name is passed (or any other random thing) the Button will still be generated, but invalid, that means the Button will not compute and/or store a Vertical Dispersion Signature. \ +Obs.: In sectors that have more then one magnet of the same family (like dipoles B1 and B2 or quadupoles Q1, Q2...) the ```name``` argument can specify the precise magnet wanted: ```b1_1_sect7_dy = Button(7, "B1_1", "dy")```. + +- ```sect```: the creation of any Button requires its sector (when not passing "indices" arg). The sectors are restrict to integer numbers from 1 to 20 (the real sectors in SIRIUS). \ +Like the ```name``` argument, if a non-existent sector is passed the Button will be generated, but invalid. + +- ```indices```: the creation of any Button can be made by passing the indices of the magnet in the model. Is recommended to be careful when creating dipoles Buttons by its indices or when working with a refined model (non single integer as its indices). + +- ```dtype```: the creation of any Button always requires its error associated. The error are restricted to: + - ```dx``` - Horizontal misalignment error + - ```dy``` - Vertical misalignment error + - ```dr``` - Rotation roll error + - ```drp``` - Rotation pitch error + - ```dry``` - Rotation yaw error + +#### About the properties +- ```.bname``` = the magnet family (button) name +- ```.fantasy_name``` = the magnet specified name +- ```.sect``` = the magnet sector location +- ```.sectype``` = the magnet sector type (like: "HighBetaA to LowBetaB") +- ```.dtype``` = the error associated +- ```.signature``` = the vertical dispersion signature of the magnet with the error associated +- ```.indices``` = the indices of the magnet in the model + +## Object Base +The Base object generate Buttons and construct a VDRM matrix (Vertical Dispersion Response Matrix) of the signatures of the buttons. Base objects can be use to easily creates sets of specified kind of magnets or errors or sectors and study with more detail the behavior of the Vertical Dispersion in these combinations. + +#### The creation +The creation of a Base follows 2 possible ways. +- 1st. Passing 3 arguments: the elements (magnets names/families), the sectors and the errors. \ +Example: ```base = Base(sects=[1,7,13,20], elements=["SDB3", "BC"], dtypes=["dy", "dr"])``` -> creates a Base with the dipoles BC and sextupoles SDB3 in the sectors 1, 7, 13 and 20 with the vertical misalignment and rotation roll errors. \ +Obs.: if any magnet not exists in the sector, the button will be discarded. (Except when controlling the ```default_valids``` arg) +- 2nd. Passing 1 arg: already generated buttons. Example: ``` base = Base(buttons=list_of_buttons)``` -> creates a Base with the buttons of the ```list_of_buttons```. + +#### About the arguments +- ```elements```: the magnets family/fantasy names -> The valid "elements" follows the Buttons valid "names"/```bnames``` +- ```sects``` : list of integers. The valid sectors are the integers from 1 to 20. +- ```dtypes``` : the misalignment/rotation errors. The valid "dtypes" follows the Buttons valid ```dtype```(s) +- ```buttons``` : should be a list of Button objects +- ```default_valids```: allows to control wether Button is considered valid or invalid. And if the Button will be or not discarded in the gen process. +- ```force_rebuild``` : force the calculation of the vertical dispersion signature of the buttons +- ```func```: specify the function to be calculated as the buttons signatures. valid "funcs" are: `vertical_disp` or `testfunc`. The "testfunc" simply sets the buttons signatures as a zero-arrays (used for sandbox creation of buttons) + +#### About the properties +- ```.magnets``` = the buttons magnet families +- ```.named_magnets``` = the buttons magnet fantasy (specified) names +- ```.sectors``` = the sectors of the Base +- ```.sector_types``` = the sector types (like: "HighBetaA to LowBetaB") of the sectors of the Base +- ```.dtypes``` = the errors of the Base +- ```.buttons``` = the buttons of the Base +- ```.resp_mat``` = the constructed VDRM (Vertical Dispersion Response Matrix) + +## The "fitting" module + +The fitting module contains functions to fit vertical dispertion function (like real data collected in the machine) in _pymodels_ models. +Obs.: this module is outdated (~ september 19, 2023). The functions shouldnt work as expected. + +## The "misc_functions"/"functions" module + +Contains functions to work with Base and Buttons objects and deal with vertical dispertion fittings and analysis. + +## The "std_si_Data" module + +Contains saved data of the Standard SIRIUS model. \ No newline at end of file diff --git a/apsuite/optics_analysis/misalign_analysis/base.py b/apsuite/optics_analysis/misalign_analysis/base.py new file mode 100644 index 00000000..26fef79d --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/base.py @@ -0,0 +1,325 @@ +"""Module 'base' for the class object 'Base': a collection of 'Buttons'.""" + +from copy import deepcopy as _dpcopy +from itertools import product + +import numpy as _np +from mathphys.functions import load_pickle, save_pickle + +from . import buttons +from .si_data import si_elems, si_sectors, std_misaligment_types + +_STD_ELEMS = si_elems() +_STD_TYPES = std_misaligment_types() +_STD_SECTS = si_sectors() +_D_BUTTONS_FILE = "Default_Buttons.pickle" +_DEFAULT_BUTTONS = [] + + +def get_default_base(): + """.""" + __load_default_buttons() + return globals()["_DEFAULT_BUTTONS"] + + +def set_model(model=None): + """.""" + buttons.buttons_set_model(model) + + +def __load_default_buttons(): + try: + globals()["_DEFAULT_BUTTONS"] = load_pickle(_D_BUTTONS_FILE) + except FileNotFoundError: + globals()["_DEFAULT_BUTTONS"] = [] + + +def update_default_base(): + """Update the Default Base.""" + __load_default_buttons() + + +class Base: + """Base object for misalign analysis. + + Args: + elems (list[str], str): List of magnets families names.\ + The valid options are the SIRIUS's dipoles, quadrupoles and\ + sextupoles. + + sects (list[int], int): List of sectors. Defaults to "all": \ + list of 1 to 20. + + dtypes (list[str], str): List of misalignmente types. Defaults to \ + "all": ['dx', 'dy', 'dr', 'drp', 'dry']. + + buttons (list[Button], optional): List of Button objects. + + func (str): The analysis function. Defaults to "vertical_disp". \ + The valid options are: 'vertical_disp' and 'testfunc'. + + use_root_buttons (bool, optional): Use default pre-saved buttons. + + About: + Passing args 'elems' + 'sects' require the arg 'buttons' be None. + + The arg 'elems' can be (str) 'all', that contains all the magnets. + + The arg 'func' can be (str) 'testfunc': the button signature will\ + be all zero arrays. + + The 'use_root_buttons' arg sets if the Base creation will or not \ + use pre saved buttons and its signatures. + """ + + def __init__( + self, + elems="all", + sects="all", + dtypes="all", + buttons=None, + func="vertical_disp", + use_root_buttons=True, + ): + """.""" + self._func = None + self._use_root = None + + self._func, self._use_root = self.__handle_input( + elems, sects, buttons, func, use_root_buttons + ) + + if buttons is None: + self.__force_init(elems, sects, dtypes) + + else: + self.__handle_buttons(buttons) + + self._matrix = self.__make_matrix() + + def __handle_input(self, elems, sects, buttons, func, use_root_buttons): + """.""" + if func not in ("testfunc", "vertical_disp", "twiss"): + raise ValueError("invalid arg: func") + + if use_root_buttons not in (True, False): + raise ValueError("invalid arg: use_root_Buttons") + + if use_root_buttons: + update_default_base() + + if any(f is not None for f in [elems, sects]) and buttons is not None: + raise ValueError("too much args: (buttons) and (elems or sects)") + + return func, use_root_buttons + + def __force_init(self, elems, sects, dtypes): + """.""" + # reading dtypes + self._dtypes = dtypes + if isinstance(dtypes, (list, tuple)) and all( + i in _STD_TYPES for i in self._dtypes + ): + self._dtypes = sorted( + list(set(self._dtypes)), key=lambda x: _STD_TYPES.index(x) + ) + elif self._dtypes in _STD_TYPES: + self._dtypes = [self._dtypes] + elif self._dtypes == "all": + self._dtypes = _STD_TYPES + else: + raise ValueError("invalid arg: dtypes") + + # reading sects + self._sects = sects + if ( + isinstance(self._sects, (int, _np.integer)) + and 0 < self._sects <= 20 + ): + self._sects = [self._sects] + elif ( + isinstance(self._sects, (list, tuple, _np.ndarray)) + and all(isinstance(i, (int, _np.integer)) for i in self._sects) + and all(0 < i <= 20 for i in self._sects) + ): + self._sects = sorted(list(set(self._sects))) + elif self._sects == "all": + self._sects = _STD_SECTS + else: + raise ValueError("invalid arg: sects") + + # reading elems + self._elems = elems + if isinstance(self._elems, (list, tuple)) and all( + i in _STD_ELEMS for i in self._elems + ): + self._elems = sorted( + list(set(self._elems)), key=lambda x: _STD_ELEMS.index(x) + ) + elif self._elems in _STD_ELEMS: + self._elems = [self._elems] + elif self._elems == "all": + self._elems = _STD_ELEMS + else: + raise ValueError("invalid arg: elems") + + # gen buttons + self._buttons = self.__generate_buttons() + + def __search_button_in_default_base(self, button): + for b in globals()["_DEFAULT_BUTTONS"]: + if ( + (button.dtype == b.dtype) + and (button.indices == b.indices) + and (button.fantasy_name == b.fantasy_name) + and (self._func == b.func) + ): + return True, _dpcopy(b.signature) + return False, None + + def __generate_buttons(self): + all_buttons = [] + for dtype, sect, elem in product( + self._dtypes, self._sects, self._elems + ): + if self._use_root and self._func in [ + "vertical_disp", + "twiss", + ]: + temp_button = buttons.Button( + elem=elem, dtype=dtype, sect=sect, func="testfunc" + ).flatten() + for tb in temp_button: + flag, sig = self.__search_button_in_default_base(tb) + if flag: + tb._func = self._func + tb._signature = sig + all_buttons += [tb] + else: + tb_new = buttons.Button( + indices=tb.indices, + dtype=tb.dtype, + func=self._func, + ) + all_buttons += [tb_new] + else: + temp_button = buttons.Button( + elem=elem, dtype=dtype, sect=sect, func=self._func + ).flatten() + all_buttons += temp_button + return all_buttons + + def __handle_buttons(self, buttons): + # reading buttons + self._buttons = buttons + + if isinstance(self._buttons, (list, tuple)) and all( + isinstance(i, buttons.Button) for i in self._buttons + ): + self._buttons = self._buttons + elif isinstance(self._buttons, buttons.Button): + self._buttons = [self._buttons] + else: + raise ValueError("invalid arg: buttons") + + self._sects = [] + self._elems = [] + self._dtypes = [] + for button in self._buttons: + self._sects.append(button.sect) + self._elems.append(button.elem) + self._dtypes.append(button.dtype) + self._sects = sorted(list(set(self._sects))) + self._elems = sorted( + list(set(self._elems)), key=lambda x: _STD_ELEMS.index(x) + ) + self._dtypes = sorted( + list(set(self._dtypes)), key=lambda x: _STD_TYPES.index(x) + ) + + def __make_matrix(self): + matrix = _np.array([b.signature for b in self._buttons]).T + return matrix + + @property + def buttons(self): + """Returns the Base buttons list.""" + return self._buttons + + @property + def resp_mat(self): + """Returns the Base response matrix.""" + return self._matrix + + @property + def sectors(self): + """Returns the sectors presents in the Base.""" + return self._sects + + @property + def magnets(self): + """Returns the magnets (elements) presents in the Base.""" + return self._elems + + @property + def dtypes(self): + """Returns the misalignment types presents in the Base.""" + return self._dtypes + + def __len__(self): + """Base length (number of Buttons).""" + return len(self._buttons) + + def __eq__(self, other) -> bool: + """Comparison method.""" + if isinstance(other, Base): + for b in other.buttons(): + if b not in self.buttons(): + return False + return True + return False + + +def save_default_base(base): + """Save the input Base and its Buttons. + + Args: + base (Base): Base object to be saved as Default Base. + + About: + The save process only saves the Buttons of the input Base. Only \ + new and unsaved Buttons will be saved. Only Buttons with vertical_disp\ + signatures will be saved. + """ + update_default_base() + if base._func in ["vertical_disp", "twiss"]: + c = 0 + for b in base.buttons: + if b not in globals()["_DEFAULT_BUTTONS"]: + c += 1 + globals()["_DEFAULT_BUTTONS"].append(b) + save_pickle( + globals()["_DEFAULT_BUTTONS"], _D_BUTTONS_FILE, overwrite=True + ) + update_default_base() + if c == 0: + print("No new Buttons.") + else: + print(f"New {c} Buttons added to Default Base.") + else: + print("Nothing saved.") + + +def delete_default_base(): + """Restore the Default Base to an empty list.""" + save_pickle([], _D_BUTTONS_FILE, overwrite=True) + update_default_base() + print("Default Base/Buttons deleted!") + + +__all__ = ( + "Base", + "save_default_base", + "delete_default_base", + "update_default_base", +) diff --git a/apsuite/optics_analysis/misalign_analysis/buttons.py b/apsuite/optics_analysis/misalign_analysis/buttons.py new file mode 100644 index 00000000..8dedafa0 --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/buttons.py @@ -0,0 +1,437 @@ +"""Module 'buttons' for the class Object Button.""" + +from copy import deepcopy as _deepcopy +from importlib import reload + +import numpy as _np +import pyaccel.optics as _opt + +from apsuite.orbcorr import OrbitCorr as _OrbitCorr + +from . import functions, si_data + +_OC_MODEL = None +_OC = None +_INIT_KICKS = None +_JAC = None +_SI_SPOS = None +_fam = None +_sects_dict = None + +_anly_funcs = ["vertical_disp", "testfunc", "twiss"] +_DELTAS = si_data.std_misaligment_tolerance() +_STD_TYPES = si_data.std_misaligment_types() +_STD_ELEMS = si_data.si_elems() + + +def buttons_set_model(model=None): + """.""" + # print("entered buttons set model") + si_data.__set_model(model) + reload(functions) + # print("buttons -> updating model") + mod = si_data.get_model() + # print("buttons -> model is", type(mod)) + globals()["_OC_MODEL"] = mod + globals()["_OC"] = _OrbitCorr(globals()["_OC_MODEL"], "SI") + globals()["_OC"].params.maxnriters = 30 + globals()["_OC"].params.convergencetol = 1e-9 + globals()["_OC"].params.use6dorb = True + globals()["_JAC"] = globals()["_OC"].get_jacobian_matrix() + functions.rmk_orbit_corr( + globals()["_OC"], jacobian_matrix=globals()["_JAC"] + ) + globals()["_INIT_KICKS"] = globals()["_OC"].get_kicks() + globals()["_SI_SPOS"] = si_data.si_spos() + globals()["_fam"] = si_data.si_famdata() + globals()["_sects_dict"] = { + fam_name: [ + int(s[:2]) + for i, s in enumerate(globals()["_fam"][fam_name]["subsection"]) + ] + for fam_name in _STD_ELEMS + } + # print("sects_dict is", type(_sects_dict)) + # print("buttons -> model updated") + + +class Button: + """Button object for misalignment analysis. + + Args: + elem (str): Magnet's family name.\ + The valid options are the SIRIUS's dipoles, quadrupoles and \ + sextupoles family names.\ + The names can be followed by postfixes like: '_1' or '_2' to \ + specify the magnet. + + sect (int): The magnet sector location.\ + The valid options are the 1 to 20 sectors of SIRIUS storage ring. + + dtype (str): The misalignmente type.\ + Valid options are: 'dx'(horizontal), 'dy' (vertical), 'dr' \ + (rotation roll), 'drp' (rotation pitch) and 'dry' (totation yaw). + + func (str): The analysis function. Defaults to "vertical_disp".\ + The valid options are: 'vertical_disp' and 'testfunc'. + + indices ((int, list[int])): The index or indices of the magnet in\ + the SIRIUS pymodels's model. Its important to have the latest\ + model always up-to-date. + + """ + + def __init__( + self, + elem=None, + sect=None, + dtype=None, + func="vertical_disp", + indices=None, + ): + """.""" + if _OC_MODEL is None: + fstr = "Creating new model." + fstr += " To switch model, use function:" + fstr += " 'apsuite.optics_analysis.misalign_analysis.set_model()'" + print(fstr) + buttons_set_model() + + self._elem = elem + self._sect = sect + self._indices = indices + + if dtype in _STD_TYPES: + self._dtype = dtype + else: + raise ValueError("Invalid dtype") + + if func not in _anly_funcs: + raise ValueError("Invalid func") + self._func = func + + if indices is not None: + if any(f is not None for f in [elem, sect]): + raise ValueError("Too much args") + else: + self._indices = indices + elif indices is None and all(f is not None for f in [elem, sect]): + self._sect = sect + self._elem = elem + else: + raise ValueError("Missing input args") + + self.__force_init__() + + self._is_valid = True # if self.fantasy_name == [] else False + + self._signature = self.__calc_signature() + + def __str__(self) -> str: + """.""" + return f"({self._sect}, {self._dtype}, {self._fantasy_name})" + + def __repr__(self) -> str: + """.""" + return self.__str__() + + def __eq__(self, other) -> bool: + """.""" + try: + if ( + (self._dtype == other.dtype) + and (self._indices == other.indices) + and (self._fantasy_name == other.fantasy_name) + and (self._func == other.func) + ): + return True + return False + except Exception: + return False + + def __calc_test_signature(self): + """.""" + if isinstance(self._fantasy_name, list): + return [_np.zeros(160) for i in self._fantasy_name] + else: + return _np.zeros(160) + + def __calc_vdisp_signature(self): + """.""" + deltafunc = functions._SET_FUNCS[self._dtype] + disp = [] + delta = _DELTAS[self._dtype][self._elem[0]] + loop = [self.indices] + flag = 1 + if isinstance(self._fantasy_name, list): + loop = self.indices + flag = -1 + + for ind in loop: + disp_0 = functions.calc_vdisp(_OC_MODEL) + deltafunc(_OC_MODEL, indices=ind, values=delta) + functions.rmk_orbit_corr(_OC, _JAC) + disp_p = functions.calc_vdisp(_OC_MODEL) + disp.append(((disp_p - disp_0) / delta).ravel()) + deltafunc(_OC_MODEL, indices=ind, values=0.0) + _OC.set_kicks(_INIT_KICKS) + + if flag == 1: + return disp[0] + return disp + + def __calc_twiss_signature(self): + """.""" + deltafunc = functions._SET_FUNCS[self._dtype] + twiss = [] + delta = _DELTAS[self._dtype][self._elem[0]] + loop = [self.indices] + flag = 1 + if isinstance(self._fantasy_name, list): + loop = self.indices + flag = -1 + + for ind in loop: + twisspart = _opt.twiss.TwissArray(len(si_data.si_bpmidx())) + + deltafunc(_OC_MODEL, indices=ind, values=+delta) + functions.rmk_orbit_corr(_OC, _JAC) + twiss_p = _opt.twiss.calc_twiss( + _OC_MODEL, indices=si_data.si_bpmidx() + )[0] + + deltafunc(_OC_MODEL, indices=ind, values=-delta) + functions.rmk_orbit_corr(_OC, _JAC) + twiss_n = _opt.twiss.calc_twiss( + _OC_MODEL, indices=si_data.si_bpmidx() + )[0] + + for key in twisspart.dtype.fields.keys(): + twisspart[key] = (twiss_p[key] - twiss_n[key]) / (2 * delta) + + twiss.append(twisspart) + + deltafunc(_OC_MODEL, indices=ind, values=0.0) + _OC.set_kicks(_INIT_KICKS) + + if flag == 1: + return twiss[0] + return twiss + + def __calc_signature(self): + """.""" + if self._func == "testfunc": + return self.__calc_test_signature() + elif self._func == "vertical_disp": + return self.__calc_vdisp_signature() + else: + return self.__calc_twiss_signature() + + @property + def func(self): + """.""" + return self._func + + @property + def is_valid(self): + """.""" + return self._is_valid + + @property + def dtype(self): + """.""" + return self._dtype + + @property + def elem(self): + """.""" + return self._elem + + @property + def sect(self): + """.""" + return self._sect + + @property + def indices(self): + """.""" + return self._indices + + @property + def signature(self): + """.""" + return self._signature + + @property + def spos(self): + """.""" + return self._spos + + @property + def fantasy_name(self): + """.""" + return self._fantasy_name + + def flatten(self): + """.""" + if not isinstance(self, Button): + print("arg is not a Button object") + return + if isinstance(self.fantasy_name, list): + buttons = [] + for i in range(len(self.fantasy_name)): + b = _deepcopy(self) + b._signature = self.signature[i] + b._fantasy_name = self.fantasy_name[i] + b._indices = self.indices[i] + buttons.append(b) + return buttons + else: + return [self] + + def __force_init__(self): + """.""" + # Extract elements + elem, sect, indices = (self._elem, self._sect, self._indices) + + # Handle sector + sect = self._handle_sector(sect) + + # Handle elem + elem, fixpos = ( + self._handle_elem(elem, sect) if elem is not None else (None, -1) + ) + + # Handle indices + elem, sect, indices, split_flag = self._handle_indices( + elem, sect, indices + ) + + # Handle fantasy name + fantasy_name = self._handle_fantasy_name( + elem, sect, indices, split_flag + ) + + # Update attributes + self._update_attributes(elem, sect, fantasy_name, indices, fixpos) + + def _handle_elem(self, elem, sect): + """.""" + fixpos = -1 + elem = elem.split("_") + if len(elem) == 1: + elem = elem[0] + fixpos = -1 + else: + fixpos = int(elem[1]) + elem = elem[0] + if fixpos > _sects_dict[elem].count(sect) or fixpos <= 0: + raise ValueError("invalid postfix number") + + return elem, fixpos + + def _handle_sector(self, sect) -> int: + """.""" + if sect is not None: + if ( + not isinstance(sect, (_np.integer, int)) + or sect < 1 + or sect > 20 + ): + raise ValueError("problem with sect") + else: + return int(sect) + else: + return sect + + def _handle_indices(self, elem, sect, indices): + """.""" + split_flag = False + if indices is None: + indices = [ + _fam[elem]["index"][i] + for i, s in enumerate(_sects_dict[elem]) + if s == sect + ] + if len(indices) == 1: + if isinstance(indices[0], (list, tuple, _np.ndarray)): + indices = indices[0] + else: + split_flag = True + else: + if isinstance(indices, (int, _np.integer)): + indices = [indices] + elif isinstance(indices, (_np.ndarray, list, tuple)) and all( + isinstance(i, (_np.integer, int)) for i in indices + ): + pass + else: + raise ValueError("indices passed in wrong format") + + elem, sect, indices, split_flag = self._process_indices( + elem, sect, indices + ) + + return elem, sect, indices, split_flag + + def _process_indices(self, elem, sect, indices): + """.""" + found_elems = [ + fname + for fname in list( + set([_OC_MODEL[int(idx)].fam_name for idx in indices]) + ) + if fname in _STD_ELEMS + ] + if len(found_elems) != 1: + raise ValueError("invalid indices") + elem = found_elems.pop() + indices = [ + ind + for ind in _fam[elem]["index"] + if all(i in ind for i in indices) + ] + sect = [ + int(_fam[elem]["subsection"][i][:2]) + for i, f in enumerate(_fam[elem]["index"]) + if f in indices + ] + flag = True + if len(indices) == 1: + indices = indices[0] + sect = sect[0] + flag = False + + return elem, sect, indices, flag + + def _handle_fantasy_name(self, elem, sect, indices, split_flag): + """Handle fantasy name logic.""" + if split_flag is True: + fantasy_name = [ + elem + "_" + str(i + 1) for i in range(len(indices)) + ] + elif _sects_dict[elem].count(sect) > 1: + c = 0 + for ind, sec in zip(_fam[elem]["index"], _fam[elem]["subsection"]): + # print("ind:", ind, "\nsec:", sec, "\nindices:", indices) + if int(sec[:2]) == sect: + c += 1 + if ind == indices: + break + fantasy_name = elem + "_" + str(c) + else: + fantasy_name = elem + return fantasy_name + + def _update_attributes(self, elem, sect, fantasy_name, indices, fixpos): + """Update instance attributes.""" + self._elem = elem + self._sect = sect + self._fantasy_name = fantasy_name + self._indices = indices + self._spos = [_SI_SPOS[i] for i in indices] + + if fixpos != -1: + self._fantasy_name = self._fantasy_name[fixpos - 1] + self._indices = self._indices[fixpos - 1] + self._spos = self._spos[fixpos - 1] diff --git a/apsuite/optics_analysis/misalign_analysis/fitting.py b/apsuite/optics_analysis/misalign_analysis/fitting.py new file mode 100644 index 00000000..c69aec0f --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/fitting.py @@ -0,0 +1,95 @@ +"""Fitting module to run dispersion fitting and analisys.""" + +import numpy as _np +from pymodels import si as _si + +from apsuite.orbcorr import OrbitCorr as _OrbitCorr + +from .functions import calc_disp as _disp, calc_pinv as _pinv, \ + rmk_orbit_corr as _correct_orbit, set_errors as _set_errors + + +def fit( + base, + dispy_meta, + nr_iters=5, + orbcorr_obj=None, + orbcorr_jac=None, + svals="auto", + svd_cut=1e-3, + model=None, +): + """.""" + imat, goal, oc, jac = _handle_input( + base, + dispy_meta, + nr_iters, + orbcorr_obj, + orbcorr_jac, + svals, + svd_cut, + model, + ) + + dispy, deltas = _fitting_loop(base, imat, goal, oc, jac, nr_iters) + + return dispy, deltas, oc.respm.model + + +def _handle_input( + base, dispy_meta, nr_iters, orbcorr_obj, orbcorr_jac, svals, cut, model +): + mat = base.resp_mat + imat = _pinv(mat, svals=svals, cut=cut) + + if ( + isinstance(dispy_meta, (list, tuple, _np.ndarray)) + and len(dispy_meta) == 160 + ): + dispy_meta = _np.array(dispy_meta) + else: + raise ValueError("Invalid Dispy") + + nr_iters = int(nr_iters) + + if all(i is not None for i in [model, orbcorr_obj]): + raise ValueError("too much args: model and orbcorr_obj") + + if model is not None: + mod = model + oc = _OrbitCorr(model, "SI") + elif model is None and orbcorr_obj is None: + mod = _si.create_accelerator() + oc = _OrbitCorr(mod, "SI") + else: + oc = orbcorr_obj + mod = oc.respm.model + + if orbcorr_jac is not None: + jac = orbcorr_jac + else: + jac = oc.get_jacobian_matrix() + + return imat, dispy_meta, oc, jac + + +def _fitting_loop(base, imat, dispy_meta, oc, jac, nr_iters): + count = 0 + mod = oc.respm.model + fulldeltas = _np.zeros(len(base)) + for _ in range(nr_iters): + disp = _disp(mod) + ddispy = dispy_meta - disp[160:] + deltas = _np.dot(imat, ddispy) + fulldeltas += deltas + _set_errors(mod, base, fulldeltas) + try: + _correct_orbit(oc, jac) + except Exception: + fulldeltas -= deltas + _set_errors(mod, base, fulldeltas) + _correct_orbit(oc, jac) + break + count += 1 + finaldispy = _disp(mod)[160:] + return finaldispy, fulldeltas, count diff --git a/apsuite/optics_analysis/misalign_analysis/functions.py b/apsuite/optics_analysis/misalign_analysis/functions.py new file mode 100644 index 00000000..478fbb4f --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/functions.py @@ -0,0 +1,224 @@ +"""Miscellaneous functions to work with 'Button' and 'Base' objects.""" + +import numpy as _np +import pyaccel as _pyaccel + +_SET_FUNCS = { + "dx": _pyaccel.lattice.set_error_misalignment_x, + "dy": _pyaccel.lattice.set_error_misalignment_y, + "dr": _pyaccel.lattice.set_error_rotation_roll, + "drp": _pyaccel.lattice.set_error_rotation_pitch, + "dry": _pyaccel.lattice.set_error_rotation_yaw, +} + +_GET_FUNCS = { + "dx": _pyaccel.lattice.get_error_misalignment_x, + "dy": _pyaccel.lattice.get_error_misalignment_y, + "dr": _pyaccel.lattice.get_error_rotation_roll, + "drp": _pyaccel.lattice.get_error_rotation_pitch, + "dry": _pyaccel.lattice.get_error_rotation_yaw, +} + +_ADD_FUNCS = { + "dx": _pyaccel.lattice.add_error_misalignment_x, + "dy": _pyaccel.lattice.add_error_misalignment_y, + "dr": _pyaccel.lattice.add_error_rotation_roll, + "drp": _pyaccel.lattice.add_error_rotation_pitch, + "dry": _pyaccel.lattice.add_error_rotation_yaw, +} + + +def get_error(model, button): + """.""" + return _GET_FUNCS[button.dtype](model, indices=[button.indices[0]]) + + +def get_errors(model, base): + """.""" + errors = [] + for button in base.buttons: + errors.append(get_error(model, button)) + return _np.array(errors) + + +def set_error(model, button, error): + """.""" + if isinstance(error, (_np.int_, _np.float64, float, int)): + _SET_FUNCS[button.dtype](model, indices=button.indices, values=error) + elif len(error) == len(button.indices): + _SET_FUNCS[button.dtype](model, indices=button.indices, values=error) + else: + raise ValueError("problem with deltas") + + +def set_errors(model, base, errors): + """.""" + if len(errors) != len(base): + raise ValueError('"errors" size is incompatible with "base" size') + for i, button in enumerate(base.buttons): + set_error(model, button, errors[i]) + + +def add_delta_error(model, button, delta): + """.""" + if isinstance(delta, (_np.int_, _np.float64, float, int)): + _ADD_FUNCS[button.dtype](model, indices=button.indices, values=delta) + elif len(delta) == len(button.indices): + _ADD_FUNCS[button.dtype](model, indices=button.indices, values=delta) + else: + raise ValueError("problem with delta") + + +def add_delta_errors(model, base, deltas): + """.""" + if len(deltas) != len(base): + raise ValueError('"deltas" size is incompatible with "base" size') + for i, button in enumerate(base.buttons): + add_delta_error(model, button, deltas[i]) + + +def remove_delta_errors(model, base, deltas): + """.""" + for i, button in enumerate(base.buttons): + _ADD_FUNCS[button.dtype]( + model, indices=button.indices, values=-deltas[i] + ) + + +def add_error_ksl(lattice, indices, values): + """.""" + if isinstance(values, list): + pass + elif isinstance(values, (int, float)): + values = [values] + else: + raise ValueError("values in wrong format") + for i, ind in enumerate(indices): + lattice[ind].KsL += values[i] + + +def calc_rms(vec): + """.""" + return float((_np.mean(vec * vec)) ** 0.5) + + +def calc_vdisp(model, indices="bpm"): + """.""" + disp = calc_disp(model=model, indices=indices) + return disp[int(len(disp) / 2) :] + + +def calc_hdisp(model, indices="bpm"): + """.""" + disp = calc_disp(model=model, indices=indices) + return disp[: int(len(disp) / 2)] + + +def calc_disp(model, indices="bpm"): + """.""" + if indices not in ["bpm", "closed", "open"]: + raise ValueError( + 'Invalid indices parameter: \ + should be "bpm" or "open" or "closed"!' + ) + if indices == "bpm": + indices = _pyaccel.lattice.find_indices(model, "fam_name", "BPM") + orbp = _pyaccel.tracking.find_orbit4( + model, indices=indices, energy_offset=+1e-6 + ) + orbn = _pyaccel.tracking.find_orbit4( + model, indices=indices, energy_offset=-1e-6 + ) + return _np.hstack( + [ + (orbp[0, :] - orbn[0, :]) / (2e-6), + (orbp[2, :] - orbn[2, :]) / (2e-6), + ] + ) + + +def calc_pinv(matrix, **kwargs): + """.""" + svals = "auto" + cut = 5e-3 + return_svd = False + if "svals" in kwargs: + svals = kwargs["svals"] + if "cut" in kwargs: + cut = kwargs["cut"] + if "return_svd" in kwargs: + return_svd = kwargs["return_svd"] + u, smat, vh = _np.linalg.svd(matrix, full_matrices=False) + if isinstance(svals, (_np.integer, int)): + ismat = _np.zeros_like(smat) + ismat += 1 / smat + ismat[svals:] = 0 + elif isinstance(svals, str): + if svals == "all": + ismat = 1 / smat + if svals == "auto": + ismat = _np.array( + [1 / s if s >= cut * smat[0] else 0 for s in smat] + ) + else: + raise ValueError('"svals" should be int or string: "all" or "auto"') + imat = vh.T @ _np.diag(ismat) @ u.T + if return_svd: + return imat, u, smat, vh, len(_np.nonzero(ismat)[0]) + else: + return imat + + +def rmk_orbit_corr(orbcorr_obj, jacobian_matrix=None, goal_orbit=None): + """.""" + if goal_orbit is None: + nbpm = len(orbcorr_obj.respm.bpm_idx) + goal_orbit = _np.zeros(2 * nbpm, dtype=float) + + jmat = jacobian_matrix + if jmat is None: + jmat = orbcorr_obj.get_jacobian_matrix() + + ismat = orbcorr_obj.get_inverse_matrix(jmat) + + orb = orbcorr_obj.get_orbit() + dorb = orb - goal_orbit + bestfigm = orbcorr_obj.get_figm(dorb) + maxit = 0 + for _ in range(orbcorr_obj.params.maxnriters): + dkicks = -1 * _np.dot(ismat, dorb) + kicks, saturation_flag = orbcorr_obj._process_kicks(dkicks) + if saturation_flag: + return orbcorr_obj.CORR_STATUS.SaturationFail, maxit + orbcorr_obj.set_kicks(kicks) + maxit += 1 + orb = orbcorr_obj.get_orbit() + dorb = orb - goal_orbit + figm = orbcorr_obj.get_figm(dorb) + diff_figm = _np.abs(bestfigm - figm) + if figm < bestfigm: + bestfigm = figm + if diff_figm < orbcorr_obj.params.convergencetol: + if bestfigm <= orbcorr_obj.params.orbrmswarnthres: + return orbcorr_obj.CORR_STATUS.Sucess, maxit + else: + return orbcorr_obj.CORR_STATUS.OrbRMSWarning, maxit + if orbcorr_obj.params.updatejacobian: + jmat = orbcorr_obj.get_jacobian_matrix() + ismat = orbcorr_obj.get_inverse_matrix(jmat) + return orbcorr_obj.CORR_STATUS.ConvergenceFail, maxit + + +__all__ = ( + "calc_pinv", + "calc_rms", + "calc_disp", + "calc_vdisp", + "calc_hdisp", + "get_error", + "get_errors", + "set_error", + "set_errors", + "add_delta_error", + "add_delta_errors", +) diff --git a/apsuite/optics_analysis/misalign_analysis/si_data.py b/apsuite/optics_analysis/misalign_analysis/si_data.py new file mode 100644 index 00000000..82b05687 --- /dev/null +++ b/apsuite/optics_analysis/misalign_analysis/si_data.py @@ -0,0 +1,150 @@ +"""Standart "PyModels Sirius model" data collection.""" + +from copy import deepcopy as _deepcopy + +import pymodels as _pymodels +from pyaccel.lattice import find_indices as _find_indices, \ + find_spos as _find_spos + + +def std_misaligment_tolerance(): + """Default misalignment and rotation expected error.""" + return { + "dx": {"B": 40e-6, "Q": 40e-6, "S": 40e-6}, + "dy": {"B": 40e-6, "Q": 40e-6, "S": 40e-6}, + "dr": {"B": 0.3e-3, "Q": 0.3e-3, "S": 0.3e-3}, + "drp": {"B": 0.3e-3, "Q": 0.3e-3, "S": 0.3e-3}, + "dry": {"B": 0.3e-3, "Q": 0.3e-3, "S": 0.3e-3}, + } + + +def std_misaligment_types(): + """Default misalignment types. + + 'dr' : Rotation roll misalignment (theta) + 'dx' : Tranverse-horizontal misalignment (X) + 'dy' : Tranverse-vertical misalignment (Y) + 'drp' : Rotation pitch misalignment + 'dry' : Rotation yaw misalignment + """ + return ["dx", "dy", "dr", "drp", "dry"] + + +def si_sectors(): + """Default sectors of SIRIUS ring.""" + return list(range(1, 21)) + + +def si_dipoles(): + """Default SIRIUS dipoles families.""" + return ["B1", "B2", "BC"] + + +def si_quadrupoles(): + """Default SIRIUS quadrupoles families.""" + return [ + "QFA", + "QDA", + "QFB", + "QDB1", + "QDB2", + "QFP", + "QDP1", + "QDP2", + "Q1", + "Q2", + "Q3", + "Q4", + ] + + +def si_sextupoles(): + """Default SIRIUS sextupoles names.""" + return [ + "SDA0", + "SDA1", + "SDA2", + "SDA3", + "SFA0", + "SFA1", + "SFA2", + "SDB0", + "SDB1", + "SDB2", + "SDB3", + "SFB0", + "SFB1", + "SFB2", + "SDP0", + "SDP1", + "SDP2", + "SDP3", + "SFP0", + "SFP1", + "SFP2", + ] + + +def si_elems(): + """Default SIRIUS elements families.""" + return si_dipoles() + si_quadrupoles() + si_sextupoles() + + +__model = None +__bpmidx = None +__spos = None + + +def __update_model(): + globals()["__model"].radiation_on = True + globals()["__model"].vchamber_on = True + globals()["__model"].cavity_on = True + globals()["__spos"] = _find_spos(globals()["__model"], indices="closed") + globals()["__bpmidx"] = _find_indices( + globals()["__model"], "fam_name", "BPM" + ) + # print("si_data -> model is", type(globals()["__model"])) + + +def __set_model(model=None): + """.""" + # print("entered si_data set model") + if model is None: + # print("si_data model -> none") + globals()["__model"] = _pymodels.si.create_accelerator() + else: + # print("si_data model -> model") + globals()["__model"] = _deepcopy(model) + __update_model() + + +def get_model(): + """SIRIUS model. + + Returns: + pyaccel.accelerator: standart SIRIUS model + """ + if __model is None: + __set_model() + return __model + + +def si_spos(): + """Default SIRIUS longitudinal coordinates of the lattice elements.""" + if __model is None: + __set_model() + return _deepcopy(__spos) + + +def si_famdata(): + """Default SIRIUS families data.""" + if __model is None: + __set_model() + return _deepcopy(_pymodels.si.families.get_family_data(__model)) + + +def si_bpmidx(): + """Default SIRIUS BPM's indices.""" + if __model is None: + __set_model() + return _deepcopy(__bpmidx)