Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions bec_lib/bec_lib/lmfit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,32 @@ def serialize_lmfit_params(params: Parameters) -> dict:
return {v.name: serialize_param_object(v) for v in params}


def deserialize_param_object(obj: dict) -> Parameter:
def deserialize_param_object(obj: dict) -> Parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the type hint could benefit from the additional dict description

obj: dict[str, dict | Parameter] 

same in the doc string

"""
Deserialize dictionary representation of lmfit.Parameter object.

Args:
obj (dict): Dictionary representation of the parameters

Returns:
Parameter: Parameter object
Parameters: Parameters object
"""
param = Parameters()
for k, v in obj.items():
v.pop("name")
param.add(k, **v)
if isinstance(v, Parameter):
param.add(
k,
value=v.value,
vary=v.vary,
min=v.min,
max=v.max,
expr=v.expr,
brute_step=v.brute_step,
)
continue
if isinstance(v, dict):
v.pop("name", None)
v_copy = v.copy()
v_copy.pop("name", None)
param.add(k, **v_copy)
return param
16 changes: 16 additions & 0 deletions bec_lib/tests/test_lmfit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,19 @@ def test_serialize_lmfit_params():

obj = deserialize_param_object(result)
assert obj == params

# `name` is optional for deserialization (key is the param name)
result_without_names = {
k: {kk: vv for kk, vv in v.items() if kk != "name"} for k, v in result.items()
}
obj = deserialize_param_object(result_without_names)
assert obj == params


def test_deserialize_param_object_accepts_parameter_objects():
params = lmfit.Parameters()
params.add("a", value=1.0, vary=True, min=-2.0, max=3.0)
params.add("b", value=2.0, vary=False)

