Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pybaselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
__version__ = '1.2.1.post1.dev0'

# import utils first since it is imported by other modules; likewise, import
# optimizers and api last since they import the other modules
# api last since it imports the other modules
from . import (
utils, classification, misc, morphological, polynomial, spline, whittaker, smooth,
optimizers, api
Expand Down
118 changes: 57 additions & 61 deletions pybaselines/_algorithm_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from ._banded_utils import PenalizedSystem
from ._nd.optimizers import _OptimizerHelper
from ._spline_utils import PSpline, SplineBasis
from ._validation import (
_check_array, _check_half_window, _check_optional_array, _check_scalar_variable,
Expand Down Expand Up @@ -857,75 +858,61 @@ def _setup_classification(self, y, weights=None, **kwargs):

return y, weight_array

def _get_function(self, method, modules, ensure_new=False):
def _spawn_fitter(self, method, ensure_new=False):
"""
Tries to retrieve the indicated function from a list of modules.
Creates an appropriate fitting object for the indicated method.

Parameters
----------
method : str
The string name of the desired function. Case does not matter.
modules : Sequence
A sequence of modules in which to look for the method.
The string name of the desired method.
ensure_new : bool, optional
If True, will ensure that the output `class_object` and `func`
If True, will ensure that the output `class_object`
correspond to a new object rather than `self`.

Returns
-------
func : Callable
The corresponding function.
func_module : str
The module that `func` belongs to.
class_object : pybaselines._algorithm_setup._Algorithm
The `_Algorithm` object which will be used for fitting.

Raises
------
AttributeError
Raised if no matching function is found within the modules.
Raised if `method` is not an available Baseline method.

"""
function_string = method.lower()
self_has = hasattr(self, function_string)
for module in modules:
if hasattr(module, function_string):
func_module = module.__name__.split('.')[-1]
# if self is a Baseline class, can just use its method
if self_has and not ensure_new:
func = getattr(self, function_string)
class_object = self
else:
if self_has:
klass = self.__class__
else:
klass = getattr(module, '_' + func_module.capitalize())
# have to reset x ordering so that all outputs and parameters are
# correctly sorted
if self._sort_order is not None:
x = self.x[self._inverted_order]
assume_sorted = False
else:
x = self.x
assume_sorted = True
class_object = klass(
x, check_finite=self._check_finite, assume_sorted=assume_sorted,
output_dtype=self._dtype
)
class_object.banded_solver = self.banded_solver
func = getattr(class_object, function_string)
break
else: # in case no break
mod_names = [module.__name__ for module in modules]
raise AttributeError((
f'unknown method "{method}" or method is not within the allowed '
f'modules: {mod_names}'
))

return func, func_module, class_object

def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True,
ensure_new=False):
self_has = hasattr(self, method)

# if self is a Baseline class, can just use its method
if self_has and not ensure_new:
class_object = self
else:
if self_has:
klass = self.__class__
else:
# just directly use Baseline rather than the individual private classes
from .api import Baseline
if not hasattr(Baseline, method):
raise AttributeError(f'{method} is not a valid method')
klass = Baseline
# have to reset x ordering so that all outputs and parameters are
# correctly sorted
if self._sort_order is not None:
x = self.x[self._inverted_order]
assume_sorted = False
else:
x = self.x
assume_sorted = True
class_object = klass(
x, check_finite=self._check_finite, assume_sorted=assume_sorted,
output_dtype=self._dtype
)
class_object.banded_solver = self.banded_solver

return class_object

def _setup_optimizer(self, y, method, method_param=None, method_kwargs=None, copy_kwargs=True,
ensure_new=False, needed_params=None):
"""
Sets the starting parameters for doing optimizer algorithms.

Expand All @@ -936,8 +923,13 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
array by :meth:`~._Algorithm._handle_io`.
method : str
The string name of the desired function, like 'asls'. Case does not matter.
modules : Sequence[module, ...]
The modules to search for the indicated `method` function.
method_param : dict, optional
A dictionary indicating potential parameter keys to use, with the default having
a key of None. For example, a `method_param` of {'method1': 'a', None: ('b', 'c')}
would specify that parameter 'a' should be used for `method`='method1'; otherwise,
either 'b' or 'c' could be potential parameters, which would then be filtered by
looking at the signature of the indicated method. Default is None, which indicates
that the optimizer method being used does not require any parameter key.
method_kwargs : dict, optional
A dictionary of keyword arguments to pass to the fitting function. Default
is None, which uses an empty dictionary.
Expand All @@ -950,27 +942,31 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
thread safety for methods which would modify internal state not typically
assumed to change when using threading, such as changing polynomial degrees.
Default is False.
needed_params : Iterable, optional
An iterature of other necessary parameter keys that the method must have in its
signature. For example ['weights', 'tol'] would error if either 'weights' or 'tol'
are not valid inputs. Default is None.

Returns
-------
y : numpy.ndarray, shape (N,)
The y-values of the measured data, converted to a numpy array.
baseline_func : Callable
The function for fitting the baseline.
func_module : str
The string name of the module that contained `fit_func`.
optimizer_obj : _OptimizerHelper
The object containing the fitting object to use and all relevant fields
for optimizer-type methods.
method_kws : dict
A dictionary of keyword arguments to pass to `fit_func`.
class_object : pybaselines._algorithm_setup._Algorithm
The `_Algorithm` object which will be used for fitting.

Raises
------
KeyError
Raised if method_kwargs has the 'x_data' key.

"""
baseline_func, func_module, class_object = self._get_function(method, modules, ensure_new)
optimizer_obj = _OptimizerHelper(
method, self, ensure_new=ensure_new, method_param=method_param,
needed_params=needed_params
)
if method_kwargs is None:
method_kws = {}
elif copy_kwargs:
Expand All @@ -981,7 +977,7 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
if 'x_data' in method_kws:
raise KeyError('"x_data" should not be within the method keyword arguments')

return y, baseline_func, func_module, method_kws, class_object
return y, optimizer_obj, method_kws

def _setup_misc(self, y):
"""
Expand Down
2 changes: 1 addition & 1 deletion pybaselines/_nd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

"""

from . import morphological, pls, polynomial
from . import morphological, optimizers, pls, polynomial
Loading