diff --git a/pybaselines/_algorithm_setup.py b/pybaselines/_algorithm_setup.py index 5c4d49b..889a9f8 100644 --- a/pybaselines/_algorithm_setup.py +++ b/pybaselines/_algorithm_setup.py @@ -775,7 +775,7 @@ def _setup_classification(self, y, weights=None, **kwargs): return y, weight_array - def _get_function(self, method, modules): + def _get_function(self, method, modules, ensure_new=False): """ Tries to retrieve the indicated function from a list of modules. @@ -785,6 +785,9 @@ def _get_function(self, method, modules): The string name of the desired function. Case does not matter. modules : Sequence A sequence of modules in which to look for the method. + ensure_new : bool, optional + If True, will ensure that the output `class_object` and `func` + correspond to a new object rather than `self`. Returns ------- @@ -802,14 +805,19 @@ def _get_function(self, method, modules): """ function_string = method.lower() + self_has = hasattr(self, function_string) for module in modules: if hasattr(module, function_string): func_module = module.__name__.split('.')[-1] # if self is a Baseline class, can just use its method - if hasattr(self, function_string): + if self_has and not ensure_new: func = getattr(self, function_string) class_object = self else: + if self_has: + klass = self.__class__ + else: + klass = getattr(module, '_' + func_module.capitalize()) # have to reset x ordering so that all outputs and parameters are # correctly sorted if self._sort_order is not None: @@ -818,7 +826,7 @@ def _get_function(self, method, modules): else: x = self.x assume_sorted = True - class_object = getattr(module, '_' + func_module.capitalize())( + class_object = klass( x, check_finite=self._check_finite, assume_sorted=assume_sorted, output_dtype=self._dtype ) @@ -834,7 +842,8 @@ def _get_function(self, method, modules): return func, func_module, class_object - def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True): + def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True, + ensure_new=False): """ Sets the starting parameters for doing optimizer algorithms. @@ -853,6 +862,12 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T copy_kwargs : bool, optional If True (default), will copy the input `method_kwargs` so that the input dictionary is not modified within the function. + ensure_new : bool, optional + If True, will ensure that the output `class_object` and `baseline_func` + correspond to a new object rather than `self`. This is to ensure + thread safety for methods which would modify internal state not typically + assumed to change when using threading, such as changing polynomial degrees. + Default is False. Returns ------- @@ -873,7 +888,7 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T Raised if method_kwargs has the 'x_data' key. """ - baseline_func, func_module, class_object = self._get_function(method, modules) + baseline_func, func_module, class_object = self._get_function(method, modules, ensure_new) if method_kwargs is None: method_kws = {} elif copy_kwargs: diff --git a/pybaselines/optimizers.py b/pybaselines/optimizers.py index 38a59a6..3db648c 100644 --- a/pybaselines/optimizers.py +++ b/pybaselines/optimizers.py @@ -478,7 +478,7 @@ def adaptive_minmax(self, data, poly_order=None, method='modpoly', weights=None, """ y, baseline_func, _, method_kws, _ = self._setup_optimizer( - data, method, [polynomial], method_kwargs, False + data, method, [polynomial], method_kwargs, False, ensure_new=True ) sort_weights = weights is not None weight_array = _check_optional_array(self._size, weights, check_finite=self._check_finite) diff --git a/pybaselines/two_d/_algorithm_setup.py b/pybaselines/two_d/_algorithm_setup.py index 328cabc..41271e6 100644 --- a/pybaselines/two_d/_algorithm_setup.py +++ b/pybaselines/two_d/_algorithm_setup.py @@ -849,7 +849,7 @@ def _setup_classification(self, y, weights=None): return y, weight_array - def _get_function(self, method, modules): + def _get_function(self, method, modules, ensure_new=False): """ Tries to retrieve the indicated function from a list of modules. @@ -859,6 +859,9 @@ def _get_function(self, method, modules): The string name of the desired function. Case does not matter. modules : Sequence A sequence of modules in which to look for the method. + ensure_new : bool, optional + If True, will ensure that the output `class_object` and `func` + correspond to a new object rather than `self`. Returns ------- @@ -876,15 +879,17 @@ def _get_function(self, method, modules): """ function_string = method.lower() + self_has = hasattr(self, function_string) for module in modules: func_module = module.__name__.split('.')[-1] module_class = getattr(module, '_' + func_module.capitalize()) if hasattr(module_class, function_string): # if self is a Baseline2D class, can just use its method - if hasattr(self, function_string): + if self_has and not ensure_new: func = getattr(self, function_string) class_object = self else: + klass = self.__class__ if self_has else module_class # have to reset x and z ordering so that all outputs and parameters are # correctly sorted if self._sort_order is None: @@ -904,7 +909,7 @@ def _get_function(self, method, modules): x = self.x[self._inverted_order] z = self.z - class_object = module_class( + class_object = klass( x, z, check_finite=self._check_finite, assume_sorted=assume_sorted, output_dtype=self._dtype ) @@ -920,7 +925,8 @@ def _get_function(self, method, modules): return func, func_module, class_object - def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True): + def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=True, + ensure_new=False): """ Sets the starting parameters for doing optimizer algorithms. @@ -939,6 +945,12 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T copy_kwargs : bool, optional If True (default), will copy the input `method_kwargs` so that the input dictionary is not modified within the function. + ensure_new : bool, optional + If True, will ensure that the output `class_object` and `baseline_func` + correspond to a new object rather than `self`. This is to ensure + thread safety for methods which would modify internal state not typically + assumed to change when using threading, such as changing polynomial degrees. + Default is False. Returns ------- @@ -950,11 +962,11 @@ def _setup_optimizer(self, y, method, modules, method_kwargs=None, copy_kwargs=T The string name of the module that contained `fit_func`. method_kws : dict A dictionary of keyword arguments to pass to `fit_func`. - class_object : pybaselines._algorithm_setup._Algorithm - The `_Algorithm` object which will be used for fitting. + class_object : pybaselines.two_d._algorithm_setup._Algorithm2D + The `_Algorithm2D` object which will be used for fitting. """ - baseline_func, func_module, class_object = self._get_function(method, modules) + baseline_func, func_module, class_object = self._get_function(method, modules, ensure_new) if method_kwargs is None: method_kws = {} elif copy_kwargs: diff --git a/pybaselines/two_d/optimizers.py b/pybaselines/two_d/optimizers.py index 4ded2d7..bf337f8 100644 --- a/pybaselines/two_d/optimizers.py +++ b/pybaselines/two_d/optimizers.py @@ -215,7 +215,7 @@ def adaptive_minmax(self, data, poly_order=None, method='modpoly', weights=None, """ y, baseline_func, _, method_kws, _ = self._setup_optimizer( - data, method, [polynomial], method_kwargs, False + data, method, [polynomial], method_kwargs, False, ensure_new=True ) sort_weights = weights is not None weight_array = _check_optional_array( diff --git a/tests/base_tests.py b/tests/base_tests.py index 9f3cdaf..ec5c247 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -15,6 +15,7 @@ import pytest import pybaselines +from pybaselines import Baseline, Baseline2D def ensure_deprecation(deprecation_major, deprecation_minor): @@ -318,14 +319,6 @@ def changing_dataset2d(data_size=(40, 33), dataset_size=20, three_d=False): return x, z, dataset -def dummy_wrapper(func): - """A dummy wrapper to simulate using the _Algorithm._register wrapper function.""" - @wraps(func) - def inner(*args, **kwargs): - return func(*args, **kwargs) - return inner - - class DummyModule: """A dummy object to serve as a fake module.""" @@ -335,18 +328,6 @@ def func(*args, data=None, x_data=None, **kwargs): raise NotImplementedError('need to set func') -class DummyAlgorithm: - """A dummy object to serve as a fake Algorithm subclass.""" - - def __init__(self, *args, **kwargs): - pass - - @dummy_wrapper - def func(self, data=None, *args, **kwargs): - """Dummy function.""" - raise NotImplementedError('need to set func') - - def check_param_keys(expected_keys, output_keys): """ Ensures the output keys within the parameter dictionary matched the expected keys. @@ -391,7 +372,7 @@ class BaseTester: """ module = DummyModule - algorithm_base = DummyAlgorithm + algorithm_base = Baseline func_name = 'func' checked_keys = None required_kwargs = None @@ -763,7 +744,7 @@ class BaseTester2D: """ module = DummyModule - algorithm_base = DummyAlgorithm + algorithm_base = Baseline2D func_name = 'func' checked_keys = None required_kwargs = None diff --git a/tests/test_algorithm_setup.py b/tests/test_algorithm_setup.py index dc83013..1a32b8f 100644 --- a/tests/test_algorithm_setup.py +++ b/tests/test_algorithm_setup.py @@ -10,7 +10,7 @@ from numpy.testing import assert_allclose, assert_array_equal import pytest -from pybaselines import _algorithm_setup, optimizers, polynomial, whittaker +from pybaselines import Baseline, _algorithm_setup, optimizers, polynomial, whittaker from pybaselines._compat import dia_object from pybaselines.utils import ParameterWarning, SortingWarning, estimate_window @@ -576,16 +576,23 @@ def test_setup_misc(small_data, algorithm): ('asls', 'asls', 'whittaker') ) ) -def test_get_function(algorithm, method_and_outputs): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function(method_and_outputs, ensure_new): """Ensures _get_function gets the correct method, regardless of case.""" method, expected_func, expected_module = method_and_outputs tested_modules = [optimizers, polynomial, whittaker] + + algorithm = Baseline(np.arange(10), assume_sorted=False) selected_func, module, class_object = algorithm._get_function( - method, tested_modules + method, tested_modules, ensure_new=ensure_new ) assert selected_func.__name__ == expected_func assert module == expected_module assert isinstance(class_object, _algorithm_setup._Algorithm) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm def test_get_function_fails_wrong_method(algorithm): @@ -600,26 +607,35 @@ def test_get_function_fails_no_module(algorithm): algorithm._get_function('collab_pls', []) -def test_get_function_sorting(): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function_sorting(ensure_new): """Ensures the sort order is correct for the output class object.""" num_points = 10 x = np.arange(num_points) ordering = np.arange(num_points) - algorithm = _algorithm_setup._Algorithm(x[::-1], assume_sorted=False) - func, func_module, class_object = algorithm._get_function('asls', [whittaker]) + algorithm = Baseline(x[::-1], assume_sorted=False) + func, func_module, class_object = algorithm._get_function( + 'asls', [whittaker], ensure_new=ensure_new + ) assert_array_equal(class_object.x, x) assert_array_equal(class_object._sort_order, ordering[::-1]) assert_array_equal(class_object._inverted_order, ordering[::-1]) assert_array_equal(class_object._sort_order, algorithm._sort_order) assert_array_equal(class_object._inverted_order, algorithm._inverted_order) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm @pytest.mark.parametrize('method_kwargs', (None, {'a': 2})) -def test_setup_optimizer(small_data, algorithm, method_kwargs): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_setup_optimizer(small_data, method_kwargs, ensure_new): """Ensures output of _setup_optimizer is correct.""" + algorithm = Baseline(np.arange(len(small_data))) y, fit_func, func_module, output_kwargs, class_object = algorithm._setup_optimizer( - small_data, 'asls', [whittaker], method_kwargs + small_data, 'asls', [whittaker], method_kwargs, ensure_new=ensure_new ) assert isinstance(y, np.ndarray) @@ -628,6 +644,10 @@ def test_setup_optimizer(small_data, algorithm, method_kwargs): assert func_module == 'whittaker' assert isinstance(output_kwargs, dict) assert isinstance(class_object, _algorithm_setup._Algorithm) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm @pytest.mark.parametrize('copy_kwargs', (True, False)) @@ -978,8 +998,12 @@ def test_override_x(algorithm): assert algorithm._shape == (old_size,) -def test_override_x_polynomial(algorithm): +def test_override_x_polynomial(): """Ensures the polynomial attributes are correctly reset and then returned by override_x.""" + algorithm = _algorithm_setup._Algorithm( + x_data=np.arange(10), assume_sorted=True, check_finite=False + ) + old_len = len(algorithm.x) poly_order = 2 new_poly_order = 3 @@ -1014,8 +1038,12 @@ def test_override_x_polynomial(algorithm): assert algorithm._polynomial.poly_order == poly_order -def test_override_x_whittaker(algorithm): +def test_override_x_whittaker(): """Ensures the whittaker attributes are correctly reset and then returned by override_x.""" + algorithm = _algorithm_setup._Algorithm( + x_data=np.arange(10), assume_sorted=True, check_finite=False + ) + old_len = len(algorithm.x) diff_order = 2 new_diff_order = 3 @@ -1045,8 +1073,12 @@ def test_override_x_whittaker(algorithm): assert new_whittaker_system.penta_solver == banded_solver -def test_override_x_spline(algorithm): +def test_override_x_spline(): """Ensures the spline attributes are correctly reset and then returned by override_x.""" + algorithm = _algorithm_setup._Algorithm( + x_data=np.arange(10), assume_sorted=True, check_finite=False + ) + old_len = len(algorithm.x) spline_degree = 2 new_spline_degree = 3 diff --git a/tests/test_classification.py b/tests/test_classification.py index fc573f3..514d21e 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -291,7 +291,6 @@ class ClassificationTester(BaseTester, InputWeightsMixin): """Base testing class for classification functions.""" module = classification - algorithm_base = classification._Classification checked_keys = ('mask',) weight_keys = ('mask',) requires_unique_x = True diff --git a/tests/test_meta.py b/tests/test_meta.py index dc435a1..32a1b04 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -7,17 +7,26 @@ """ from contextlib import contextmanager +from functools import wraps +from threading import Lock import numpy as np from numpy.testing import assert_allclose import pytest from .base_tests import ( - BasePolyTester, BaseTester, BaseTester2D, InputWeightsMixin, dummy_wrapper, get_data, - get_data2d + BasePolyTester, BaseTester, BaseTester2D, InputWeightsMixin, get_data, get_data2d ) +def dummy_wrapper(func): + """A dummy wrapper to simulate using the _Algorithm._register wrapper function.""" + @wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + class DummyModule: """A dummy object to serve as a fake module.""" @@ -390,6 +399,11 @@ def non_unique_x_raises(self, data=None): """Will raise an exception if x-values are not unique.""" return DummyModule.non_unique_x_raises(data=data, x_data=self.x) + @dummy_wrapper + def func(self, data=None, *args, **kwargs): + """Dummy function.""" + raise NotImplementedError('need to set func') + class TestBaseTesterWorks(BaseTester): """Ensures a basic subclass of BaseTester works.""" @@ -445,12 +459,8 @@ def test_reverse_array(self): assert_allclose(self.reverse_array(self.y), self.y[..., ::-1]) -class TestBaseTesterFailures(BaseTester): - """Tests the various BaseTester methods for functions with incorrect output.""" - - module = DummyModule - algorithm_base = DummyAlgorithm - func_name = 'no_func' +class SetFuncMixin: + """Mixin to allow temporarily changing the function for the test class.""" @contextmanager def set_func(self, func_name, checked_keys=None, attributes=None): @@ -467,6 +477,8 @@ def set_func(self, func_name, checked_keys=None, attributes=None): A dictionary of other attributes to temporarily set. Should be class attributes. """ + self.lock.acquire() + original_name = self.func_name original_keys = self.param_keys attributes = attributes if attributes is not None else {} @@ -487,6 +499,56 @@ def set_func(self, func_name, checked_keys=None, attributes=None): setattr(self.__class__, key, value) self.__class__.setup_class() + self.lock.release() + + +class SetFuncWeightsMixin: + """Mixin to allow temporarily changing the function for test classes that deal with weights.""" + + @contextmanager + def set_func(self, func_name, checked_keys=None, weight_key=('weights',)): + """Temporarily sets a new function for the class. + + Parameters + ---------- + func_name : str + The string of the function to use. + checked_keys : iterable, optional + An iterable of strings designating the keys to check in the output parameters + dictionary. + weight_key : iterable, optional + An iterable of strings designating the keys corresponding to weights to check + in the output parameters dictionary. Default is ('weights',). + + """ + self.lock.acquire() + + original_name = self.func_name + original_keys = self.param_keys + original_weight_key = self.weight_keys + try: + self.__class__.func_name = func_name + self.__class__.checked_keys = checked_keys + self.__class__.weight_keys = weight_key + self.__class__.setup_class() + yield self + finally: + self.__class__.func_name = original_name + self.__class__.checked_keys = original_keys + self.__class__.weight_keys = original_weight_key + self.__class__.setup_class() + + self.lock.release() + + +class TestBaseTesterFailures(BaseTester, SetFuncMixin): + """Tests the various BaseTester methods for functions with incorrect output.""" + + module = DummyModule + algorithm_base = DummyAlgorithm + func_name = 'no_func' + lock = Lock() + def test_ensure_wrapped(self): """Ensures no wrapper fails.""" with self.set_func('no_wrapper'): @@ -623,6 +685,8 @@ def test_non_unique_x(self): class TestBaseTesterNoFunc(BaseTester): """Ensures the BaseTester fails if not setup correctly.""" + algorithm_base = DummyAlgorithm + @pytest.mark.parametrize('use_class', (True, False)) def test_unchanged_data(self, use_class): """Ensures that input data is unchanged by the function.""" @@ -707,7 +771,7 @@ def test_output_coefs(self): super().test_output_coefs() -class TestInputWeightsMixinWorks(BaseTester, InputWeightsMixin): +class TestInputWeightsMixinWorks(BaseTester, InputWeightsMixin, SetFuncWeightsMixin): """Ensures a basic subclass of InputWeightsMixin works.""" module = DummyModule @@ -715,33 +779,17 @@ class TestInputWeightsMixinWorks(BaseTester, InputWeightsMixin): func_name = 'good_weights_func' checked_keys = ['a', 'weights'] required_kwargs = {'key': 1} - - @contextmanager - def set_func(self, func_name, checked_keys=None, weight_key=None): - """Temporarily sets a new function for the class.""" - original_name = self.func_name - original_keys = self.param_keys - original_weight_key = self.weight_keys - try: - self.__class__.func_name = func_name - self.__class__.checked_keys = checked_keys - self.__class__.weight_keys = weight_key - self.__class__.setup_class() - yield self - finally: - self.__class__.func_name = original_name - self.__class__.checked_keys = original_keys - self.__class__.weight_keys = original_weight_key - self.__class__.setup_class() + lock = Lock() def test_input_weights(self): """Ensures weight testing works for different weight keys in the parameter dictionary.""" - super().test_input_weights() + with self.lock: # have to ensure the lock is maintained here too + super().test_input_weights() with self.set_func('good_mask_func', weight_key=('mask',), checked_keys=('a', 'mask')): super().test_input_weights() -class TestInputWeightsMixinFails(BaseTester, InputWeightsMixin): +class TestInputWeightsMixinFails(BaseTester, InputWeightsMixin, SetFuncWeightsMixin): """Tests the various BasePolyTester methods for functions with incorrect output.""" module = DummyModule @@ -749,24 +797,7 @@ class TestInputWeightsMixinFails(BaseTester, InputWeightsMixin): func_name = 'bad_weights_func' checked_keys = ['a', 'weights'] required_kwargs = {'key': 1} - - @contextmanager - def set_func(self, func_name, checked_keys=None, weight_key=('weights',)): - """Temporarily sets a new function for the class.""" - original_name = self.func_name - original_keys = self.param_keys - original_weight_key = self.weight_keys - try: - self.__class__.func_name = func_name - self.__class__.checked_keys = checked_keys - self.__class__.weight_keys = weight_key - self.__class__.setup_class() - yield self - finally: - self.__class__.func_name = original_name - self.__class__.checked_keys = original_keys - self.__class__.weight_keys = original_weight_key - self.__class__.setup_class() + lock = Lock() def test_input_weights(self): """Ensures weight testing works for different weight keys in the parameter dictionary.""" @@ -835,47 +866,13 @@ def test_reverse_array(self): assert_allclose(self.reverse_array(self.y), self.y[..., ::-1, ::-1]) -class TestBaseTester2DFailures(BaseTester2D): +class TestBaseTester2DFailures(BaseTester2D, SetFuncMixin): """Tests the various BaseTester2D methods for functions with incorrect output.""" module = DummyModule algorithm_base = DummyAlgorithm func_name = 'no_func' - - @contextmanager - def set_func(self, func_name, checked_keys=None, attributes=None): - """Temporarily sets a new function for the class. - - Parameters - ---------- - func_name : str - The string of the function to use. - checked_keys : iterable, optional - An iterable of strings designating the keys to check in the output parameters - dictionary. - attributes : dict, optional - A dictionary of other attributes to temporarily set. Should be class attributes. - - """ - original_name = self.func_name - original_keys = self.param_keys - attributes = attributes if attributes is not None else {} - original_attributes = {} - for key in attributes.keys(): - original_attributes[key] = getattr(self, key) - try: - self.__class__.func_name = func_name - self.__class__.checked_keys = checked_keys - for key, value in attributes.items(): - setattr(self.__class__, key, value) - self.__class__.setup_class() - yield self - finally: - self.__class__.func_name = original_name - self.__class__.checked_keys = original_keys - for key, value in original_attributes.items(): - setattr(self.__class__, key, value) - self.__class__.setup_class() + lock = Lock() def test_ensure_wrapped(self): """Ensures no wrapper fails.""" @@ -996,6 +993,8 @@ def test_non_unique_xz(self): class TestBaseTester2DNoFunc(BaseTester2D): """Ensures the BaseTester2D fails if not setup correctly.""" + algorithm_base = DummyAlgorithm + @pytest.mark.parametrize('new_instance', (True, False)) def test_unchanged_data(self, new_instance): """Ensures that input data is unchanged by the function.""" diff --git a/tests/test_misc.py b/tests/test_misc.py index a106879..bbbcd56 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -23,7 +23,6 @@ class MiscTester(BaseTester): """Base testing class for miscellaneous functions.""" module = misc - algorithm_base = misc._Misc @pytest.mark.filterwarnings('ignore:"interp_pts" is deprecated') diff --git a/tests/test_morphological.py b/tests/test_morphological.py index a6d7a70..7639104 100644 --- a/tests/test_morphological.py +++ b/tests/test_morphological.py @@ -21,7 +21,6 @@ class MorphologicalTester(BaseTester): """Base testing class for morphological functions.""" module = morphological - algorithm_base = morphological._Morphological checked_keys = ('half_window',) @ensure_deprecation(1, 4) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index fadeb94..66585c3 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -67,7 +67,6 @@ class OptimizersTester(BaseTester): """Base testing class for optimizer functions.""" module = optimizers - algorithm_base = optimizers._Optimizers checked_method_keys = None def test_output(self, additional_keys=None, additional_method_keys=None, **kwargs): diff --git a/tests/test_polynomial.py b/tests/test_polynomial.py index 88cca27..de4fcc7 100644 --- a/tests/test_polynomial.py +++ b/tests/test_polynomial.py @@ -26,7 +26,6 @@ class PolynomialTester(BasePolyTester, InputWeightsMixin): """Base testing class for polynomial functions.""" module = polynomial - algorithm_base = polynomial._Polynomial checked_keys = ('weights',) diff --git a/tests/test_smooth.py b/tests/test_smooth.py index 9f4b3ab..a7bd6be 100644 --- a/tests/test_smooth.py +++ b/tests/test_smooth.py @@ -19,7 +19,6 @@ class SmoothTester(BaseTester): """Base testing class for whittaker functions.""" module = smooth - algorithm_base = smooth._Smooth uses_padding = True # TODO remove after version 1.4 when kwargs are deprecated @ensure_deprecation(1, 4) diff --git a/tests/test_spline.py b/tests/test_spline.py index 5146189..76bac98 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -47,7 +47,6 @@ class SplineTester(BaseTester): """Base testing class for spline functions.""" module = spline - algorithm_base = spline._Spline def test_numba_implementation(self): """ diff --git a/tests/test_spline_utils.py b/tests/test_spline_utils.py index 06a6c41..e180461 100644 --- a/tests/test_spline_utils.py +++ b/tests/test_spline_utils.py @@ -181,9 +181,11 @@ def test_bspline_has_extrapolate(): # Also check that the result is cached so that the actual check is only done once. The # cache hits would depend on the test run order, so just check that calling it twice # results in a non-zero hits value and that misses is 1 (the first call counts as a miss) + # Note that the actual check uses misses > 0 to pass if using pytest-run-parallel, where + # the number of misses would equal the number of threads used during testing assert _spline_utils._bspline_has_extrapolate() == has_extrapolate assert _spline_utils._bspline_has_extrapolate.cache_info().hits > 0 - assert _spline_utils._bspline_has_extrapolate.cache_info().misses == 1 + assert _spline_utils._bspline_has_extrapolate.cache_info().misses > 0 @pytest.mark.parametrize('num_knots', (2, 20, 1001)) diff --git a/tests/test_whittaker.py b/tests/test_whittaker.py index edb7457..97c00be 100644 --- a/tests/test_whittaker.py +++ b/tests/test_whittaker.py @@ -92,7 +92,6 @@ class WhittakerTester(BaseTester, InputWeightsMixin, RecreationMixin): """Base testing class for whittaker functions.""" module = whittaker - algorithm_base = whittaker._Whittaker checked_keys = ('weights', 'tol_history') @pytest.mark.parametrize('diff_order', (2, 3)) diff --git a/tests/two_d/test_algorithm_setup.py b/tests/two_d/test_algorithm_setup.py index 6423b35..b25b205 100644 --- a/tests/two_d/test_algorithm_setup.py +++ b/tests/two_d/test_algorithm_setup.py @@ -12,7 +12,7 @@ from scipy.sparse import kron from pybaselines._compat import identity -from pybaselines.two_d import _algorithm_setup, optimizers, polynomial, whittaker +from pybaselines.two_d import Baseline2D, _algorithm_setup, optimizers, polynomial, whittaker from pybaselines.utils import ParameterWarning, SortingWarning, difference_matrix, estimate_window from pybaselines._validation import _check_scalar @@ -241,8 +241,12 @@ def test_setup_polynomial_too_large_polyorder_fails(small_data2d, algorithm): algorithm._setup_polynomial(small_data2d, poly_order=np.array([1, 2, 3])) -def test_setup_polynomial_maxcross(small_data2d, algorithm): +def test_setup_polynomial_maxcross(small_data2d): """Ensures the _max_cross attribute is updated after calling _setup_polynomial.""" + num_x, num_z = small_data2d.shape + algorithm = _algorithm_setup._Algorithm2D( + x_data=np.arange(num_x), z_data=np.arange(num_z), assume_sorted=True, check_finite=False + ) algorithm._setup_polynomial(small_data2d, max_cross=[1], calc_vander=True) assert algorithm._polynomial.max_cross == 1 @@ -1015,16 +1019,25 @@ def test_override_x(algorithm): ('asls', 'asls', 'whittaker') ) ) -def test_get_function(algorithm, method_and_outputs): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function(method_and_outputs, ensure_new): """Ensures _get_function gets the correct method, regardless of case.""" method, expected_func, expected_module = method_and_outputs tested_modules = [optimizers, polynomial, whittaker] + + algorithm = Baseline2D( + x_data=np.arange(10), z_data=np.arange(20), assume_sorted=True, check_finite=False + ) selected_func, module, class_object = algorithm._get_function( - method, tested_modules + method, tested_modules, ensure_new ) assert selected_func.__name__ == expected_func assert module == expected_module assert isinstance(class_object, _algorithm_setup._Algorithm2D) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm def test_get_function_fails_wrong_method(algorithm): @@ -1039,28 +1052,34 @@ def test_get_function_fails_no_module(algorithm): algorithm._get_function('collab_pls', []) -def test_get_function_sorting_x(): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function_sorting_x(ensure_new): """Ensures the sort order is correct for the output class object when x is reversed.""" num_points = 10 x = np.arange(num_points) ordering = np.arange(num_points) - algorithm = _algorithm_setup._Algorithm2D(x[::-1], assume_sorted=False) - func, func_module, class_object = algorithm._get_function('asls', [whittaker]) + algorithm = Baseline2D(x[::-1], assume_sorted=False) + func, func_module, class_object = algorithm._get_function('asls', [whittaker], ensure_new) assert_array_equal(class_object.x, x) assert_array_equal(class_object._sort_order, ordering[::-1]) assert_array_equal(class_object._inverted_order, ordering[::-1]) assert_array_equal(class_object._sort_order, algorithm._sort_order) assert_array_equal(class_object._inverted_order, algorithm._inverted_order) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm -def test_get_function_sorting_z(): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function_sorting_z(ensure_new): """Ensures the sort order is correct for the output class object when z is reversed.""" num_points = 10 z = np.arange(num_points) ordering = np.arange(num_points) - algorithm = _algorithm_setup._Algorithm2D(None, z[::-1], assume_sorted=False) - func, func_module, class_object = algorithm._get_function('asls', [whittaker]) + algorithm = Baseline2D(None, z[::-1], assume_sorted=False) + func, func_module, class_object = algorithm._get_function('asls', [whittaker], ensure_new) assert_array_equal(class_object.z, z) assert class_object._sort_order[0] is Ellipsis @@ -1071,9 +1090,14 @@ def test_get_function_sorting_z(): assert_array_equal(class_object._inverted_order[1], ordering[::-1]) assert_array_equal(class_object._sort_order[1], algorithm._sort_order[1]) assert_array_equal(class_object._inverted_order[1], algorithm._inverted_order[1]) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm -def test_get_function_sorting_xz(): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_get_function_sorting_xz(ensure_new): """Ensures the sort order is correct for the output class object when x and z are reversed.""" num_x_points = 10 num_z_points = 11 @@ -1082,8 +1106,8 @@ def test_get_function_sorting_xz(): z = np.arange(num_z_points) z_ordering = np.arange(num_z_points) - algorithm = _algorithm_setup._Algorithm2D(x[::-1], z[::-1], assume_sorted=False) - func, func_module, class_object = algorithm._get_function('asls', [whittaker]) + algorithm = Baseline2D(x[::-1], z[::-1], assume_sorted=False) + func, func_module, class_object = algorithm._get_function('asls', [whittaker], ensure_new) assert_array_equal(class_object.x, x) assert_array_equal(class_object.z, z) @@ -1095,13 +1119,22 @@ def test_get_function_sorting_xz(): assert_array_equal(class_object._sort_order[1], algorithm._sort_order[1]) assert_array_equal(class_object._inverted_order[0], algorithm._inverted_order[0]) assert_array_equal(class_object._inverted_order[1], algorithm._inverted_order[1]) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm @pytest.mark.parametrize('method_kwargs', (None, {'a': 2})) -def test_setup_optimizer(small_data2d, algorithm, method_kwargs): +@pytest.mark.parametrize('ensure_new', (True, False)) +def test_setup_optimizer(small_data2d, method_kwargs, ensure_new): """Ensures output of _setup_optimizer is correct.""" + num_x, num_z = small_data2d.shape + algorithm = Baseline2D( + x_data=np.arange(num_x), z_data=np.arange(num_z), assume_sorted=True, check_finite=False + ) y, fit_func, func_module, output_kwargs, class_object = algorithm._setup_optimizer( - small_data2d, 'asls', [whittaker], method_kwargs + small_data2d, 'asls', [whittaker], method_kwargs, ensure_new=ensure_new ) assert isinstance(y, np.ndarray) @@ -1110,6 +1143,10 @@ def test_setup_optimizer(small_data2d, algorithm, method_kwargs): assert func_module == 'whittaker' assert isinstance(output_kwargs, dict) assert isinstance(class_object, _algorithm_setup._Algorithm2D) + if ensure_new: + assert class_object is not algorithm + else: + assert class_object is algorithm @pytest.mark.parametrize('copy_kwargs', (True, False)) diff --git a/tests/two_d/test_morphological.py b/tests/two_d/test_morphological.py index 1e4403d..2595ead 100644 --- a/tests/two_d/test_morphological.py +++ b/tests/two_d/test_morphological.py @@ -19,7 +19,6 @@ class MorphologicalTester(BaseTester2D): """Base testing class for morphological functions.""" module = morphological - algorithm_base = morphological._Morphological checked_keys = ('half_window',) @pytest.mark.parametrize('half_window', (None, 10, [10, 12], np.array([12, 10]))) diff --git a/tests/two_d/test_optimizers.py b/tests/two_d/test_optimizers.py index 21ae6d4..e89983c 100644 --- a/tests/two_d/test_optimizers.py +++ b/tests/two_d/test_optimizers.py @@ -63,7 +63,6 @@ class OptimizersTester(BaseTester2D): """Base testing class for optimizer functions.""" module = optimizers - algorithm_base = optimizers._Optimizers checked_method_keys = None def test_output(self, additional_keys=None, additional_method_keys=None, diff --git a/tests/two_d/test_polynomial.py b/tests/two_d/test_polynomial.py index 3d17424..22efb5f 100644 --- a/tests/two_d/test_polynomial.py +++ b/tests/two_d/test_polynomial.py @@ -22,7 +22,6 @@ class PolynomialTester(BasePolyTester2D, InputWeightsMixin): """Base testing class for polynomial functions.""" module = polynomial - algorithm_base = polynomial._Polynomial checked_keys = ('weights',) diff --git a/tests/two_d/test_smooth.py b/tests/two_d/test_smooth.py index 89d5cea..88268ba 100644 --- a/tests/two_d/test_smooth.py +++ b/tests/two_d/test_smooth.py @@ -18,7 +18,6 @@ class SmoothTester(BaseTester2D): """Base testing class for whittaker functions.""" module = smooth - algorithm_base = smooth._Smooth @ensure_deprecation(1, 4) def test_kwargs_deprecation(self): diff --git a/tests/two_d/test_spline.py b/tests/two_d/test_spline.py index 53ae3c9..23fd320 100644 --- a/tests/two_d/test_spline.py +++ b/tests/two_d/test_spline.py @@ -50,7 +50,6 @@ class SplineTester(BaseTester2D): """Base testing class for spline functions.""" module = spline - algorithm_base = spline._Spline class IterativeSplineTester(SplineTester, InputWeightsMixin, RecreationMixin): diff --git a/tests/two_d/test_whittaker.py b/tests/two_d/test_whittaker.py index a6d6aa5..3f6488a 100644 --- a/tests/two_d/test_whittaker.py +++ b/tests/two_d/test_whittaker.py @@ -18,7 +18,6 @@ class WhittakerTester(BaseTester2D, InputWeightsMixin, RecreationMixin): """Base testing class for whittaker functions.""" module = whittaker - algorithm_base = whittaker._Whittaker checked_keys = ('weights', 'tol_history') def test_tol_history(self):