obj = deserialize_param_object({"a": params["a"], "b": params["b"]})
assert obj == params
88 changes: 79 additions & 9 deletions bec_server/bec_server/data_processing/lmfit1d_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, model: str, *args, continuous: bool = False, **kwargs):
self.device_y = None
self.signal_y = None
self.parameters = None
self._parameter_override_names = []
self.current_scan_item = None
self.finished_id = None
self.model = getattr(lmfit.models, model)()
Expand Down Expand Up @@ -169,6 +170,7 @@ def configure(
data_y: np.ndarray = None,
x_min: float = None,
x_max: float = None,
parameters: dict | None = None,
amplitude: lmfit.Parameter = None,
center: lmfit.Parameter = None,
sigma: lmfit.Parameter = None,
Expand All @@ -195,15 +197,59 @@ def configure(

self.oversample = oversample

self.parameters = {}
raw_parameters: dict = {}
if parameters:
if isinstance(parameters, lmfit.Parameters):
raw_parameters.update({name: param for name, param in parameters.items()})
elif isinstance(parameters, dict):
raw_parameters.update(parameters)
else:
raise DAPError(
f"Invalid parameters type {type(parameters)}. Expected dict or lmfit.Parameters."
)
if amplitude:
self.parameters["amplitude"] = amplitude
raw_parameters["amplitude"] = amplitude
if center:
self.parameters["center"] = center
raw_parameters["center"] = center
if sigma:
self.parameters["sigma"] = sigma

self.parameters = deserialize_param_object(self.parameters)
raw_parameters["sigma"] = sigma

override_params = deserialize_param_object(raw_parameters)
if len(override_params) > 0:
valid_names = set(getattr(self.model, "param_names", []))
if valid_names:
invalid_names = set(override_params.keys()) - valid_names
for name in invalid_names:
logger.warning(
f"Ignoring unknown lmfit parameter '{name}' for model '{self.model.__class__.__name__}'."
)
override_params.pop(name, None)

self._parameter_override_names = list(override_params.keys())
if len(override_params) > 0:
# If `params=` is provided to lmfit, it must contain ALL parameters.
# Start from model defaults and apply overrides on top.
full_params = self.model.make_params()
for name, override in override_params.items():
full_params[name].set(
value=override.value,
vary=override.vary,
min=override.min,
max=override.max,
expr=override.expr,
brute_step=getattr(override, "brute_step", None),
)
self.parameters = full_params
logger.info(
f"Configured lmfit model={self.model.__class__.__name__} with override_params={serialize_lmfit_params(override_params)}"
)
else:
self.parameters = None
if parameters or amplitude or center or sigma:
logger.info(
f"No usable lmfit parameter overrides after validation for model={self.model.__class__.__name__} "
f"(input_keys={list(raw_parameters.keys())})"
)

if data_x is not None and data_y is not None:
self.data = {
Expand Down Expand Up @@ -342,9 +388,12 @@ def get_data_from_current_scan(
"scan_data": True,
}

def process(self) -> tuple[dict, dict]:
def process(self) -> tuple[dict, dict] | None:
"""
Process data and return the result.

Returns:
tuple[dict, dict]: Processed data and metadata if successful, None otherwise.
"""
# get the data
if not self.data:
Expand All @@ -354,10 +403,31 @@ def process(self) -> tuple[dict, dict]:
y = self.data["y"]

# fit the data
model_name = self.model.__class__.__name__
if self.parameters:
result = self.model.fit(y, x=x, params=self.parameters)
logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these logs could be potentially very verbose. Maybe put them on debug?

f"Running lmfit fit: model={model_name} points={len(x)} fixed/override_params={self._parameter_override_names}"
)
else:
result = self.model.fit(y, x=x)
logger.info(f"Running lmfit fit: model={model_name} points={len(x)} params=<default>")

try:
if self.parameters:
result = self.model.fit(y, x=x, params=self.parameters)
else:
result = self.model.fit(y, x=x)
except Exception as exc: # pylint: disable=broad-except
if self.parameters is not None:
try:
params_str = serialize_lmfit_params(self.parameters)
except Exception as ser_exc:
params_str = f"<serialization failed: {ser_exc}>"
else:
params_str = "<None>"
logger.warning(
f"lmfit fit failed: model={model_name} points={len(x)} parameters={params_str} error={exc}"
)
return

# if the fit was only on a subset of the data, add the original x values to the output
if self.data["x_lim"] or self.oversample != 1:
Expand Down
48 changes: 48 additions & 0 deletions bec_server/tests/tests_data_processing/test_lmfit1d_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,54 @@ def test_LmfitService1D_configure_selected_devices(lmfit_service):
get_data.assert_called_once()


def test_LmfitService1D_configure_accepts_generic_parameters_and_filters_invalid(lmfit_service):
x = np.linspace(-1.0, 1.0, 15)
y = np.exp(-(x**2))
lmfit_service.configure(
data_x=x,
data_y=y,
parameters={"amplitude": {"value": 1.0, "vary": False}, "frequency": {"value": 2.0}},
)
assert lmfit_service.parameters["amplitude"].value == 1.0
assert lmfit_service.parameters["amplitude"].vary is False
assert "frequency" not in lmfit_service.parameters


def test_LmfitService1D_configure_accepts_lmfit_parameters_object(lmfit_service):
x = np.linspace(-1.0, 1.0, 15)
y = np.exp(-(x**2))
params = lmfit.models.GaussianModel().make_params()
params["amplitude"].set(value=1.0, vary=False)
lmfit_service.configure(data_x=x, data_y=y, parameters=params)
assert lmfit_service.parameters["amplitude"].value == 1.0
assert lmfit_service.parameters["amplitude"].vary is False


def test_LmfitService1D_configure_invalid_parameters_type_raises(lmfit_service):
x = np.linspace(-1.0, 1.0, 15)
y = np.exp(-(x**2))
with pytest.raises(DAPError):
lmfit_service.configure(data_x=x, data_y=y, parameters=["amplitude", 1.0]) # type: ignore[arg-type]


def test_LmfitService1D_configure_parameters_work_for_sine_model():
if not hasattr(lmfit.models, "SineModel"):
pytest.skip("lmfit.models.SineModel not available in this environment")
service = LmfitService1D(model="SineModel", continuous=False, client=mock.MagicMock())
x = np.linspace(0.0, 2.0 * np.pi, 25)
y = np.sin(x)
service.configure(
data_x=x,
data_y=y,
parameters={"frequency": {"value": 1.0, "vary": False}, "center": {"value": 0.0}},
)
assert service.parameters["frequency"].value == 1.0
assert service.parameters["frequency"].vary is False
assert "center" not in service.parameters
assert "amplitude" in service.parameters
assert "shift" in service.parameters


def test_LmfitService1D_get_model(lmfit_service):
model = lmfit_service.get_model("GaussianModel")
assert model.__name__ == "GaussianModel"
Expand Down
Loading