Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pybaselines/_algorithm_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _setup_classification(self, y, weights=None, **kwargs):

return y, weight_array

def _get_function(self, method, modules):
def _get_function(self, method, modules, ensure_new=False):
"""
Tries to retrieve the indicated function from a list of modules.

Expand All @@ -785,6 +785,9 @@ def _get_function(self, method, modules):
The string name of the desired function. Case does not matter.
modules : Sequence
A sequence of modules in which to look for the method.
ensure_new : bool, optional
If True, will ensure that the output `class_object` and `func`
correspond to a new object rather than `self`.

Returns
-------
Expand All @@ -802,14 +805,19 @@ def _get_function(self, method, modules):

"""
function_string = method.lower()
self_has = hasattr(self, function_string)
for module in modules:
if hasattr(module, function_string):
func_module = module.__name__.split('.')[-1]
# if self is a Baseline class, can just use its method
if hasattr(self, function_string):
if self_has and not ensure_new:
func = getattr(self, function_string)
class_object = self
else:
if self_has:
klass = self.__class__
else:
klass = getattr(module, '_' + func_module.capitalize())
# have to reset x ordering so that all outputs and parameters are
# correctly sorted
if self._sort_order is not None:
Expand All @@ -818,7 +826,7 @@ def _get_function(self, method, modules):
else:
x = self.x
assume_sorted = True
class_object = getattr(module, '_' + func_module.capitalize())(
class_object = klass(
x, check_finite=self._check_finite, assume_sorted=assume_sorted,
output_dtype=self._dtype
)
Expand All @@ -834,7 +842,8 @@ def _get_function(self, method, modules):

return func, func_module, class_object

def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True):
def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True,
ensure_new=False):
"""
Sets the starting parameters for doing optimizer algorithms.

Expand All @@ -853,6 +862,12 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
copy_kwargs : bool, optional
If True (default), will copy the input `method_kwargs` so that the input
dictionary is not modified within the function.
ensure_new : bool, optional
If True, will ensure that the output `class_object` and `baseline_func`
correspond to a new object rather than `self`. This is to ensure
thread safety for methods which would modify internal state not typically
assumed to change when using threading, such as changing polynomial degrees.
Default is False.

Returns
-------
Expand All @@ -873,7 +888,7 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
Raised if method_kwargs has the 'x_data' key.

"""
baseline_func, func_module, class_object = self._get_function(method, modules)
baseline_func, func_module, class_object = self._get_function(method, modules, ensure_new)
if method_kwargs is None:
method_kws = {}
elif copy_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion pybaselines/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def adaptive_minmax(self, data, poly_order=None, method='modpoly', weights=None,

"""
y, baseline_func, _, method_kws, _ = self._setup_optimizer(
data, method, [polynomial], method_kwargs, False
data, method, [polynomial], method_kwargs, False, ensure_new=True
)
sort_weights = weights is not None
weight_array = _check_optional_array(self._size, weights, check_finite=self._check_finite)
Expand Down
26 changes: 19 additions & 7 deletions pybaselines/two_d/_algorithm_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def _setup_classification(self, y, weights=None):

return y, weight_array

def _get_function(self, method, modules):
def _get_function(self, method, modules, ensure_new=False):
"""
Tries to retrieve the indicated function from a list of modules.

Expand All @@ -859,6 +859,9 @@ def _get_function(self, method, modules):
The string name of the desired function. Case does not matter.
modules : Sequence
A sequence of modules in which to look for the method.
ensure_new : bool, optional
If True, will ensure that the output `class_object` and `func`
correspond to a new object rather than `self`.

Returns
-------
Expand All @@ -876,15 +879,17 @@ def _get_function(self, method, modules):

"""
function_string = method.lower()
self_has = hasattr(self, function_string)
for module in modules:
func_module = module.__name__.split('.')[-1]
module_class = getattr(module, '_' + func_module.capitalize())
if hasattr(module_class, function_string):
# if self is a Baseline2D class, can just use its method
if hasattr(self, function_string):
if self_has and not ensure_new:
func = getattr(self, function_string)
class_object = self
else:
klass = self.__class__ if self_has else module_class
# have to reset x and z ordering so that all outputs and parameters are
# correctly sorted
if self._sort_order is None:
Expand All @@ -904,7 +909,7 @@ def _get_function(self, method, modules):
x = self.x[self._inverted_order]
z = self.z

class_object = module_class(
class_object = klass(
x, z, check_finite=self._check_finite, assume_sorted=assume_sorted,
output_dtype=self._dtype
)
Expand All @@ -920,7 +925,8 @@ def _get_function(self, method, modules):

return func, func_module, class_object

def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True):
def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True,
ensure_new=False):
"""
Sets the starting parameters for doing optimizer algorithms.

Expand All @@ -939,6 +945,12 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
copy_kwargs : bool, optional
If True (default), will copy the input `method_kwargs` so that the input
dictionary is not modified within the function.
ensure_new : bool, optional
If True, will ensure that the output `class_object` and `baseline_func`
correspond to a new object rather than `self`. This is to ensure
thread safety for methods which would modify internal state not typically
assumed to change when using threading, such as changing polynomial degrees.
Default is False.

Returns
-------
Expand All @@ -950,11 +962,11 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T
The string name of the module that contained `fit_func`.
method_kws : dict
A dictionary of keyword arguments to pass to `fit_func`.
class_object : pybaselines._algorithm_setup._Algorithm
The `_Algorithm` object which will be used for fitting.
class_object : pybaselines.two_d._algorithm_setup._Algorithm2D
The `_Algorithm2D` object which will be used for fitting.

"""
baseline_func, func_module, class_object = self._get_function(method, modules)
baseline_func, func_module, class_object = self._get_function(method, modules, ensure_new)
if method_kwargs is None:
method_kws = {}
elif copy_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion pybaselines/two_d/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def adaptive_minmax(self, data, poly_order=None, method='modpoly', weights=None,

"""
y, baseline_func, _, method_kws, _ = self._setup_optimizer(
data, method, [polynomial], method_kwargs, False
data, method, [polynomial], method_kwargs, False, ensure_new=True
)
sort_weights = weights is not None
weight_array = _check_optional_array(
Expand Down
25 changes: 3 additions & 22 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

import pybaselines
from pybaselines import Baseline, Baseline2D


def ensure_deprecation(deprecation_major, deprecation_minor):
Expand Down Expand Up @@ -318,14 +319,6 @@ def changing_dataset2d(data_size=(40, 33), dataset_size=20, three_d=False):
return x, z, dataset


def dummy_wrapper(func):
"""A dummy wrapper to simulate using the _Algorithm._register wrapper function."""
@wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner


class DummyModule:
"""A dummy object to serve as a fake module."""

Expand All @@ -335,18 +328,6 @@ def func(*args, data=None, x_data=None, **kwargs):
raise NotImplementedError('need to set func')


class DummyAlgorithm:
"""A dummy object to serve as a fake Algorithm subclass."""

def __init__(self, *args, **kwargs):
pass

@dummy_wrapper
def func(self, data=None, *args, **kwargs):
"""Dummy function."""
raise NotImplementedError('need to set func')


def check_param_keys(expected_keys, output_keys):
"""
Ensures the output keys within the parameter dictionary matched the expected keys.
Expand Down Expand Up @@ -391,7 +372,7 @@ class BaseTester:
"""

module = DummyModule
algorithm_base = DummyAlgorithm
algorithm_base = Baseline
func_name = 'func'
checked_keys = None
required_kwargs = None
Expand Down Expand Up @@ -763,7 +744,7 @@ class BaseTester2D:
"""

module = DummyModule
algorithm_base = DummyAlgorithm
algorithm_base = Baseline2D
func_name = 'func'
checked_keys = None
required_kwargs = None
Expand Down
54 changes: 43 additions & 11 deletions tests/test_algorithm_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.testing import assert_allclose, assert_array_equal
import pytest

from pybaselines import _algorithm_setup, optimizers, polynomial, whittaker
from pybaselines import Baseline, _algorithm_setup, optimizers, polynomial, whittaker
from pybaselines._compat import dia_object
from pybaselines.utils import ParameterWarning, SortingWarning, estimate_window

Expand Down Expand Up @@ -576,16 +576,23 @@ def test_setup_misc(small_data, algorithm):
('asls', 'asls', 'whittaker')
)
)
def test_get_function(algorithm, method_and_outputs):
@pytest.mark.parametrize('ensure_new', (True, False))
def test_get_function(method_and_outputs, ensure_new):
"""Ensures _get_function gets the correct method, regardless of case."""
method, expected_func, expected_module = method_and_outputs
tested_modules = [optimizers, polynomial, whittaker]

algorithm = Baseline(np.arange(10), assume_sorted=False)
selected_func, module, class_object = algorithm._get_function(
method, tested_modules
method, tested_modules, ensure_new=ensure_new
)
assert selected_func.__name__ == expected_func
assert module == expected_module
assert isinstance(class_object, _algorithm_setup._Algorithm)
if ensure_new:
assert class_object is not algorithm
else:
assert class_object is algorithm


def test_get_function_fails_wrong_method(algorithm):
Expand All @@ -600,26 +607,35 @@ def test_get_function_fails_no_module(algorithm):
algorithm._get_function('collab_pls', [])


def test_get_function_sorting():
@pytest.mark.parametrize('ensure_new', (True, False))
def test_get_function_sorting(ensure_new):
"""Ensures the sort order is correct for the output class object."""
num_points = 10
x = np.arange(num_points)
ordering = np.arange(num_points)
algorithm = _algorithm_setup._Algorithm(x[::-1], assume_sorted=False)
func, func_module, class_object = algorithm._get_function('asls', [whittaker])
algorithm = Baseline(x[::-1], assume_sorted=False)
func, func_module, class_object = algorithm._get_function(
'asls', [whittaker], ensure_new=ensure_new
)

assert_array_equal(class_object.x, x)
assert_array_equal(class_object._sort_order, ordering[::-1])
assert_array_equal(class_object._inverted_order, ordering[::-1])
assert_array_equal(class_object._sort_order, algorithm._sort_order)
assert_array_equal(class_object._inverted_order, algorithm._inverted_order)
if ensure_new:
assert class_object is not algorithm
else:
assert class_object is algorithm


@pytest.mark.parametrize('method_kwargs', (None, {'a': 2}))
def test_setup_optimizer(small_data, algorithm, method_kwargs):
@pytest.mark.parametrize('ensure_new', (True, False))
def test_setup_optimizer(small_data, method_kwargs, ensure_new):
"""Ensures output of _setup_optimizer is correct."""
algorithm = Baseline(np.arange(len(small_data)))
y, fit_func, func_module, output_kwargs, class_object = algorithm._setup_optimizer(
small_data, 'asls', [whittaker], method_kwargs
small_data, 'asls', [whittaker], method_kwargs, ensure_new=ensure_new
)

assert isinstance(y, np.ndarray)
Expand All @@ -628,6 +644,10 @@ def test_setup_optimizer(small_data, algorithm, method_kwargs):
assert func_module == 'whittaker'
assert isinstance(output_kwargs, dict)
assert isinstance(class_object, _algorithm_setup._Algorithm)
if ensure_new:
assert class_object is not algorithm
else:
assert class_object is algorithm


@pytest.mark.parametrize('copy_kwargs', (True, False))
Expand Down Expand Up @@ -978,8 +998,12 @@ def test_override_x(algorithm):
assert algorithm._shape == (old_size,)


def test_override_x_polynomial(algorithm):
def test_override_x_polynomial():
"""Ensures the polynomial attributes are correctly reset and then returned by override_x."""
algorithm = _algorithm_setup._Algorithm(
x_data=np.arange(10), assume_sorted=True, check_finite=False
)

old_len = len(algorithm.x)
poly_order = 2
new_poly_order = 3
Expand Down Expand Up @@ -1014,8 +1038,12 @@ def test_override_x_polynomial(algorithm):
assert algorithm._polynomial.poly_order == poly_order


def test_override_x_whittaker(algorithm):
def test_override_x_whittaker():
"""Ensures the whittaker attributes are correctly reset and then returned by override_x."""
algorithm = _algorithm_setup._Algorithm(
x_data=np.arange(10), assume_sorted=True, check_finite=False
)

old_len = len(algorithm.x)
diff_order = 2
new_diff_order = 3
Expand Down Expand Up @@ -1045,8 +1073,12 @@ def test_override_x_whittaker(algorithm):
assert new_whittaker_system.penta_solver == banded_solver


def test_override_x_spline(algorithm):
def test_override_x_spline():
"""Ensures the spline attributes are correctly reset and then returned by override_x."""
algorithm = _algorithm_setup._Algorithm(
x_data=np.arange(10), assume_sorted=True, check_finite=False
)

old_len = len(algorithm.x)
spline_degree = 2
new_spline_degree = 3
Expand Down
1 change: 0 additions & 1 deletion tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ class ClassificationTester(BaseTester, InputWeightsMixin):
"""Base testing class for classification functions."""

module = classification
algorithm_base = classification._Classification
checked_keys = ('mask',)
weight_keys = ('mask',)
requires_unique_x = True
Expand Down
Loading
Loading