diff --git a/examples/customize/plot_custom_param_conversions.py b/examples/customize/plot_custom_param_conversions.py new file mode 100644 index 00000000..cb78f783 --- /dev/null +++ b/examples/customize/plot_custom_param_conversions.py @@ -0,0 +1,312 @@ +""" +Custom Parameter Conversions +============================ + +This example covers defining and using custom parameter post-fitting conversions. +""" + +from specparam import SpectralModel + +from specparam.utils.download import load_example_data + +# Import the default set of parameter conversions +from specparam.convert.definitions import check_converters, DEFAULT_CONVERTERS + +# Import objects to define parameter conversions +from specparam.convert.converter import PeriodicParamConverter, AperiodicParamConverter + +################################################################################################### +# Parameter Conversions +# --------------------- +# +# After model fitting, a model object includes the parameters for the model as defined by the +# fit modes and as arrived at by the fit algorithm. These fit parameters define the model fit, +# as visualized, for example, by the 'full model' fit, when plotting the model. +# +# However, these 'fit' parameters are not necessarily defined in a way that we actually +# want to analyzed. For this reason, spectral parameterization supports doing post-fitting +# parameter conversions, whereby after the fitting process, conversions can be applied to +# the fit parameters. +# +# Let's first explore this with an example model fit. +# + +################################################################################################### + +# Load example spectra +freqs = load_example_data('freqs.npy', folder='data') +powers = load_example_data('spectrum.npy', folder='data') + +# Define fitting fit range +freq_range = [2, 40] + +# Initialize and fit an example model +fm = SpectralModel() +fm.report(freqs, powers, freq_range) + +################################################################################################### +# +# In the above, we see the model fit, and reported parameter values. +# +# Let's further investigate the different versions of the parameters: 'fit' and 'converted'. +# + +################################################################################################### + +# Check the aperiodic fit & converted parameters +print(fm.results.get_params('aperiodic', version='fit')) +print(fm.results.get_params('aperiodic', version='converted')) + +################################################################################################### +# +# In the above, we can see that there are fit parameters, but there is no defined converted +# version of the parameters, indicating that there are no conversions defined for the +# aperiodic parameters. +# + +################################################################################################### + +# Check the periodic fit & converted parameters, for an example peak +print(fm.results.get_params('periodic', version='fit')[1, :]) +print(fm.results.get_params('periodic', version='converted')[1, :]) + +################################################################################################### +# +# In this case, there are both fit and converted versions of the parameters, +# and they are not the same! +# +# There are defined periodic parameter conversions that are being done. Note also that it is +# the converted versions of the parameters that are printed in the report above. +# + +################################################################################################### +# Default Converters +# ------------------ +# +# To see what the conversions are that are being defined, we can examine the set of +# DEFAULT_CONVERTERS, which we imported from the module. +# + +################################################################################################### + +# Check the default model fit parameters +DEFAULT_CONVERTERS + +################################################################################################### +# Change Default Converters +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Next, we can explore changing which converters we use. +# +# To start with a simple example, let's turn off all parameter conversions. +# +# Note that as a shortcut, we can get a parameter definition from the Modes sub-object that +# is part of the model object, specified to return a dictionary. +# + +################################################################################################### + +# Get a dictionary representation of current parameters +null_converters = fm.modes.get_params('dict') +null_converters + +################################################################################################### + +# Initialize & fit a new model with null converters +fm1 = SpectralModel(converters=null_converters) +fm1.report(freqs, powers, freq_range) + +################################################################################################### +# +# In the above no parameter conversions were applied! +# + +################################################################################################### + +# Check that there are no converted parameters - should all be nan +print(fm1.results.get_params('aperiodic', version='converted')) +print(fm1.results.get_params('periodic', version='converted')) + +################################################################################################### +# +# Next, we can explore specifying to use different built in parameter conversions. +# +# To do so, we can explore the available options with the +# :func:`~specparam.convert.definitions.check_converters` function. +# + +################################################################################################### + +# Check the available aperiodic parameter converters +check_converters('aperiodic') + +################################################################################################### + +# Check the available periodic parameter converters +check_converters('periodic') + +################################################################################################### +# +# Now we can select different conversions from these options. +# + +################################################################################################### + +# Take a copy of the null converters dictionary +selected_converters = null_converters.copy() + +# Specify a different +selected_converters['periodic']['pw'] = 'lin_sub' + +################################################################################################### + +# Initialize & fit a new model with selected converters +fm2 = SpectralModel(converters=selected_converters) +fm2.report(freqs, powers, freq_range) + +################################################################################################### +# +# In the above, the converted and reported parameter outputs used the specified conversions! +# + +################################################################################################### +# Create Custom Converters +# ------------------------ +# +# Finally, let's explore defining some custom parameter conversions. +# +# To do so, for any parameter that we wish to define a conversion for, we can define a +# callable that implements our desired conversion. +# +# In order for specparam to be able to use the callable, they must follow properties: +# +# - for aperiodic component conversions : callable should accept inputs `fit_value` and `model` +# - for periodic component conversions: callable should accept inputs `fit_value`, `model`, and `peak_ind` +# + +################################################################################################### + +# Take a copy of the null converters dictionary +custom_converters = null_converters.copy() + +################################################################################################### +# +# To start with, let's define a simple conversion for the aperiodic exponent to convert the +# fit value into the equivalent spectral slope value (the negative of the exponent value). +# +# To define this simple conversion we can even use a lambda function. +# + +################################################################################################### + +# Create a custom exponent converter as a lambda function +custom_converters['aperiodic']['exponent'] = lambda param, model : -param + +################################################################################################### +# +# Let's also define a conversion for a periodic parameter. As an example, we can define a +# conversion of the fit center frequency value that finds and update to the closest frequency +# value that actually occurs in the frequency definition. For this case, we will implement +# conversion function. +# + +################################################################################################### + +# Import utility function to find nearest index +from specparam.utils.select import nearest_ind + +# Define a function to update the center frequency +def update_cf(fit_value, model, peak_ind): + """Updates center frequency to be closest existing frequency value.""" + + f_ind = nearest_ind(model.data.freqs, fit_value) + new_cf = model.data.freqs[f_ind] + + return new_cf + +################################################################################################### + +# Add the custom cf converter function to function collection +custom_converters['periodic']['cf'] = update_cf + +################################################################################################### +# +# Now we have defined our custom converters, we can use them in the fitting process! +# + +################################################################################################### + +# Initialize & fit a new model with custom converters +fm3 = SpectralModel(converters=custom_converters) +fm3.report(freqs, powers, freq_range) + +################################################################################################### +# +# In the above report, our custom parameter conversions were used. +# + +################################################################################################### +# Parameter Converter Objects +# --------------------------- +# +# In the above, we defined custom parameter converters by directly passing in callables that +# implement our desired conversions. As we've seen above, this works to pass in conversions +# +# However, only passing in the callable is a bit light on details and description. If you +# want to implement parameter conversions using an approach that keeps track of additional +# description of the approach, you can use the +# :class:`~specparam.convert.converter.AperiodicParamConverter` and +# :class:`~specparam.convert.converter.PeriodicParamConverter` objects to +# + +################################################################################################### + +# Define the exponent to slope conversion as a converter object +exp_slope_converter = AperiodicParamConverter( + parameter='exponent', + name='slope', + description='Convert the fit exponent value to the equivalent spectral slope value.', + function=lambda param, model : -param, +) + +# Define the center frequency fixed frequency converter as a converter object +cf_fixed_freq_converter = PeriodicParamConverter( + parameter='cf', + name='fixed_freq', + description='Convert the fit center frequency value to a fixed frequency value.', + function=update_cf, +) + +################################################################################################### + +# Take a new copy of the null converters dictionary & add +custom_converters2 = null_converters.copy() +custom_converters['aperiodic']['exponent'] = exp_slope_converter +custom_converters2['periodic']['cf'] = cf_fixed_freq_converter + +################################################################################################### +# +# Same as before, we can now use our custom converter definitions in the model fitting process. +# + +################################################################################################### + +# Initialize & fit a new model with custom converters +fm4 = SpectralModel(converters=custom_converters2) +fm4.report(freqs, powers, freq_range) + +################################################################################################### +# Adding New Parameter Conversions to the Module +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# As a final note, if you look into the set of 'built-in' parameter conversions that are +# available within the module, you will see that these are defined in the same way as done here, +# using the conversion objects introduced above. The only difference is that they are defined +# within the module and therefore can be accessed via their name, as a shortcut, +# rather than the user having to pass in their own full definitions. +# +# This also means that if you have a custom parameter conversion that you think would be of +# interest to other specparam users, once the ParamConverter object is defined it is quite +# easy to add this to the module as a new default option. If you would be interested in +# suggesting a mode be added to the module, feel free to open an issue and/or pull request. +# diff --git a/examples/customize/plot_sub_objects.py b/examples/customize/plot_sub_objects.py index 1d0360e8..35c18742 100644 --- a/examples/customize/plot_sub_objects.py +++ b/examples/customize/plot_sub_objects.py @@ -259,7 +259,7 @@ def print_public_api(obj): ################################################################################################### # Initialize a base model, passing in empty mode definitions -base = BaseModel(None, None, False) +base = BaseModel(None, None, None, False) # Check the API of the object print_public_api(base) diff --git a/examples/models/plot_model_component_relationships.py b/examples/models/plot_model_component_relationships.py new file mode 100644 index 00000000..b3580b1f --- /dev/null +++ b/examples/models/plot_model_component_relationships.py @@ -0,0 +1,316 @@ +""" +Component Combinations +====================== + +Explore different approaches combining model components. +""" + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as plt + +from specparam import SpectralModel + +from specparam.plts import plot_spectra + +from specparam.utils.array import unlog +from specparam.utils.select import nearest_ind +from specparam.utils.download import load_example_data + +# Import function to directly compute peak heights +from specparam.convert.params import compute_peak_height + +# Import the default parameter conversions +from specparam.modes.convert import DEFAULT_CONVERTERS + +# sphinx_gallery_start_ignore +from specparam.plts.utils import check_ax + +def plot_peak_height(model, peak_ind, spacing, operation, ax=None): + """Annotate plot by drawing the computed peak height.""" + + # Get the frequency value of the data closest to the specified peak + f_ind = nearest_ind(model.data.freqs, + model.results.params.periodic.params[peak_ind, 0]) + + # Plot the power spectrum + ax = check_ax(ax) + title = 'Peak Height: {:s}_{:s}'.format(spacing[0:3], operation[0:3]) + plot_spectra(freqs, powers, log_powers=spacing=='log', + color='black', title=title, ax=ax) + + # Add dot marker at the peak frequency index, at the aperiodic component power value + ax.plot([model.data.freqs[f_ind]], + [model.results.model.get_component('aperiodic', spacing)[f_ind]], + '.', ms=12, color='blue') + + # Add dot marker at the peak frequency index, at the peak top (combined) power value + ax.plot([model.data.freqs[f_ind]], + [model.results.model.get_component('full', spacing)[f_ind]], + '.', ms=12, color='red') + + # Draw the line for the computed peak height, based on provided spacing and operation + ax.plot([model.data.freqs[f_ind], model.data.freqs[f_ind]], + [model.results.model.get_component('aperiodic', spacing)[f_ind], + model.results.model.get_component('aperiodic', spacing)[f_ind] + \ + compute_peak_height(fm, peak_ind, spacing, operation)], + color='green', lw=2) +# sphinx_gallery_end_ignore + +################################################################################################### +# Introduction +# ------------ +# +# In general, the approach taken for doing spectral parameterization considers the power +# spectrum to be a combination of multiple components. Notably, however, there is more than +# one possible way to combine the components, for example, components could be added +# together, or multiplied, etc. +# +# An additional complication is that the power values of power spectra are often examined +# in log-power spacing. This is important as whether the implications of how the model +# components are combined also depends on the spacing of the data. To explore this, we +# will first start with some brief notes on logging, and then explore how this all +# relates to model component combinations and related measures, such as peak heights. +# + +################################################################################################### + +# Load example spectra - using real data here +freqs = load_example_data('freqs.npy', folder='data') +powers = load_example_data('spectrum.npy', folder='data') + +# Define frequency range for model fitting +freq_range = [2, 40] + +################################################################################################### +# Some Notes on Logging +# --------------------- +# +# In order to explore the implications of how the different components are combined, we will first +# briefly revisit some rules for how logs work in mathematics. +# +# Specifically, the relationship between adding & subtracting log values, and how this relates +# to equivalent operations in linear space, whereby the rules are: +# +# - log(x) + log(y) = log(x * y) +# - log(x) - log(y) = log(x / y) +# +# When working in log space, the addition or subtraction of two log spaced values is +# equivalent to the log of the multiplication or division of those values. +# +# Relatedly, we could note some properties that don't hold in log space, such as: +# +# - log(a) + log(y) != log(x + y) +# - log(a) - log(y) != log(x - y) +# +# Collectively, what this means is that the addition or subtraction of log values, +# is not equivalent of doing addition of subtraction of the linear values. +# + +################################################################################################### + +# Sum of log values is equivalent to the log of the product +assert np.log10(1.5) + np.log10(1.5) == np.log10(1.5 * 1.5) + +# Sum of log values is not equivalent to the log of sum +assert np.log10(1.5) + np.log10(1.5) != np.log10(1.5 + 1.5) + +################################################################################################### +# So, why do we use logs? +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# Given this, it is perhaps worth a brief interlude to consider why we so often use log +# transforms when examining power spectra. One reason is simply that power values are +# extremely skewed, with huge differences in the measured power values between, for example, +# low frequencies and high frequencies and/or between the peak of an oscillation peak and the +# power values for surrounding frequencies. +# +# This is why for visualizations and/or statistical analyses, working in log space can be +# useful and convenient. However, when doing so, it's important to keep in mind the implications +# of doing so, since it can otherwise be easy to think about properties and transformations +# in linear space, and end up with incorrect conclusions. For example, when adding or subtracting +# from power spectra in log space and/or when comparing power values, such as between different +# peaks, we need to remember the implications of log spacing. +# + +################################################################################################### + +# Plot a power spectrum in both linear-linear and log-linear space +_, axes = plt.subplots(1, 2, figsize=(12, 6)) +plot_spectra(freqs, powers, log_powers=False, label='Linear Power', ax=axes[0]) +plot_spectra(freqs, powers, log_powers=True, label='Log Power', ax=axes[1]) + +################################################################################################### +# +# In the above linear-linear power spectrum plot, we can see the skewed nature of the power +# values, including the steepness of the decay of the 1/f-like nature of the spectrum, and +# the degree to which peaks of power, such as the alpha peak here, can be many times higher +# power than other frequencies. +# + +################################################################################################### +# Model Component Combinations +# ---------------------------- +# +# Having explored typical representations of neural power spectra, and some notes on logging, +# let's come back to the main topic of model component combinations. +# +# Broadly,when considering how the components relate to each other, in terms of how they are +# combined to create the full model fit, we can start with considering two key aspects: +# +# - the operation, e.g. additive or multiplicative +# - the spacing of the data, e.g. linear or log +# +# Notably, as seen above there is an interaction between these choices that needs to be considered. +# + +################################################################################################### + +# Initialize and fit an example model +fm = SpectralModel(verbose=False) +fm.fit(freqs, powers, [2, 40]) + +# Plot the model fit, with peak annotations +fm.plot(plot_peaks='dot') + +################################################################################################### +# +# To compute different possible versions of the peak height, we can use the +# :func:`~.compute_peak_height` function. Using this function, we can compute measures of +# the peak height, specifying different data representations and difference measures. +# + +################################################################################################### + +# Define which peak ind to compute height for +peak_ind = 0 + +# Compute 4 different measures of the peak height +peak_heights = { + 'log_sub' : compute_peak_height(fm, peak_ind, 'log', 'subtract'), + 'log_div' : compute_peak_height(fm, peak_ind, 'log', 'divide'), + 'lin_sub' : compute_peak_height(fm, peak_ind, 'linear', 'subtract'), + 'lin_div' : compute_peak_height(fm, peak_ind, 'linear', 'divide'), +} + +################################################################################################### + +# Check computing difference / division measures +print('log sub : {:+08.4f}'.format(peak_heights['log_sub'])) +print('log div : {:+08.4f}'.format(peak_heights['log_div'])) +print('lin sub : {:+08.4f}'.format(peak_heights['lin_sub'])) +print('lin div : {:+08.4f}'.format(peak_heights['lin_div'])) + +################################################################################################### +# +# As expected, we can see that the four different combinations of spacing and operation can +# lead to 4 different answers for the peak height. +# +# We can also go one step further, and examine (un)logging the results, to explore if +# changing the spacing of the computed results aligns with any of the original calculations. +# + +################################################################################################### + +# Check logging / unlogging measures: un-logged log sub is same as linear division +print('Unlog log sub : {:+08.4f}'.format(unlog(peak_heights['log_sub']))) + +################################################################################################### + +# Check logging / unlogging measures: logged linear-division is the same as log subtraction +print('Log of lin div : {:+08.4f}'.format(np.log10(peak_heights['lin_div']))) + +################################################################################################### +# +# In the above examples we see that changing the spacing of some results does line up with +# some of the previously computed estimates. As expected based on the log rules, unlogging +# the log-subtraction is equivalent to the linear division, and (vice-versa) logging the +# linear division is equivalent to the log-subtraction. +# +# This also means that you cannot convert directly between spacing keeping the same operation, +# for example, you cannot convert to the linear-subtraction result by unlogging +# the log-subtraction result. +# +# To summarize: +# +# - log / linear and difference / division all give difference values +# - unlogging the log difference is the same as the linear division +# - unlogging the log difference does NOT give the linear difference +# +# - logging the linear division is the same as the log difference +# - logging the linear difference does NOT give the log difference +# +# +# Note that this is all standard log operations, the point here is to evaluate these +# different estimates in the context of spectral parameterization, so that we can next +# discuss when to select and use these different estimates. +# + +################################################################################################### + +# Visualize log vs linear peak height estimates +_, axes = plt.subplots(1, 2, figsize=(12, 6)) +plot_peak_height(fm, peak_ind, 'linear', 'subtract', ax=axes[1]) +plot_peak_height(fm, peak_ind, 'log', 'subtract', ax=axes[0]) + +################################################################################################### +# Additive vs. Multiplicative Component Combinations +# -------------------------------------------------- +# +# Given these different possible measures of the peak height, the natural next question +# is which is 'correct' or 'best'. +# +# The short answer is that there is not a singular definitive answer. Depending on +# one's goals and assumptions about the data, there may be better answers for particular +# use cases. The different measures make different assumptions about the generative model +# of the data under study. If we had a definitive model of the underlying generators of +# the different data components, and a clear understanding of they related to each other, +# then we could use that information to decide exactly how to proceed. +# +# However, for the case of neuro-electrophysiological recordings, there is not a definitively +# established generative model for the data, and as such, no singular or definitive answer +# to how best to model the data. +# +# For any individual project / analysis, one can choose the approach that best fits the +# assumed generative model of the data. For example, if one wishes to examine the data +# based on a linearly-additive model, then the linear-subtraction of components matches this, +# whereas if one wants to specify a linearly multiplicative model (equivalent to subtraction +# in log space, and the kind of model assumed by filtered noise processes), then the +# linear-division approach the the way to go. +# +# Within specparam, you can specify the approach to take for converting parameters post +# model fitting, which can be used to re-compute peak heights based on the desired model. +# For more discussion of this, see other documentation sections on choosing and defining +# parameter conversions. +# + +################################################################################################### + +# Initialize model objects, specifying different peak height parameter conversions +fm_log_sub = SpectralModel(converters={'periodic' : {'pw' : 'log_sub'}}, verbose=False) +fm_lin_sub = SpectralModel(converters={'periodic' : {'pw' : 'lin_sub'}}, verbose=False) + +# Fit the models to the data +fm_log_sub.fit(freqs, powers, freq_range) +fm_lin_sub.fit(freqs, powers, freq_range) + +# Check the resulting parameters, with different peak height values +print(fm_log_sub.results.get_params('periodic')) +print(fm_lin_sub.results.get_params('periodic')) + +################################################################################################### +# Does it matter which form I choose? +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the above, we have shown that choosing the peak height estimations does lead to different +# computed values. However, in most analyses, it is not the absolute values or absolute +# differences of these measures that is of interest, but their relative differences. +# +# Broadly speaking, a likely rule of thumb is that within the spectral parameterization +# approach, switching the model combination definition is generally unlikely to change the +# general pattern of things (in terms of which parameters change). However it could well +# change effect size measures (and as such, potentially the results of significant tests), +# such that it is plausible that the results of different model combination forms could +# be at least somewhat different. +# diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index fb95e68d..328f0f33 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -186,12 +186,6 @@ def _fit(self): self.results.model.modeled_spectrum = \ self.results.model._peak_fit + self.results.model._ap_fit - ## PARAMETER UPDATES - - # Convert fit peak parameters to updated values - self.results.params.periodic.add_params('converted', \ - self._create_peak_params(self.results.params.periodic.get_params('fit'))) - def _get_ap_guess(self, freqs, power_spectrum): """Get the guess parameters for the aperiodic fit. @@ -603,52 +597,3 @@ def _drop_peak_overlap(self, guess): guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) return guess - - - def _create_peak_params(self, fit_peak_params): - """Copies over the fit peak parameters output parameters, updating as appropriate. - - Parameters - ---------- - fit_peak_params : 2d array - Parameters that define the peak parameters directly fit to the spectrum. - - Returns - ------- - peak_params : 2d array - Updated parameter values for the peaks. - - Notes - ----- - The center frequency estimate is unchanged as the peak center frequency. - - The peak height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the fit peak height, as - the fit height is harder to interpret, due to peak overlaps. - - The peak bandwidth is updated to be 'both-sided', to reflect the overal width - of the peak, as opposed to the fit parameter, which is 1-sided standard deviation. - - Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available. - """ - - inds = self.modes.periodic.params.indices - - peak_params = np.empty((len(fit_peak_params), self.modes.periodic.n_params)) - - for ii, peak in enumerate(fit_peak_params): - - cpeak = peak.copy() - - # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - cf_ind = np.argmin(np.abs(self.data.freqs - peak[inds['cf']])) - cpeak[inds['pw']] = \ - self.results.model.modeled_spectrum[cf_ind] - self.results.model._ap_fit[cf_ind] - - # Bandwidth is updated to be 'two-sided' (as opposed to one-sided std dev) - cpeak[inds['bw']] = peak[inds['bw']] * 2 - - peak_params[ii] = cpeak - - return peak_params diff --git a/specparam/convert/__init__.py b/specparam/convert/__init__.py new file mode 100644 index 00000000..ac31b664 --- /dev/null +++ b/specparam/convert/__init__.py @@ -0,0 +1 @@ +"""Sub-module for functionality related to parameter conversions.""" diff --git a/specparam/convert/convert.py b/specparam/convert/convert.py new file mode 100644 index 00000000..bb71c151 --- /dev/null +++ b/specparam/convert/convert.py @@ -0,0 +1,73 @@ +"""Parameter converters. + +Notes +----- +Parameter converters should have the following properties, depending on component: +- for 'aperiodic' parameters : callable, takes 'fit_value' & 'model' as inputs +- for 'peak' parameters : callable, takes 'fit_value' & 'model', 'peak_ind' as inputs +""" + +import numpy as np + +from specparam.convert.definitions import get_converter + +################################################################################################### +################################################################################################### + +def convert_aperiodic_params(model, updates): + """Convert aperiodic parameters. + + Parameters + ---------- + model : SpectralModel + Model object, post model fitting. + updates : dict + Dictionary specifying the parameter conversions to do, whereby: + Each key is the name of a parameter. + Each value reflects what conversion to do. + This can be a string label for a built-in conversion, or a custom implementation. + + Returns + ------- + converted_parameters : 1d array + Converted aperiodic parameters. + """ + + converted_params = np.zeros_like(model.results.params.aperiodic._fit) + for param, p_ind in model.modes.aperiodic.params.indices.items(): + converter = get_converter('aperiodic', param, updates[param]) + fit_value = model.results.params.aperiodic._fit[\ + model.modes.aperiodic.params.indices[param]] + converted_params[p_ind] = converter(fit_value, model) + + return converted_params + + +def convert_periodic_params(model, updates): + """Convert periodic parameters. + + Parameters + ---------- + model : SpectralModel + Model object, post model fitting. + updates : dict + Dictionary specifying the parameter conversions to do, whereby: + Each key is the name of a parameter. + Each value reflects what conversion to do. + This can be a string label for a built-in conversion, or a custom implementation. + + Returns + ------- + converted_parameters : array + Converted periodic parameters. + """ + + converted_params = np.zeros_like(model.results.params.periodic._fit) + for peak_ind in range(len(converted_params)): + for param, param_ind in model.modes.periodic.params.indices.items(): + converter = get_converter('periodic', param, updates.get(param, None)) + fit_value = model.results.params.periodic._fit[\ + peak_ind, model.modes.periodic.params.indices[param]] + converted_params[peak_ind, param_ind] = converter(fit_value, model, peak_ind) + + return converted_params diff --git a/specparam/convert/converter.py b/specparam/convert/converter.py new file mode 100644 index 00000000..559ab7bb --- /dev/null +++ b/specparam/convert/converter.py @@ -0,0 +1,78 @@ +"""Parameter converter objects.""" + +################################################################################################### +################################################################################################### + +class BaseParamConverter(): + """General class for parameter converters - to be inherited by component specific converter. + + Parameters + ---------- + component : {'aperiodic', 'periodic'}, + Which component the converter relates to. + parameter : str + Label of the parameter the converter is for. + name : str + Name of the parameter converter. + description : str + Description of the parameter converter. + function : callable + Function that implements the parameter conversion. + """ + + def __init__(self, component, parameter, name, description, function): + """Initialize a parameter converter.""" + + self.component = component + self.parameter = parameter + self.name = name + self.description = description + self.function = function + + +class AperiodicParamConverter(BaseParamConverter): + """Parameter converter for aperiodic parameters.""" + + def __init__(self, parameter, name, description, function): + """Initialize an aperiodic parameter converter.""" + + super().__init__('aperiodic', parameter, name, description, function) + + + def __call__(self, fit_value, model): + """Call the aperiodic parameter converter. + + Parameters + ---------- + fit_value : float + Fit value for the parameter. + model : SpectralModel + Model object. + """ + + return self.function(fit_value, model) + + +class PeriodicParamConverter(BaseParamConverter): + """Parameter converter for periodic parameters.""" + + def __init__(self, parameter, name, description, function): + """Initialize a periodic parameter converter.""" + + super().__init__('periodic', parameter, name, description, function) + + + def __call__(self, fit_value, model, peak_ind): + """Call the peak parameter converter. + + Parameters + ---------- + fit_value : float + Fit value for the parameter. + model : SpectralModel + Model object. + peak_ind : int + Index of the current peak. + """ + + return self.function(fit_value, model, peak_ind) diff --git a/specparam/convert/definitions.py b/specparam/convert/definitions.py new file mode 100644 index 00000000..71d43832 --- /dev/null +++ b/specparam/convert/definitions.py @@ -0,0 +1,194 @@ +"""Define parameter converters.""" + +from copy import deepcopy + +from specparam.convert.converter import AperiodicParamConverter, PeriodicParamConverter +from specparam.convert.params import compute_peak_height + +################################################################################################### +## DEFINE DEFAULT CONVERTERS + +DEFAULT_CONVERTERS = { + 'aperiodic' : {'offset' : None, 'exponent' : None}, + 'periodic' : {'cf' : None, 'pw' : 'log_sub', 'bw' : 'full_width'}, +} + + +def update_converters(defaults, updates): + """Update default converters. + + Parameters + ---------- + defaults : dict + Default converters. + updates : dict + Converter definitions to update. + + Returns + ------- + converters : dict + Updated converters definition. + """ + + out = deepcopy(defaults) + for component, converters in updates.items(): + for param, converter in converters.items(): + out[component][param] = converter + + return out + +################################################################################################### +## APERIODIC PARAMETER CONVERTERS + +## AP - Null converter +ap_null = AperiodicParamConverter( + parameter=None, + name='ap_null', + description='Null converter for aperiodic converter - return fit parameter value.', + function=lambda fit_value, model : fit_value, +) + +################################################################################################### +## PERIODIC PARAMETER CONVERTERS + +## PE - Null converter +pe_null = PeriodicParamConverter( + parameter=None, + name='pe_null', + description='Null converter for aperiodic converter - return fit parameter value.', + function=lambda fit_value, model, peak_ind : fit_value, +) + +## PE - PW + +pw_log_sub = PeriodicParamConverter( + parameter='pw', + name='log_sub', + description='Convert peak height to be the log subtraction '\ + 'of full model and aperiodic component.', + function=lambda fit_value, model, peak_ind : \ + compute_peak_height(model, peak_ind, 'log', 'subtract'), +) + +pw_log_div = PeriodicParamConverter( + parameter='pw', + name='log_div', + description='Convert peak height to be the log division '\ + 'of full model and aperiodic component.', + function=lambda fit_value, model, peak_ind : \ + compute_peak_height(model, peak_ind, 'log', 'divide'), +) + +pw_lin_sub = PeriodicParamConverter( + parameter='pw', + name='lin_sub', + description='Convert peak height to be the linear subtraction '\ + 'of full model and aperiodic component.', + function=lambda fit_value, model, peak_ind : \ + compute_peak_height(model, peak_ind, 'linear', 'subtract'), +) + +pw_lin_div = PeriodicParamConverter( + parameter='pw', + name='lin_div', + description='Convert peak height to be the linear division '\ + 'of full model and aperiodic component.', + function=lambda fit_value, model, peak_ind : \ + compute_peak_height(model, peak_ind, 'linear', 'divide'), +) + +## PE - BW + +bw_full_width = PeriodicParamConverter( + parameter='bw', + name='full_width', + description='Convert peak bandwidth to be the full, '\ + 'two-sided bandwidth of the peak.', + function=lambda fit_value, model, peak_ind : 2 * fit_value, +) + +################################################################################################### +## COLLECT ALL CONVERTERS + +# Null converters: extract the fit parameter, with no conversion applied +NULL_CONVERTERS = { + 'aperiodic' : ap_null, + 'periodic' : pe_null, +} + +# Collect converters by component & by paramter +CONVERTERS = { + + 'aperiodic' : { + 'offset' : {}, + 'exponent' : {}, + }, + + 'periodic' : { + 'cf' : {}, + 'pw' : { + 'log_sub' : pw_log_sub, + 'log_div' : pw_log_div, + 'lin_sub' : pw_lin_sub, + 'lin_div' : pw_lin_div, + }, + 'bw' : { + 'full_width' : bw_full_width, + }, + } +} + +################################################################################################### +## SELECTOR & CHECKER FUNCTIONS + +def get_converter(component, parameter, converter): + """Get a specified parameter converter function. + + Parameters + ---------- + component : {'aperiodic', 'periodic'} + Which component to access a converter for. + parameter : str + The name of the parameter to access a converter for. + converter : str or callable + The converter to access. + If str, should correspond to a built-in converter. + If callable, should be a custom converter definition, following framework. + + Returns + ------- + converter : callable + Function to compute parameter conversion. + + Notes + ----- + This function accesses predefined converters from `CONVERTERS`. + If a callable, as a custom definition, is passed in, the same callable is returned. + If the parameter or converter name is not found, a null converter + (from `NULL_CONVERTERS`) is returned. + """ + + if isinstance(converter, str) and converter in CONVERTERS[component][parameter]: + converter = CONVERTERS[component][parameter][converter] + elif callable(converter): + pass + else: + converter = NULL_CONVERTERS[component] + + return converter + + +def check_converters(component): + """Check the set of parameter converters that are available. + + Parameters + ---------- + component : {'aperiodic', 'periodic'} + Which component to check available parameter converters for. + """ + + print('Available {:s} converters:'.format(component)) + for param, convs in CONVERTERS[component].items(): + print(param) + for label, converter in convs.items(): + print(' {:10s} {:s}'.format(converter.name, converter.description)) diff --git a/specparam/convert/params.py b/specparam/convert/params.py new file mode 100644 index 00000000..3e4ad7f8 --- /dev/null +++ b/specparam/convert/params.py @@ -0,0 +1,43 @@ +"""Conversion functions for specific parameters.""" + +import numpy as np + +from specparam.utils.select import nearest_ind + +################################################################################################### +################################################################################################### + +## PARAMETER CONVERTERS + +PEAK_HEIGHT_OPERATIONS = { + 'subtract' : np.subtract, + 'divide' : np.divide, +} + +def compute_peak_height(model, peak_ind, spacing, operation): + """Compute peak heights, based on specified approach & spacing. + + Parameters + ---------- + model : SpectralModel + Model object, post fitting. + peak_ind : int + Index of which peak to compute height for. + spacing : {'log', 'linear'} + Spacing to extract the data components in. + operation : {'subtract', 'divide'} + Approach to take to compute the peak height measure. + + Returns + ------- + peak_height : float + Computed peak height. + """ + + ind = nearest_ind(model.data.freqs, model.results.params.periodic._fit[\ + peak_ind, model.modes.periodic.params.indices['cf']]) + peak_height = PEAK_HEIGHT_OPERATIONS[operation](\ + model.results.model.get_component('full', spacing)[ind], + model.results.model.get_component('aperiodic', spacing)[ind]) + + return peak_height diff --git a/specparam/models/base.py b/specparam/models/base.py index d81ca4ad..0981e2a2 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -31,10 +31,11 @@ class BaseModel(): Verbosity status. """ - def __init__(self, aperiodic_mode, periodic_mode, verbose): + def __init__(self, aperiodic_mode, periodic_mode, converters, verbose): """Initialize object.""" self.add_modes(aperiodic_mode, periodic_mode) + self._converters = converters self.verbose = verbose diff --git a/specparam/models/model.py b/specparam/models/model.py index 78875310..dd8ea8e7 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -12,6 +12,9 @@ from specparam.data.conversions import model_to_dataframe from specparam.results.results import Results +from specparam.convert.convert import convert_aperiodic_params, convert_periodic_params +from specparam.convert.definitions import update_converters, DEFAULT_CONVERTERS + from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF from specparam.algorithms.definitions import ALGORITHMS, check_algorithm_definition @@ -20,6 +23,7 @@ from specparam.modutils.errors import NoDataError, FitError from specparam.modutils.docs import (copy_doc_func_to_method, replace_docstring_sections, docs_get_section) +from specparam.utils.checks import check_all_none from specparam.io.files import load_json from specparam.io.models import save_model from specparam.plts.model import plot_model @@ -47,6 +51,8 @@ class SpectralModel(BaseModel): Setting for the algorithm. metrics : Metrics or list of Metric or list or str Metrics definition(s) to use to evaluate the model. + converters : dict + Definition for parameter conversions to apply post fitting. bands : Bands or dict or int or None, optional Bands object with band definitions, or definition that can be turned into a Bands object. debug : bool, optional, default: False @@ -81,10 +87,13 @@ class SpectralModel(BaseModel): def __init__(self, aperiodic_mode='fixed', periodic_mode='gaussian', algorithm='spectral_fit', algorithm_settings=None, - metrics=None, bands=None, debug=False, verbose=True, **model_kwargs): + metrics=None, converters=None, bands=None, + debug=False, verbose=True, **model_kwargs): """Initialize model object.""" - BaseModel.__init__(self, aperiodic_mode, periodic_mode, verbose) + converters = DEFAULT_CONVERTERS if not converters else \ + update_converters(DEFAULT_CONVERTERS, converters) + BaseModel.__init__(self, aperiodic_mode, periodic_mode, converters, verbose) self.data = Data() @@ -172,6 +181,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True): # Call the fit function from the algorithm object self.algorithm._fit() + # Do any parameter conversions + self._convert_params() + # Compute post-fit metrics self.results.metrics.compute_metrics(self.data, self.results) @@ -333,6 +345,17 @@ def to_df(self, bands=None): return model_to_dataframe(self.results.get_results(), self.modes, bands) + def _convert_params(self): + """Convert fit parameters.""" + + if not check_all_none(self._converters['aperiodic'].values()): + self.results.params.aperiodic.add_params(\ + 'converted', convert_aperiodic_params(self, self._converters['aperiodic'])) + if not check_all_none(self._converters['periodic'].values()): + self.results.params.periodic.add_params(\ + 'converted', convert_periodic_params(self, self._converters['periodic'])) + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): """Set, or reset, data & results attributes to empty. diff --git a/specparam/modes/definitions.py b/specparam/modes/definitions.py index 4f7c540f..7d5e7036 100644 --- a/specparam/modes/definitions.py +++ b/specparam/modes/definitions.py @@ -166,6 +166,8 @@ 'periodic' : PE_MODES, } +################################################################################################### +## CHECKER FUNCTION def check_modes(component, check_params=False): """Check the set of modes that are available. diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index bb714459..82ef2281 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -52,6 +52,33 @@ def get_modes(self): periodic_mode=self.periodic.name if self.periodic else None) + def get_params(self, param_type='list'): + """Get a description of the parameters, across modes. + + Parameters + ---------- + param_type : {'list', 'dict'} + The output type for the parameters. + + Returns + ------- + params : dict + Parameter definition for the set of modes. + Each key is a component label. + Each set of values if the parameters, with type specified by 'param_type'. + """ + + params = {} + for component in self.components: + params[component] = getattr(self, component).params.labels + + if param_type == 'dict': + params = {component : {param : None for param in params[component]} \ + for component in params.keys()} + + return params + + def print(self, description=False, concise=False): """Print out the current fit modes. diff --git a/specparam/results/params.py b/specparam/results/params.py index 811a1e73..9c9731f5 100644 --- a/specparam/results/params.py +++ b/specparam/results/params.py @@ -134,7 +134,7 @@ def params(self): Notes ----- - If available, this return converted parameters. If not, this returns fit parameters. + If available, this returns converted parameters. If not, this returns fit parameters. """ return self.get_params('converted' if self.has_converted else 'fit') @@ -174,18 +174,6 @@ def add_params(self, version, params): self._converted = params - def convert_params(self, converter): - """Convert fit parameters to converted versions and store in the object. - - Parameters - ---------- - converter : func - Callable that takes in fit parameters and returns converted version. - """ - - self.add_params('converted', converter(self.get_params('fit'))) - - def get_params(self, version, field=None): """Get parameter values from the object. diff --git a/specparam/tests/algorithms/test_spectral_fit.py b/specparam/tests/algorithms/test_spectral_fit.py index 313d2895..ac783526 100644 --- a/specparam/tests/algorithms/test_spectral_fit.py +++ b/specparam/tests/algorithms/test_spectral_fit.py @@ -17,8 +17,8 @@ def test_algorithm_inherit(tfm): class TestAlgo(BaseModel): def __init__(self): - BaseModel.__init__(self, aperiodic_mode='fixed', - periodic_mode='gaussian', verbose=False) + BaseModel.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian', + converters=None, verbose=False) self.data = Data() self.add_data = self.data.add_data self.results = Results(modes=self.modes) diff --git a/specparam/tests/convert/__init__.py b/specparam/tests/convert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/specparam/tests/convert/test_convert.py b/specparam/tests/convert/test_convert.py new file mode 100644 index 00000000..35649c5f --- /dev/null +++ b/specparam/tests/convert/test_convert.py @@ -0,0 +1,30 @@ +"""Test functions for specparam.convert.convert.""" + +from specparam.convert.convert import * + +################################################################################################### +################################################################################################### + +def test_convert_aperiodic_params(tfm): + + # Take copy to not change test object + ntfm = tfm.copy() + + converted = convert_aperiodic_params(ntfm, + {'offset' : None, 'exponent' : lambda fit_value, model : 1.}) + assert converted[ntfm.modes.aperiodic.params.indices['offset']] == \ + ntfm.results.get_params('aperiodic', 'offset', 'fit') + assert converted[ntfm.modes.aperiodic.params.indices['exponent']] == 1. + +def test_convert_periodic_params(tfm): + + # Take copy to not change test object + ntfm = tfm.copy() + + converted = convert_periodic_params(ntfm, + {'cf' : None, 'pw' : lambda fit_value, model, peak_ind : 1., 'bw' : None}) + assert np.array_equal(converted[:, ntfm.modes.periodic.params.indices['pw']], + np.array([1.] * ntfm.results.n_peaks)) + for param in ['cf', 'bw']: # test parameters that should not have been changed + assert np.array_equal(converted[:, ntfm.modes.periodic.params.indices[param]], + ntfm.results.get_params('periodic', param, 'fit')) diff --git a/specparam/tests/convert/test_converter.py b/specparam/tests/convert/test_converter.py new file mode 100644 index 00000000..57f3ded3 --- /dev/null +++ b/specparam/tests/convert/test_converter.py @@ -0,0 +1,27 @@ +"""Test functions for specparam.convert.converter.""" + +from specparam.convert.converter import * + +################################################################################################### +################################################################################################### + +def test_base_param_converter(): + + baconv = BaseParamConverter('tcomponent', 'tparameter', 'tname', 'tdescription', lambda a : a) + assert baconv + +def test_aperiodic_param_converter(): + + apconv = AperiodicParamConverter('tparameter', 'tname', 'tdescription', + lambda param, model : param) + assert apconv + assert apconv.component == 'aperiodic' + assert apconv(1, None) == 1 + +def test_periodic_param_converter(): + + peconv = PeriodicParamConverter('tparameter', 'tname', 'tdescription', + lambda param, model, peak_ind : param) + assert peconv + assert peconv.component == 'periodic' + assert peconv(1, None, None) == 1 diff --git a/specparam/tests/convert/test_definitions.py b/specparam/tests/convert/test_definitions.py new file mode 100644 index 00000000..a4fb973a --- /dev/null +++ b/specparam/tests/convert/test_definitions.py @@ -0,0 +1,41 @@ +"""Test functions for specparam.convert.definitions.""" + +from specparam.modes.mode import VALID_COMPONENTS +from specparam.convert.converter import BaseParamConverter + +from specparam.convert.definitions import * + +################################################################################################### +################################################################################################### + +def test_converters_library(): + + for component in VALID_COMPONENTS: + for parameter, converters in CONVERTERS[component].items(): + for label, converter in converters.items(): + assert isinstance(converter, BaseParamConverter) + assert converter.component == component + assert converter.name == label + assert callable(converter.function) + +def test_update_converters(): + + converters1 = {'aperiodic' : {'exponent' : 'custom'}} + out1 = update_converters(DEFAULT_CONVERTERS, converters1) + assert out1['periodic'] == DEFAULT_CONVERTERS['periodic'] + assert out1['aperiodic']['exponent'] == converters1['aperiodic']['exponent'] + + converters2 = {'periodic' : {'cf' : 'custom'}} + out2 = update_converters(DEFAULT_CONVERTERS, converters2) + assert out2['aperiodic'] == DEFAULT_CONVERTERS['aperiodic'] + assert out2['periodic']['cf'] == converters2['periodic']['cf'] + + converters3 = {'aperiodic' : {'knee' : 'custom'}} + out3 = update_converters(DEFAULT_CONVERTERS, converters3) + assert out3['periodic'] == DEFAULT_CONVERTERS['periodic'] + assert out3['aperiodic']['knee'] == converters3['aperiodic']['knee'] + +def test_check_converters(): + + check_converters('aperiodic') + check_converters('periodic') diff --git a/specparam/tests/convert/test_params.py b/specparam/tests/convert/test_params.py new file mode 100644 index 00000000..994d52ee --- /dev/null +++ b/specparam/tests/convert/test_params.py @@ -0,0 +1,13 @@ +"""Test functions for specparam.convert.params.""" + +from specparam.convert.params import * + +################################################################################################### +################################################################################################### + +def test_compute_peak_height(tfm): + + for spacing in ['log', 'linear']: + for op in ['subtract', 'divide']: + out = compute_peak_height(tfm, 0, spacing, op) + assert isinstance(out, float) diff --git a/specparam/tests/models/test_base.py b/specparam/tests/models/test_base.py index cc2b029b..19025600 100644 --- a/specparam/tests/models/test_base.py +++ b/specparam/tests/models/test_base.py @@ -7,12 +7,14 @@ def test_base_model(): - tbase = BaseModel(aperiodic_mode='fixed', periodic_mode='gaussian', verbose=False) + tbase = BaseModel(aperiodic_mode='fixed', periodic_mode='gaussian', + converters=None, verbose=False) assert isinstance(tbase, BaseModel) def test_common_base_copy(): - tbase = BaseModel(aperiodic_mode='fixed', periodic_mode='gaussian', verbose=False) + tbase = BaseModel(aperiodic_mode='fixed', periodic_mode='gaussian', + converters=None, verbose=False) ntbase = tbase.copy() assert ntbase != tbase diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 82e39fd4..9639e9b4 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -76,7 +76,8 @@ def test_fit_nk(): # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, \ + tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" @@ -107,7 +108,8 @@ def test_fit_knee(): # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, \ + tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) def test_fit_default_metrics(): """Test computing metrics, post model fitting.""" @@ -138,6 +140,24 @@ def test_fit_custom_metrics(): assert key in metrics assert isinstance(val, float) +def test_fit_null_conversions(tfm): + + null_converters = tfm.modes.get_params('dict') + ntfm = SpectralModel(converters=null_converters) + + ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear')) + assert np.all(np.isnan(ntfm.results.get_params('aperiodic', version='converted'))) + assert np.all(np.isnan(ntfm.results.get_params('periodic', version='converted'))) + +def test_fit_custom_conversions(tfm): + + converters = {'periodic' : {'pw' : 'lin_sub'}} + ntfm = SpectralModel(converters=converters) + + ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear')) + assert not np.array_equal( + tfm.results.get_params('periodic', 'pw'), ntfm.results.get_params('periodic', 'pw')) + def test_checks(): """Test various checks, errors and edge cases for model fitting. This tests all the input checking done in `_prepare_data`. diff --git a/specparam/tests/modes/test_modes.py b/specparam/tests/modes/test_modes.py index 40cdbd44..c8753b58 100644 --- a/specparam/tests/modes/test_modes.py +++ b/specparam/tests/modes/test_modes.py @@ -17,7 +17,7 @@ def test_modes(): assert isinstance(modes.periodic, Mode) modes.check_params() -def test_modes_get_modes(): +def test_modes_gets(): ap_mode_name = 'fixed' pe_mode_name = 'gaussian' @@ -27,3 +27,12 @@ def test_modes_get_modes(): assert isinstance(mode_names, ModelModes) assert mode_names.aperiodic_mode == ap_mode_name assert mode_names.periodic_mode == pe_mode_name + + params = modes.get_params('list') + assert isinstance(params, dict) + for comp in modes.components: + assert isinstance(params[comp], list) + params = modes.get_params('dict') + assert isinstance(params, dict) + for comp in modes.components: + assert isinstance(params[comp], dict) diff --git a/specparam/tests/utils/test_checks.py b/specparam/tests/utils/test_checks.py index 0f4cdc75..b91fc7cb 100644 --- a/specparam/tests/utils/test_checks.py +++ b/specparam/tests/utils/test_checks.py @@ -113,3 +113,13 @@ def test_check_inds(): # Check None inputs, including length input assert isinstance(check_inds(None), slice) assert isinstance(check_inds(None, 4), range) + +def test_check_all_none(): + + assert check_all_none([None]) + assert check_all_none([None, None]) + assert check_all_none((None,)) + + assert not check_all_none([]) + assert not check_all_none([1, None]) + assert not check_all_none([1, 2, 3]) diff --git a/specparam/utils/checks.py b/specparam/utils/checks.py index 4e59e278..836577db 100644 --- a/specparam/utils/checks.py +++ b/specparam/utils/checks.py @@ -210,3 +210,27 @@ def check_inds(inds, length=None): length, inds.step if inds.step else 1) return inds + + +def check_all_none(collection): + """Check whether all elements of a collection are None. + + Parameters + ---------- + collection : list or type-castable to list + Collection of elements to check for all None contents. + + Returns + ------- + output : bool + Indicetor for whether all elements of `collection` are None. + """ + + items = set(list(collection)) + + if len(items) == 0: + output = False + else: + output = len(items) == 1 and items == {None} + + return output