1111from specparam .data .data import Data
1212from specparam .data .conversions import model_to_dataframe
1313from specparam .results .results import Results
14+
1415from specparam .algorithms .spectral_fit import SpectralFitAlgorithm , SPECTRAL_FIT_SETTINGS_DEF
16+ from specparam .algorithms .definitions import ALGORITHMS , check_algorithm_definition
17+
1518from specparam .reports .save import save_model_report
1619from specparam .reports .strings import gen_model_results_str
1720from specparam .modutils .errors import NoDataError , FitError
@@ -34,10 +37,14 @@ class SpectralModel(BaseModel):
3437 Parameters
3538 ----------
3639 % copied in from Spectral Fit Algorithm Settings
37- aperiodic_mode : {'fixed', 'knee'}
40+ aperiodic_mode : {'fixed', 'knee'} or Mode
3841 Which approach to take for fitting the aperiodic component.
39- periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'}
42+ periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'} or Mode
4043 Which approach to take for fitting the periodic component.
44+ algorithm : {'spectral_fit'} or Algorithm
45+ The fitting algorithm to use.
46+ algorithm_settings : dict
47+ Setting for the algorithm.
4148 metrics : Metrics or list of Metric or list or str
4249 Metrics definition(s) to use to evaluate the model.
4350 bands : Bands or dict or int or None, optional
@@ -49,6 +56,7 @@ class SpectralModel(BaseModel):
4956 Verbosity mode. If True, prints out warnings and general status updates.
5057 **model_kwargs
5158 Additional model fitting related keyword arguments.
59+ These are passed into the algorithm object.
5260
5361 Attributes
5462 ----------
@@ -71,25 +79,21 @@ class SpectralModel(BaseModel):
7179 as this will give better model fits.
7280 """
7381
74- def __init__ (self , peak_width_limits = ( 0.5 , 12.0 ), max_n_peaks = np . inf , min_peak_height = 0.0 ,
75- peak_threshold = 2.0 , aperiodic_mode = 'fixed ' , periodic_mode = 'gaussian' ,
82+ def __init__ (self , aperiodic_mode = 'fixed' , periodic_mode = 'gaussian' ,
83+ algorithm = 'spectral_fit ' , algorithm_settings = None ,
7684 metrics = None , bands = None , debug = False , verbose = True , ** model_kwargs ):
7785 """Initialize model object."""
7886
79- BaseModel .__init__ (self ,
80- aperiodic_mode = aperiodic_mode ,
81- periodic_mode = periodic_mode ,
82- verbose = verbose )
87+ BaseModel .__init__ (self , aperiodic_mode , periodic_mode , verbose )
8388
8489 self .data = Data ()
8590
8691 self .results = Results (modes = self .modes , metrics = metrics , bands = bands )
8792
88- self .algorithm = SpectralFitAlgorithm (
89- peak_width_limits = peak_width_limits , max_n_peaks = max_n_peaks ,
90- min_peak_height = min_peak_height , peak_threshold = peak_threshold ,
91- modes = self .modes , data = self .data , results = self .results ,
92- debug = debug , ** model_kwargs )
93+ algorithm_settings = {} if algorithm_settings is None else algorithm_settings
94+ self .algorithm = check_algorithm_definition (algorithm , ALGORITHMS )(
95+ ** algorithm_settings , modes = self .modes , data = self .data ,
96+ results = self .results , debug = debug , ** model_kwargs )
9397
9498
9599 @replace_docstring_sections ([docs_get_section (Data .add_data .__doc__ , 'Parameters' ),
0 commit comments