Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from scipy.signal import spectrogram

from nstat.compat.matlab import SignalObj
from nstat.SignalObj import SignalObj


def _fallback_multitaper_psd(signal: np.ndarray, fs_hz: float) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -31,7 +31,7 @@ def main() -> None:
time = np.arange(0.0, duration_s, dt, dtype=float)

signal = np.sin(2.0 * np.pi * f0_hz * time)
sig_obj = SignalObj(time=time, data=signal, name="sine_signal", units="a.u.")
sig_obj = SignalObj(time=time, data=signal, name="sine_signal", yunits="a.u.")

try:
freq_hz, psd = sig_obj.MTMspectrum()
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 31 additions & 10 deletions nstat/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@ def _as_neuron_indices(trial: Trial, neuron_selector) -> list[int]:
raise TypeError("neuron selector must be a MATLAB-style one-based index, name, or sequence of either")


def _restore_trial_partition(trial: Trial, original_partition: np.ndarray) -> None:
def _restore_trial_partition(trial: Trial, original_partition: np.ndarray, original_window: np.ndarray | None = None) -> None:
trial.restoreToOriginal()
if original_partition.size:
trial.setTrialPartition(original_partition)
trial.setTrialTimesFor("training")
if original_window is None or original_window.size != 2:
trial.setTrialTimesFor("training")
return
training = original_partition[:2] if original_partition.size >= 2 else None
validation = original_partition[2:4] if original_partition.size >= 4 else None
if training is not None and training.size == 2 and np.allclose(original_window, training, rtol=0.0, atol=1e-12):
trial.setTrialTimesFor("training")
elif validation is not None and validation.size == 2 and np.allclose(original_window, validation, rtol=0.0, atol=1e-12):
trial.setTrialTimesFor("validation")
else:
trial.setMinTime(float(original_window[0]))
trial.setMaxTime(float(original_window[1]))


def _time_rescaled_z(counts: np.ndarray, lam_per_bin: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -154,7 +165,7 @@ def GLMFit(
lambdaIndex: int,
Algorithm: str = "GLM",
*,
l2: float = 1e-6,
l2: float = 0.0,
max_iter: int = 120,
):
algorithm = str(Algorithm or "GLM").upper()
Expand Down Expand Up @@ -243,13 +254,14 @@ def run_analysis_for_neuron(
config_collection: ConfigCollection,
*,
algorithm: str = "GLM",
l2: float = 1e-6,
l2: float = 0.0,
max_iter: int = 120,
) -> FitResult:
if neuron_index < 0:
raise IndexError("neuron_index must be >= 0")

original_partition = np.asarray(trial.getTrialPartition(), dtype=float).reshape(-1)
original_window = np.asarray([trial.minTime, trial.maxTime], dtype=float).reshape(-1)
neuron_number = int(neuron_index) + 1
labels: list[list[str]] = []
lambda_parts: list[Covariate] = []
Expand All @@ -272,7 +284,10 @@ def run_analysis_for_neuron(
spike_train.setName(str(neuron_number))

for cfg_index in range(1, config_collection.numConfigs + 1):
_restore_trial_partition(trial, original_partition)
trial.restoreToOriginal()
if original_partition.size:
trial.setTrialPartition(original_partition)
trial.setTrialTimesFor("training")
config_collection.setConfig(trial, cfg_index)

current_labels = trial.getLabelsFromMask(neuron_number)
Expand Down Expand Up @@ -326,7 +341,7 @@ def run_analysis_for_neuron(
for part in lambda_parts[1:]:
merged_lambda = merged_lambda.merge(part)

_restore_trial_partition(trial, original_partition)
_restore_trial_partition(trial, original_partition, original_window)
fit_result = FitResult(
spike_train,
labels,
Expand Down Expand Up @@ -357,7 +372,7 @@ def run_analysis_for_all_neurons(
config_collection: ConfigCollection,
*,
algorithm: str = "GLM",
l2: float = 1e-6,
l2: float = 0.0,
max_iter: int = 120,
) -> list[FitResult]:
out: list[FitResult] = []
Expand Down Expand Up @@ -433,8 +448,15 @@ def computeFitResidual(nspikeObj, lambdaInput: Covariate, windowSize: float = 0.
nCopy.setMinTime(lambdaInput.minTime)
nCopy.setMaxTime(lambdaInput.maxTime)

sumSpikes = nCopy.getSigRep(windowSize)
# MATLAB's static Analysis.computeFitResidual ultimately operates on
# the resampled spike-count grid, even when a finer windowSize is
# requested. Preserve that canonical helper behavior here.
sumSpikes = nCopy.getSigRep(1.0 / float(nCopy.sampleRate), float(nCopy.minTime), float(nCopy.maxTime))
windowTimes = np.linspace(float(nCopy.minTime), float(nCopy.maxTime), sumSpikes.time.size, dtype=float)
if np.isfinite(windowSize) and windowSize > 0:
origin = float(nCopy.minTime)
windowTimes = origin + np.round((windowTimes - origin) / float(windowSize)) * float(windowSize)
windowTimes = np.round(windowTimes, decimals=12)
lambdaInt = lambdaInput.integral()
lambdaIntVals = (
lambdaInt.getValueAt(windowTimes[1:]).reshape(-1, lambdaInt.dimension)
Expand Down Expand Up @@ -465,8 +487,7 @@ def KSPlot(fitResults: FitResult, DTCorrection: int = 1, makePlot: int = 1):

@staticmethod
def plotFitResidual(fitResults: FitResult, windowSize: float = 0.01, makePlot: int = 1):
del windowSize
fitResults.computeFitResidual()
fitResults.computeFitResidual(window_size=windowSize)
return fitResults.plotResidual() if makePlot else []

@staticmethod
Expand Down
35 changes: 35 additions & 0 deletions nstat/class_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,41 @@


EXPECTED_RUNTIME_MEMBERS: dict[str, tuple[str, ...]] = {
"nstat.SignalObj": (
"shift",
"shiftMe",
"alignTime",
"power",
"sqrt",
"xcov",
"periodogram",
"MTMspectrum",
"spectrogram",
"plotVariability",
"plotAllVariability",
"plotPropsSet",
"areDataLabelsEmpty",
"isLabelPresent",
"convertNamesToIndices",
"clearPlotProps",
),
"nstat.Trial": (
"findMinSampleRate",
"getAllLabels",
"getDesignMatrix",
"getNumHist",
"getEnsCovMatrix",
"getTrialPartition",
"plotCovariates",
"plotRaster",
"toStructure",
"fromStructure",
),
"nstat.nstColl": (
"psthBars",
"estimateVarianceAcrossTrials",
"ssglm",
),
"nstat.Analysis": (
"GLMFit",
"RunAnalysisForNeuron",
Expand Down
Loading