From f4241c5fdbbe5ec43b45dd7fb51eb132fc0ec32f Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Mon, 15 Dec 2025 22:43:58 +0100 Subject: [PATCH 1/2] fix(dap): lmfit1d can pass any parameters with kwargs --- bec_lib/bec_lib/lmfit_serializer.py | 18 ++++- bec_lib/tests/test_lmfit_serializer.py | 7 ++ .../data_processing/lmfit1d_service.py | 78 +++++++++++++++++-- .../test_lmfit1d_service.py | 31 ++++++++ 4 files changed, 123 insertions(+), 11 deletions(-) diff --git a/bec_lib/bec_lib/lmfit_serializer.py b/bec_lib/bec_lib/lmfit_serializer.py index dee651cf2..c2e167f51 100644 --- a/bec_lib/bec_lib/lmfit_serializer.py +++ b/bec_lib/bec_lib/lmfit_serializer.py @@ -52,7 +52,7 @@ def serialize_lmfit_params(params: Parameters) -> dict: return {v.name: serialize_param_object(v) for v in params} -def deserialize_param_object(obj: dict) -> Parameter: +def deserialize_param_object(obj: dict) -> Parameters: """ Deserialize dictionary representation of lmfit.Parameter object. @@ -64,6 +64,18 @@ def deserialize_param_object(obj: dict) -> Parameter: """ param = Parameters() for k, v in obj.items(): - v.pop("name") - param.add(k, **v) + if isinstance(v, Parameter): + param.add( + k, + value=v.value, + vary=v.vary, + min=v.min, + max=v.max, + expr=v.expr, + brute_step=v.brute_step, + ) + continue + if isinstance(v, dict): + v.pop("name", None) + param.add(k, **v) return param diff --git a/bec_lib/tests/test_lmfit_serializer.py b/bec_lib/tests/test_lmfit_serializer.py index e73613d48..2b04e9d8b 100644 --- a/bec_lib/tests/test_lmfit_serializer.py +++ b/bec_lib/tests/test_lmfit_serializer.py @@ -33,3 +33,10 @@ def test_serialize_lmfit_params(): obj = deserialize_param_object(result) assert obj == params + + # `name` is optional for deserialization (key is the param name) + result_without_names = { + k: {kk: vv for kk, vv in v.items() if kk != "name"} for k, v in result.items() + } + obj = deserialize_param_object(result_without_names) + assert obj == params diff --git a/bec_server/bec_server/data_processing/lmfit1d_service.py b/bec_server/bec_server/data_processing/lmfit1d_service.py index ab16dbfc8..1bfb1cdd2 100644 --- a/bec_server/bec_server/data_processing/lmfit1d_service.py +++ b/bec_server/bec_server/data_processing/lmfit1d_service.py @@ -41,6 +41,7 @@ def __init__(self, model: str, *args, continuous: bool = False, **kwargs): self.device_y = None self.signal_y = None self.parameters = None + self._parameter_override_names = [] self.current_scan_item = None self.finished_id = None self.model = getattr(lmfit.models, model)() @@ -169,6 +170,7 @@ def configure( data_y: np.ndarray = None, x_min: float = None, x_max: float = None, + parameters: dict | None = None, amplitude: lmfit.Parameter = None, center: lmfit.Parameter = None, sigma: lmfit.Parameter = None, @@ -195,15 +197,61 @@ def configure( self.oversample = oversample - self.parameters = {} + raw_parameters: dict = {} + if parameters: + if isinstance(parameters, lmfit.Parameters): + raw_parameters.update(parameters) + elif isinstance(parameters, dict): + raw_parameters.update(parameters) + else: + raise DAPError( + f"Invalid parameters type {type(parameters)}. Expected dict or lmfit.Parameters." + ) if amplitude: - self.parameters["amplitude"] = amplitude + raw_parameters["amplitude"] = amplitude if center: - self.parameters["center"] = center + raw_parameters["center"] = center if sigma: - self.parameters["sigma"] = sigma - - self.parameters = deserialize_param_object(self.parameters) + raw_parameters["sigma"] = sigma + + override_params = deserialize_param_object(raw_parameters) + if override_params: + valid_names = set(getattr(self.model, "param_names", [])) + if valid_names: + invalid_names = set(override_params.keys()) - valid_names + for name in invalid_names: + logger.warning( + f"Ignoring unknown lmfit parameter '{name}' for model '{self.model.__class__.__name__}'." + ) + override_params.pop(name, None) + + self._parameter_override_names = list(getattr(override_params, "keys", lambda: [])()) + if override_params: + # If `params=` is provided to lmfit, it must contain ALL parameters. + # Start from model defaults and apply overrides on top. + full_params = self.model.make_params() + for name, override in override_params.items(): + if name not in full_params: + continue + full_params[name].set( + value=override.value, + vary=override.vary, + min=override.min, + max=override.max, + expr=override.expr, + brute_step=getattr(override, "brute_step", None), + ) + self.parameters = full_params + logger.info( + f"Configured lmfit model={self.model.__class__.__name__} with override_params={serialize_lmfit_params(override_params)}" + ) + else: + self.parameters = None + if parameters or amplitude or center or sigma: + logger.info( + f"No usable lmfit parameter overrides after validation for model={self.model.__class__.__name__} " + f"(input_keys={list(raw_parameters.keys())})" + ) if data_x is not None and data_y is not None: self.data = { @@ -354,10 +402,24 @@ def process(self) -> tuple[dict, dict]: y = self.data["y"] # fit the data + model_name = self.model.__class__.__name__ if self.parameters: - result = self.model.fit(y, x=x, params=self.parameters) + logger.info( + f"Running lmfit fit: model={model_name} points={len(x)} fixed/override_params={self._parameter_override_names}" + ) else: - result = self.model.fit(y, x=x) + logger.info(f"Running lmfit fit: model={model_name} points={len(x)} params=") + + try: + if self.parameters: + result = self.model.fit(y, x=x, params=self.parameters) + else: + result = self.model.fit(y, x=x) + except Exception as exc: # pylint: disable=broad-except + logger.warning( + f"lmfit fit failed: model={model_name} points={len(x)} parameters={serialize_lmfit_params(self.parameters)} error={exc}" + ) + raise # if the fit was only on a subset of the data, add the original x values to the output if self.data["x_lim"] or self.oversample != 1: diff --git a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py index 9e16ba0cb..9b56fc06f 100644 --- a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py +++ b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py @@ -161,6 +161,37 @@ def test_LmfitService1D_configure_selected_devices(lmfit_service): get_data.assert_called_once() +def test_LmfitService1D_configure_accepts_generic_parameters_and_filters_invalid(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + lmfit_service.configure( + data_x=x, + data_y=y, + parameters={"amplitude": {"value": 1.0, "vary": False}, "frequency": {"value": 2.0}}, + ) + assert lmfit_service.parameters["amplitude"].value == 1.0 + assert lmfit_service.parameters["amplitude"].vary is False + assert "frequency" not in lmfit_service.parameters + + +def test_LmfitService1D_configure_parameters_work_for_sine_model(): + if not hasattr(lmfit.models, "SineModel"): + pytest.skip("lmfit.models.SineModel not available in this environment") + service = LmfitService1D(model="SineModel", continuous=False, client=mock.MagicMock()) + x = np.linspace(0.0, 2.0 * np.pi, 25) + y = np.sin(x) + service.configure( + data_x=x, + data_y=y, + parameters={"frequency": {"value": 1.0, "vary": False}, "center": {"value": 0.0}}, + ) + assert service.parameters["frequency"].value == 1.0 + assert service.parameters["frequency"].vary is False + assert "center" not in service.parameters + assert "amplitude" in service.parameters + assert "shift" in service.parameters + + def test_LmfitService1D_get_model(lmfit_service): model = lmfit_service.get_model("GaussianModel") assert model.__name__ == "GaussianModel" From 3398749e9a2eb315afba12b25bf482bf143b3949 Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Tue, 16 Dec 2025 11:21:45 +0100 Subject: [PATCH 2/2] fix(dap): lmfit1d can pass any parameters with kwargs; cleanup --- bec_lib/bec_lib/lmfit_serializer.py | 6 +++-- bec_lib/tests/test_lmfit_serializer.py | 9 +++++++ .../data_processing/lmfit1d_service.py | 26 ++++++++++++------- .../test_lmfit1d_service.py | 17 ++++++++++++ 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/bec_lib/bec_lib/lmfit_serializer.py b/bec_lib/bec_lib/lmfit_serializer.py index c2e167f51..88d88210d 100644 --- a/bec_lib/bec_lib/lmfit_serializer.py +++ b/bec_lib/bec_lib/lmfit_serializer.py @@ -60,7 +60,7 @@ def deserialize_param_object(obj: dict) -> Parameters: obj (dict): Dictionary representation of the parameters Returns: - Parameter: Parameter object + Parameters: Parameters object """ param = Parameters() for k, v in obj.items(): @@ -77,5 +77,7 @@ def deserialize_param_object(obj: dict) -> Parameters: continue if isinstance(v, dict): v.pop("name", None) - param.add(k, **v) + v_copy = v.copy() + v_copy.pop("name", None) + param.add(k, **v_copy) return param diff --git a/bec_lib/tests/test_lmfit_serializer.py b/bec_lib/tests/test_lmfit_serializer.py index 2b04e9d8b..c5c487c6a 100644 --- a/bec_lib/tests/test_lmfit_serializer.py +++ b/bec_lib/tests/test_lmfit_serializer.py @@ -40,3 +40,12 @@ def test_serialize_lmfit_params(): } obj = deserialize_param_object(result_without_names) assert obj == params + + +def test_deserialize_param_object_accepts_parameter_objects(): + params = lmfit.Parameters() + params.add("a", value=1.0, vary=True, min=-2.0, max=3.0) + params.add("b", value=2.0, vary=False) + + obj = deserialize_param_object({"a": params["a"], "b": params["b"]}) + assert obj == params diff --git a/bec_server/bec_server/data_processing/lmfit1d_service.py b/bec_server/bec_server/data_processing/lmfit1d_service.py index 1bfb1cdd2..260f5b6a4 100644 --- a/bec_server/bec_server/data_processing/lmfit1d_service.py +++ b/bec_server/bec_server/data_processing/lmfit1d_service.py @@ -200,7 +200,7 @@ def configure( raw_parameters: dict = {} if parameters: if isinstance(parameters, lmfit.Parameters): - raw_parameters.update(parameters) + raw_parameters.update({name: param for name, param in parameters.items()}) elif isinstance(parameters, dict): raw_parameters.update(parameters) else: @@ -215,7 +215,7 @@ def configure( raw_parameters["sigma"] = sigma override_params = deserialize_param_object(raw_parameters) - if override_params: + if len(override_params) > 0: valid_names = set(getattr(self.model, "param_names", [])) if valid_names: invalid_names = set(override_params.keys()) - valid_names @@ -225,14 +225,12 @@ def configure( ) override_params.pop(name, None) - self._parameter_override_names = list(getattr(override_params, "keys", lambda: [])()) - if override_params: + self._parameter_override_names = list(override_params.keys()) + if len(override_params) > 0: # If `params=` is provided to lmfit, it must contain ALL parameters. # Start from model defaults and apply overrides on top. full_params = self.model.make_params() for name, override in override_params.items(): - if name not in full_params: - continue full_params[name].set( value=override.value, vary=override.vary, @@ -390,9 +388,12 @@ def get_data_from_current_scan( "scan_data": True, } - def process(self) -> tuple[dict, dict]: + def process(self) -> tuple[dict, dict] | None: """ Process data and return the result. + + Returns: + tuple[dict, dict]: Processed data and metadata if successful, None otherwise. """ # get the data if not self.data: @@ -416,10 +417,17 @@ def process(self) -> tuple[dict, dict]: else: result = self.model.fit(y, x=x) except Exception as exc: # pylint: disable=broad-except + if self.parameters is not None: + try: + params_str = serialize_lmfit_params(self.parameters) + except Exception as ser_exc: + params_str = f"" + else: + params_str = "" logger.warning( - f"lmfit fit failed: model={model_name} points={len(x)} parameters={serialize_lmfit_params(self.parameters)} error={exc}" + f"lmfit fit failed: model={model_name} points={len(x)} parameters={params_str} error={exc}" ) - raise + return # if the fit was only on a subset of the data, add the original x values to the output if self.data["x_lim"] or self.oversample != 1: diff --git a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py index 9b56fc06f..ff40bc319 100644 --- a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py +++ b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py @@ -174,6 +174,23 @@ def test_LmfitService1D_configure_accepts_generic_parameters_and_filters_invalid assert "frequency" not in lmfit_service.parameters +def test_LmfitService1D_configure_accepts_lmfit_parameters_object(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + params = lmfit.models.GaussianModel().make_params() + params["amplitude"].set(value=1.0, vary=False) + lmfit_service.configure(data_x=x, data_y=y, parameters=params) + assert lmfit_service.parameters["amplitude"].value == 1.0 + assert lmfit_service.parameters["amplitude"].vary is False + + +def test_LmfitService1D_configure_invalid_parameters_type_raises(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + with pytest.raises(DAPError): + lmfit_service.configure(data_x=x, data_y=y, parameters=["amplitude", 1.0]) # type: ignore[arg-type] + + def test_LmfitService1D_configure_parameters_work_for_sine_model(): if not hasattr(lmfit.models, "SineModel"): pytest.skip("lmfit.models.SineModel not available in this environment")