Skip to content

Commit bd75cdc

Browse files
authored
Merge pull request #372 from fooof-tools/algo
[ENH] - Update Algorithm Management
2 parents d581ae0 + 5af56e0 commit bd75cdc

File tree

19 files changed

+197
-75
lines changed

19 files changed

+197
-75
lines changed

specparam/algorithms/algorithm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, name, description, public_settings, private_settings=None,
6666
self.set_debug(debug)
6767

6868

69-
def _fit_prechecks(self):
69+
def _fit_prechecks(self, verbose):
7070
"""Pre-checks to run before the fit function - if are some, overload this function."""
7171

7272

@@ -195,11 +195,17 @@ def _initialize_bounds(self, mode):
195195
[high_bound_param1, high_bound_param2])
196196
"""
197197

198-
n_params = getattr(self.modes, mode).n_params
198+
# If modes defined, get number of params - otherwise set stores as empty
199+
if self.modes is not None:
200+
n_params = getattr(self.modes, mode).n_params
201+
else:
202+
n_params = 0
203+
199204
bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params))
200205

201206
return bounds
202207

208+
203209
def _initialize_guess(self, mode):
204210
"""Initialize a guess definition.
205211

specparam/algorithms/definitions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Define collection of fitting algorithms."""
22

3+
from functools import partial
4+
5+
from specparam.utils.checks import check_selection
6+
from specparam.algorithms.algorithm import Algorithm
37
from specparam.algorithms.spectral_fit import SpectralFitAlgorithm
48

59
###################################################################################################
@@ -9,3 +13,15 @@
913
ALGORITHMS = {
1014
'spectral_fit' : SpectralFitAlgorithm,
1115
}
16+
17+
18+
def check_algorithms():
19+
"""Check the set of available fit algorithms."""
20+
21+
print('Available algorithms:')
22+
for algorithm in ALGORITHMS.values():
23+
algorithm = algorithm()
24+
print(' {:12s} : {:s}'.format(algorithm.name, algorithm.description))
25+
26+
27+
check_algorithm_definition = partial(check_selection, definition=Algorithm)

specparam/algorithms/spectral_fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
111111

112112
# Initialize base algorithm object with algorithm metadata
113113
super().__init__(
114-
name='spectral fit',
114+
name='spectral_fit',
115115
description='Original parameterizing neural power spectra algorithm.',
116116
public_settings=SPECTRAL_FIT_SETTINGS_DEF,
117117
private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF,

specparam/metrics/definitions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Collect together library of available built in metrics."""
22

3+
from functools import partial
4+
35
from specparam.metrics.metrics import Metric
46
from specparam.metrics.error import (compute_mean_abs_error, compute_mean_squared_error,
57
compute_root_mean_squared_error, compute_median_abs_error)
68
from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared
9+
from specparam.utils.checks import check_selection
710

811
###################################################################################################
912
## ERROR METRICS
@@ -79,3 +82,6 @@ def check_metrics():
7982
print('Available metrics:')
8083
for metric in METRICS.values():
8184
print(' {:8s} {:12s} : {:s}'.format(metric.category, metric.measure, metric.description))
85+
86+
87+
check_metric_definition = partial(check_selection, definition=Metric)

specparam/metrics/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from specparam.metrics.metric import Metric
8+
from specparam.metrics.definitions import METRICS, check_metric_definition
89

910
###################################################################################################
1011
###################################################################################################
@@ -61,6 +62,8 @@ def add_metric(self, metric):
6162
if isinstance(metric, dict):
6263
metric = Metric(**metric)
6364

65+
metric = check_metric_definition(metric, METRICS)
66+
6467
self.metrics.append(deepcopy(metric))
6568

6669

specparam/models/event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def __init__(self, *args, **kwargs):
6060
self.data = Data3D()
6161

6262
self.results = Results3D(modes=self.modes,
63-
metrics=kwargs.pop('metrics', None),
64-
bands=kwargs.pop('bands', None))
63+
metrics=kwargs.pop('metrics', None),
64+
bands=kwargs.pop('bands', None))
6565

6666
self.algorithm._reset_subobjects(data=self.data, results=self.results)
6767

specparam/models/group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(self, *args, **kwargs):
6363
self.data = Data2D()
6464

6565
self.results = Results2D(modes=self.modes,
66-
metrics=kwargs.pop('metrics', None),
67-
bands=kwargs.pop('bands', None))
66+
metrics=kwargs.pop('metrics', None),
67+
bands=kwargs.pop('bands', None))
6868

