diff --git a/bec_lib/bec_lib/lmfit_serializer.py b/bec_lib/bec_lib/lmfit_serializer.py index dee651cf2..88d88210d 100644 --- a/bec_lib/bec_lib/lmfit_serializer.py +++ b/bec_lib/bec_lib/lmfit_serializer.py @@ -52,7 +52,7 @@ def serialize_lmfit_params(params: Parameters) -> dict: return {v.name: serialize_param_object(v) for v in params} -def deserialize_param_object(obj: dict) -> Parameter: +def deserialize_param_object(obj: dict) -> Parameters: """ Deserialize dictionary representation of lmfit.Parameter object. @@ -60,10 +60,24 @@ def deserialize_param_object(obj: dict) -> Parameter: obj (dict): Dictionary representation of the parameters Returns: - Parameter: Parameter object + Parameters: Parameters object """ param = Parameters() for k, v in obj.items(): - v.pop("name") - param.add(k, **v) + if isinstance(v, Parameter): + param.add( + k, + value=v.value, + vary=v.vary, + min=v.min, + max=v.max, + expr=v.expr, + brute_step=v.brute_step, + ) + continue + if isinstance(v, dict): + v.pop("name", None) + v_copy = v.copy() + v_copy.pop("name", None) + param.add(k, **v_copy) return param diff --git a/bec_lib/tests/test_lmfit_serializer.py b/bec_lib/tests/test_lmfit_serializer.py index e73613d48..c5c487c6a 100644 --- a/bec_lib/tests/test_lmfit_serializer.py +++ b/bec_lib/tests/test_lmfit_serializer.py @@ -33,3 +33,19 @@ def test_serialize_lmfit_params(): obj = deserialize_param_object(result) assert obj == params + + # `name` is optional for deserialization (key is the param name) + result_without_names = { + k: {kk: vv for kk, vv in v.items() if kk != "name"} for k, v in result.items() + } + obj = deserialize_param_object(result_without_names) + assert obj == params + + +def test_deserialize_param_object_accepts_parameter_objects(): + params = lmfit.Parameters() + params.add("a", value=1.0, vary=True, min=-2.0, max=3.0) + params.add("b", value=2.0, vary=False) + + obj = deserialize_param_object({"a": params["a"], "b": params["b"]}) + assert obj == params diff --git a/bec_server/bec_server/data_processing/lmfit1d_service.py b/bec_server/bec_server/data_processing/lmfit1d_service.py index ab16dbfc8..260f5b6a4 100644 --- a/bec_server/bec_server/data_processing/lmfit1d_service.py +++ b/bec_server/bec_server/data_processing/lmfit1d_service.py @@ -41,6 +41,7 @@ def __init__(self, model: str, *args, continuous: bool = False, **kwargs): self.device_y = None self.signal_y = None self.parameters = None + self._parameter_override_names = [] self.current_scan_item = None self.finished_id = None self.model = getattr(lmfit.models, model)() @@ -169,6 +170,7 @@ def configure( data_y: np.ndarray = None, x_min: float = None, x_max: float = None, + parameters: dict | None = None, amplitude: lmfit.Parameter = None, center: lmfit.Parameter = None, sigma: lmfit.Parameter = None, @@ -195,15 +197,59 @@ def configure( self.oversample = oversample - self.parameters = {} + raw_parameters: dict = {} + if parameters: + if isinstance(parameters, lmfit.Parameters): + raw_parameters.update({name: param for name, param in parameters.items()}) + elif isinstance(parameters, dict): + raw_parameters.update(parameters) + else: + raise DAPError( + f"Invalid parameters type {type(parameters)}. Expected dict or lmfit.Parameters." + ) if amplitude: - self.parameters["amplitude"] = amplitude + raw_parameters["amplitude"] = amplitude if center: - self.parameters["center"] = center + raw_parameters["center"] = center if sigma: - self.parameters["sigma"] = sigma - - self.parameters = deserialize_param_object(self.parameters) + raw_parameters["sigma"] = sigma + + override_params = deserialize_param_object(raw_parameters) + if len(override_params) > 0: + valid_names = set(getattr(self.model, "param_names", [])) + if valid_names: + invalid_names = set(override_params.keys()) - valid_names + for name in invalid_names: + logger.warning( + f"Ignoring unknown lmfit parameter '{name}' for model '{self.model.__class__.__name__}'." + ) + override_params.pop(name, None) + + self._parameter_override_names = list(override_params.keys()) + if len(override_params) > 0: + # If `params=` is provided to lmfit, it must contain ALL parameters. + # Start from model defaults and apply overrides on top. + full_params = self.model.make_params() + for name, override in override_params.items(): + full_params[name].set( + value=override.value, + vary=override.vary, + min=override.min, + max=override.max, + expr=override.expr, + brute_step=getattr(override, "brute_step", None), + ) + self.parameters = full_params + logger.info( + f"Configured lmfit model={self.model.__class__.__name__} with override_params={serialize_lmfit_params(override_params)}" + ) + else: + self.parameters = None + if parameters or amplitude or center or sigma: + logger.info( + f"No usable lmfit parameter overrides after validation for model={self.model.__class__.__name__} " + f"(input_keys={list(raw_parameters.keys())})" + ) if data_x is not None and data_y is not None: self.data = { @@ -342,9 +388,12 @@ def get_data_from_current_scan( "scan_data": True, } - def process(self) -> tuple[dict, dict]: + def process(self) -> tuple[dict, dict] | None: """ Process data and return the result. + + Returns: + tuple[dict, dict]: Processed data and metadata if successful, None otherwise. """ # get the data if not self.data: @@ -354,10 +403,31 @@ def process(self) -> tuple[dict, dict]: y = self.data["y"] # fit the data + model_name = self.model.__class__.__name__ if self.parameters: - result = self.model.fit(y, x=x, params=self.parameters) + logger.info( + f"Running lmfit fit: model={model_name} points={len(x)} fixed/override_params={self._parameter_override_names}" + ) else: - result = self.model.fit(y, x=x) + logger.info(f"Running lmfit fit: model={model_name} points={len(x)} params=") + + try: + if self.parameters: + result = self.model.fit(y, x=x, params=self.parameters) + else: + result = self.model.fit(y, x=x) + except Exception as exc: # pylint: disable=broad-except + if self.parameters is not None: + try: + params_str = serialize_lmfit_params(self.parameters) + except Exception as ser_exc: + params_str = f"" + else: + params_str = "" + logger.warning( + f"lmfit fit failed: model={model_name} points={len(x)} parameters={params_str} error={exc}" + ) + return # if the fit was only on a subset of the data, add the original x values to the output if self.data["x_lim"] or self.oversample != 1: diff --git a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py index 9e16ba0cb..ff40bc319 100644 --- a/bec_server/tests/tests_data_processing/test_lmfit1d_service.py +++ b/bec_server/tests/tests_data_processing/test_lmfit1d_service.py @@ -161,6 +161,54 @@ def test_LmfitService1D_configure_selected_devices(lmfit_service): get_data.assert_called_once() +def test_LmfitService1D_configure_accepts_generic_parameters_and_filters_invalid(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + lmfit_service.configure( + data_x=x, + data_y=y, + parameters={"amplitude": {"value": 1.0, "vary": False}, "frequency": {"value": 2.0}}, + ) + assert lmfit_service.parameters["amplitude"].value == 1.0 + assert lmfit_service.parameters["amplitude"].vary is False + assert "frequency" not in lmfit_service.parameters + + +def test_LmfitService1D_configure_accepts_lmfit_parameters_object(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + params = lmfit.models.GaussianModel().make_params() + params["amplitude"].set(value=1.0, vary=False) + lmfit_service.configure(data_x=x, data_y=y, parameters=params) + assert lmfit_service.parameters["amplitude"].value == 1.0 + assert lmfit_service.parameters["amplitude"].vary is False + + +def test_LmfitService1D_configure_invalid_parameters_type_raises(lmfit_service): + x = np.linspace(-1.0, 1.0, 15) + y = np.exp(-(x**2)) + with pytest.raises(DAPError): + lmfit_service.configure(data_x=x, data_y=y, parameters=["amplitude", 1.0]) # type: ignore[arg-type] + + +def test_LmfitService1D_configure_parameters_work_for_sine_model(): + if not hasattr(lmfit.models, "SineModel"): + pytest.skip("lmfit.models.SineModel not available in this environment") + service = LmfitService1D(model="SineModel", continuous=False, client=mock.MagicMock()) + x = np.linspace(0.0, 2.0 * np.pi, 25) + y = np.sin(x) + service.configure( + data_x=x, + data_y=y, + parameters={"frequency": {"value": 1.0, "vary": False}, "center": {"value": 0.0}}, + ) + assert service.parameters["frequency"].value == 1.0 + assert service.parameters["frequency"].vary is False + assert "center" not in service.parameters + assert "amplitude" in service.parameters + assert "shift" in service.parameters + + def test_LmfitService1D_get_model(lmfit_service): model = lmfit_service.get_model("GaussianModel") assert model.__name__ == "GaussianModel"