6969
self.algorithm._reset_subobjects(data=self.data, results=self.results)
7070

specparam/models/model.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from specparam.data.data import Data
1212
from specparam.data.conversions import model_to_dataframe
1313
from specparam.results.results import Results
14+
1415
from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF
16+
from specparam.algorithms.definitions import ALGORITHMS, check_algorithm_definition
17+
1518
from specparam.reports.save import save_model_report
1619
from specparam.reports.strings import gen_model_results_str
1720
from specparam.modutils.errors import NoDataError, FitError
@@ -34,10 +37,14 @@ class SpectralModel(BaseModel):
3437
Parameters
3538
----------
3639
% copied in from Spectral Fit Algorithm Settings
37-
aperiodic_mode : {'fixed', 'knee'}
40+
aperiodic_mode : {'fixed', 'knee'} or Mode
3841
Which approach to take for fitting the aperiodic component.
39-
periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'}
42+
periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'} or Mode
4043
Which approach to take for fitting the periodic component.
44+
algorithm : {'spectral_fit'} or Algorithm
45+
The fitting algorithm to use.
46+
algorithm_settings : dict
47+
Setting for the algorithm.
4148
metrics : Metrics or list of Metric or list or str
4249
Metrics definition(s) to use to evaluate the model.
4350
bands : Bands or dict or int or None, optional
@@ -49,6 +56,7 @@ class SpectralModel(BaseModel):
4956
Verbosity mode. If True, prints out warnings and general status updates.
5057
**model_kwargs
5158
Additional model fitting related keyword arguments.
59+
These are passed into the algorithm object.
5260
5361
Attributes
5462
----------
@@ -71,25 +79,21 @@ class SpectralModel(BaseModel):
7179
as this will give better model fits.
7280
"""
7381

74-
def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0,
75-
peak_threshold=2.0, aperiodic_mode='fixed', periodic_mode='gaussian',
82+
def __init__(self, aperiodic_mode='fixed', periodic_mode='gaussian',
83+
algorithm='spectral_fit', algorithm_settings=None,
7684
metrics=None, bands=None, debug=False, verbose=True, **model_kwargs):
7785
"""Initialize model object."""
7886

79-
BaseModel.__init__(self,
80-
aperiodic_mode=aperiodic_mode,
81-
periodic_mode=periodic_mode,
82-
verbose=verbose)
87+
BaseModel.__init__(self, aperiodic_mode, periodic_mode, verbose)
8388

8489
self.data = Data()
8590

8691
self.results = Results(modes=self.modes, metrics=metrics, bands=bands)
8792

88-
self.algorithm = SpectralFitAlgorithm(
89-
peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks,
90-
min_peak_height=min_peak_height, peak_threshold=peak_threshold,
91-
modes=self.modes, data=self.data, results=self.results,
92-
debug=debug, **model_kwargs)
93+
algorithm_settings = {} if algorithm_settings is None else algorithm_settings
94+
self.algorithm = check_algorithm_definition(algorithm, ALGORITHMS)(
95+
**algorithm_settings, modes=self.modes, data=self.data,
96+
results=self.results, debug=debug, **model_kwargs)
9397

9498

9599
@replace_docstring_sections([docs_get_section(Data.add_data.__doc__, 'Parameters'),

specparam/models/time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def __init__(self, *args, **kwargs):
5656
self.data = Data2DT()
5757

5858
self.results = Results2DT(modes=self.modes,
59-
metrics=kwargs.pop('metrics', None),
60-
bands=kwargs.pop('bands', None))
59+
metrics=kwargs.pop('metrics', None),
60+
bands=kwargs.pop('bands', None))
6161

6262
self.algorithm._reset_subobjects(data=self.data, results=self.results)
6363

specparam/models/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def combine_model_objs(model_objs):
237237
"or meta data, and so cannot be combined.")
238238

239239
# Initialize group model object, with settings derived from input objects
240-
group = SpectralGroupModel(*model_objs[0].algorithm.get_settings(),
240+
group = SpectralGroupModel(**model_objs[0].modes.get_modes()._asdict(),
241+
**model_objs[0].algorithm.get_settings()._asdict(),
241242
verbose=model_objs[0].verbose)
242243

243244
# Use a temporary store to collect spectra, as we'll only add it if it is consistently present

0 commit comments

Comments
 (0)