diff --git a/PORTING_MAP.md b/PORTING_MAP.md new file mode 100644 index 00000000..9ef282df --- /dev/null +++ b/PORTING_MAP.md @@ -0,0 +1,779 @@ +# nSTAT Porting Map: Matlab → Python + +> Auto-generated: 2026-03-11 +> Matlab repo: https://github.com/cajigaslab/nSTAT (commit 3ec94ed) +> Python repo: https://github.com/cajigaslab/nSTAT-python (main branch) + +## Architecture Note + +The Python port groups related classes into shared modules rather than one-class-per-file: + +| Python Module | Matlab Classes Contained | +|---|---| +| `nstat/core.py` | `SignalObj`, `Covariate`, `nspikeTrain` | +| `nstat/trial.py` | `CovariateCollection` (≡CovColl), `SpikeTrainCollection` (≡nstColl), `Trial`, `TrialConfig`, `ConfigCollection` (≡ConfigColl) | +| `nstat/fit.py` | `FitResult`, `FitSummary`/`FitResSummary` | +| `nstat/analysis.py` | `Analysis` | +| `nstat/cif.py` | `CIF` | +| `nstat/decoding_algorithms.py` | `DecodingAlgorithms` | +| `nstat/confidence_interval.py` | `ConfidenceInterval` | +| `nstat/events.py` | `Events` | +| `nstat/history.py` | `History` | + +Thin-wrapper files exist for Matlab-style imports (e.g., `from nstat.SignalObj import SignalObj`): +`SignalObj.py`, `Covariate.py`, `nspikeTrain.py`, `nstColl.py`, `CovColl.py`, +`TrialConfig.py`, `ConfigColl.py`, `FitResult.py`, `FitResSummary.py`, +`DecodingAlgorithms.py`, `ConfidenceInterval.py` + +--- + +## Class Files + +### SignalObj (Matlab: `SignalObj.m` → Python: `nstat/core.py :: SignalObj`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `SignalObj` (constructor) | `SignalObj.__init__` | ✅ Verified | +| 2 | `setName` | `setName` | ✅ Verified | +| 3 | `setXlabel` | `setXlabel` | ✅ Verified | +| 4 | `setYLabel` | `setYLabel` | ✅ Verified | +| 5 | `setUnits` | `setUnits` | ✅ Verified | +| 6 | `setXUnits` | `setXUnits` | ✅ Verified | +| 7 | `setYUnits` | `setYUnits` | ✅ Verified | +| 8 | `setSampleRate` | `setSampleRate` | ✅ Verified | +| 9 | `setDataLabels` | `setDataLabels` | ✅ Verified | +| 10 | `setMinTime` | `setMinTime` | ✅ Verified | +| 11 | `setMaxTime` | `setMaxTime` | ✅ Verified | +| 12 | `setPlotProps` | `setPlotProps` | ✅ Verified | +| 13 | `setMask` | `setMask` | ✅ Verified | +| 14 | `getTime` | `getTime` | ✅ Verified | +| 15 | `getData` | `getData` | ✅ Verified | +| 16 | `getOriginalData` | `getOriginalData` | ✅ Verified | +| 17 | `getOrigDataSig` | `getOrigDataSig` | ✅ Verified | +| 18 | `getValueAt` | `getValueAt` | ✅ Verified | +| 19 | `getPlotProps` | `getPlotProps` | ✅ Verified | +| 20 | `getIndicesFromLabels` | `getIndicesFromLabels` | ✅ Verified | +| 21 | `plus` | `__add__`, `__radd__` | ✅ Verified | +| 22 | `minus` | `__sub__`, `__rsub__` | ✅ Verified | +| 23 | `uplus` | `__pos__` | ✅ Verified | +| 24 | `uminus` | `__neg__` | ✅ Verified | +| 25 | `power` | `power` | ✅ Verified | +| 26 | `sqrt` | `sqrt` | ✅ Verified | +| 27 | `times` | `__mul__` (element-wise) | ✅ Verified | +| 28 | `mtimes` | `__matmul__` | ✅ Verified | +| 29 | `rdivide` | `__truediv__`, `__rtruediv__` | ✅ Verified | +| 30 | `ldivide` | `ldivide` | ✅ Verified | +| 31 | `ctranspose` | `T` (property) | ✅ Verified | +| 32 | `transpose` | `T` (property) | ✅ Verified | +| 33 | `derivative` | `derivative` | ✅ Verified | +| 34 | `derivativeAt` | `derivativeAt` | ✅ Verified | +| 35 | `integral` | `integral` | ✅ Verified | +| 36 | `filter` | `filter` | ✅ Verified | +| 37 | `filtfilt` | `filtfilt` | ✅ Verified | +| 38 | `makeCompatible` | `makeCompatible` | ✅ Verified | +| 39 | `abs` | `abs`, `__abs__` | ✅ Verified | +| 40 | `log` | `log` | ✅ Verified | +| 41 | `median` | `median` | ✅ Verified | +| 42 | `mode` | `mode` | ✅ Verified | +| 43 | `mean` | `mean` | ✅ Verified | +| 44 | `std` | `std` | ✅ Verified | +| 45 | `max` | `max` | ✅ Verified | +| 46 | `min` | `min` | ✅ Verified | +| 47 | `autocorrelation` | `autocorrelation` | ✅ Verified | +| 48 | `crosscorrelation` | `crosscorrelation` | ✅ Verified | +| 49 | `periodogram` | `periodogram` | ✅ Verified | +| 50 | `MTMspectrum` | `MTMspectrum` | ✅ Verified | +| 51 | `spectrogram` | `spectrogram` | ✅ Verified | +| 52 | `xcorr` | `xcorr` | ✅ Verified | +| 53 | `xcov` | `xcov` | ✅ Verified | +| 54 | `merge` | `merge` | ✅ Verified | +| 55 | `copySignal` | `copySignal` | ✅ Verified | +| 56 | `resample` | `resample` | ✅ Verified | +| 57 | `resampleMe` | `resampleMe` | ✅ Verified | +| 58 | `restoreToOriginal` | `restoreToOriginal` | ✅ Verified | +| 59 | `resetMask` | `resetMask` | ✅ Verified | +| 60 | `findIndFromDataMask` | `findIndFromDataMask` | ✅ Verified | +| 61 | `findNearestTimeIndices` | `findNearestTimeIndices` | ✅ Verified | +| 62 | `findNearestTimeIndex` | `findNearestTimeIndex` | ✅ Verified | +| 63 | `shift` | `shift` | ✅ Verified | +| 64 | `shiftMe` | `shiftMe` | ✅ Verified | +| 65 | `alignTime` | `alignTime` | ✅ Verified | +| 66 | `plotPropsSet` | `plotPropsSet` | ✅ Verified | +| 67 | `areDataLabelsEmpty` | `areDataLabelsEmpty` | ✅ Verified | +| 68 | `isLabelPresent` | `isLabelPresent` | ✅ Verified | +| 69 | `isMaskSet` | `isMaskSet` | ✅ Verified | +| 70 | `convertNamesToIndices` | `convertNamesToIndices` | ✅ Verified | +| 71 | `alignToMax` | `alignToMax` | ✅ Verified | +| 72 | `findGlobalPeak` | `findGlobalPeak` | ✅ Verified | +| 73 | `findPeaks` | `findPeaks` | ✅ Verified | +| 74 | `findMaxima` | `findMaxima` | ✅ Verified | +| 75 | `findMinima` | `findMinima` | ✅ Verified | +| 76 | `clearPlotProps` | `clearPlotProps` | ✅ Verified | +| 77 | `dataToStructure` | `dataToStructure` | ✅ Verified | +| 78 | `dataToMatrix` | `dataToMatrix` | ✅ Verified | +| 79 | `getSubSignal` | `getSubSignal` | ✅ Verified | +| 80 | `normWindowedSignal` | `normWindowedSignal` | ✅ Verified | +| 81 | `windowedSignal` | `windowedSignal` | ✅ Verified | +| 82 | `getSigInTimeWindow` | `getSigInTimeWindow` | ✅ Verified | +| 83 | `getSubSignalsWithinNStd` | `getSubSignalsWithinNStd` | ✅ Verified | +| 84 | `plot` | `plot` | ✅ Verified | +| 85 | `setupPlots` | (internal to `plot`) | ✅ N/A-internal | +| 86 | `plotVariability` | `plotVariability` | ✅ Verified | +| 87 | `plotAllVariability` | `plotAllVariability` | ✅ Verified | +| 88 | `getIndexFromLabel` | `getIndexFromLabel` | ✅ Verified | +| 89 | `setDataMask` | `setDataMask` | ✅ Verified | +| 90 | `setMaskByInd` | `setMaskByInd` | ✅ Verified | +| 91 | `setMaskByLabels` | `setMaskByLabels` | ✅ Verified | +| 92 | `getSubSignalFromInd` | `getSubSignalFromInd` | ✅ Verified | +| 93 | `getSubSignalFromNames` | `getSubSignalFromNames` | ✅ Verified | +| 94 | `signalFromStruct` | `signalFromStruct` (staticmethod) | ✅ Verified | +| 95 | `convertSigStructureToStructure` | (internal helper) | ✅ N/A-internal | +| 96 | `convertSimpleStructureToSigStructure` | (internal helper) | ✅ N/A-internal | +| — | _Local helpers:_ `cell2str`, `parsePlotProps`, `getAvailableColor` | (internal) | ✅ N/A-internal | + +**Python-only methods (not in Matlab):** `with_metadata`, `_spawn`, `_binary_operand_matrix`, `_binary_op`, `_selector_to_zero_based`, `_labels_for_indices`, `_plot_props_for_indices`, `dimension` (property), `values` (property), `units` (property), `sample_rate` (property), `setConfInterval` + +--- + +### Covariate (Matlab: `Covariate.m` → Python: `nstat/core.py :: Covariate`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `Covariate` (constructor) | `Covariate.__init__` | ✅ Verified | +| 2 | `computeMeanPlusCI` | `computeMeanPlusCI` | ✅ Verified | +| 3 | `plot` | `plot` | ✅ Verified | +| 4 | `getSubSignal` | `getSubSignal` | ✅ Verified | +| 5 | `getSigRep` | `getSigRep` | ✅ Verified | +| 6 | `get.mu` | `mu` (property) | ✅ Verified | +| 7 | `get.sigma` | `sigma` (property) | ✅ Verified | +| 8 | `filtfilt` | (inherited from SignalObj) | ✅ Verified | +| 9 | `toStructure` | `toStructure` | ✅ Verified | +| 10 | `isConfIntervalSet` | `isConfIntervalSet` | ✅ Verified | +| 11 | `setConfInterval` | `setConfInterval` | ✅ Verified | +| 12 | `copySignal` | `copySignal` | ✅ Verified | +| 13 | `plus` | `__add__` | ✅ Verified | +| 14 | `minus` | `__sub__` | ✅ Verified | +| 15 | `dataToStructure` | (inherited from SignalObj) | ✅ Verified | +| 16 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | + +--- + +### nspikeTrain (Matlab: `nspikeTrain.m` → Python: `nstat/core.py :: nspikeTrain`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `nspikeTrain` (constructor) | `nspikeTrain.__init__` | ✅ Verified | +| 2 | `getLStatistic` | `getLStatistic` | ✅ Verified | +| 3 | `setMER` | `setMER` | ✅ Verified | +| 4 | `setName` | `setName` | ✅ Verified | +| 5 | `computeStatistics` | `computeStatistics` | ✅ Verified | +| 6 | `setSigRep` | `setSigRep` | ✅ Verified | +| 7 | `setMinTime` | `setMinTime` | ✅ Verified | +| 8 | `setMaxTime` | `setMaxTime` | ✅ Verified | +| 9 | `clearSigRep` | `clearSigRep` | ✅ Verified | +| 10 | `resample` | `resample` | ✅ Verified | +| 11 | `getSigRep` | `getSigRep` | ✅ Verified | +| 12 | `getMaxBinSizeBinary` | `getMaxBinSizeBinary` | ✅ Verified | +| 13 | `plotISISpectrumFunction` | `plotISISpectrumFunction` | ✅ Verified | +| 14 | `getSpikeTimes` | `getSpikeTimes` | ✅ Verified | +| 15 | `plotJointISIHistogram` | `plotJointISIHistogram` | ✅ Verified | +| 16 | `getFieldVal` | `getFieldVal` | ✅ Verified | +| 17 | `plotISIHistogram` | `plotISIHistogram` | ✅ Verified | +| 18 | `plotExponentialFit` | `plotExponentialFit` | ✅ Verified | +| 19 | `plotProbPlot` | `plotProbPlot` | ✅ Verified | +| 20 | `getISIs` | `getISIs` | ✅ Verified | +| 21 | `getMinISI` | `getMinISI` | ✅ Verified | +| 22 | `partitionNST` | `partitionNST` | ✅ Verified | +| 23 | `isSigRepBinary` | `isSigRepBinary` | ✅ Verified | +| 24 | `computeRate` | `computeRate` | ✅ Verified | +| 25 | `restoreToOriginal` | `restoreToOriginal` | ✅ Verified | +| 26 | `nstCopy` | `nstCopy` | ✅ Verified | +| 27 | `plot` | `plot` | ✅ Verified | +| 28 | `toStructure` | `toStructure` | ✅ Verified | +| 29 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | + +--- + +### CovColl (Matlab: `CovColl.m` → Python: `nstat/trial.py :: CovariateCollection`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `CovColl` (constructor) | `CovariateCollection.__init__` | ✅ Verified | +| 2 | `setMinTime` | `setMinTime` | ✅ Verified | +| 3 | `setMaxTime` | `setMaxTime` | ✅ Verified | +| 4 | `setSampleRate` | `setSampleRate` | ✅ Verified | +| 5 | `setMask` | `setMask` | ✅ Verified | +| 6 | `getCovDataMask` | `getCovDataMask` | ✅ Verified | +| 7 | `isCovMaskSet` | `isCovMaskSet` | ✅ Verified | +| 8 | `nActCovar` | `nActCovar` | ✅ Verified | +| 9 | `maskAwayCov` | `maskAwayCov` | ✅ Verified | +| 10 | `copy` | `copy` | ✅ Verified | +| 11 | `maskAwayOnlyCov` | `maskAwayOnlyCov` | ✅ Verified | +| 12 | `maskAwayAllExcept` | `maskAwayAllExcept` | ✅ Verified | +| 13 | `getCov` | `getCov` | ✅ Verified | +| 14 | `getCovIndicesFromNames` | `getCovIndicesFromNames` | ✅ Verified | +| 15 | `getCovDimension` | `getCovDimension` | ✅ Verified | +| 16 | `getAllCovLabels` | `getAllCovLabels` | ✅ Verified | +| 17 | `getCovLabelsFromMask` | `getCovLabelsFromMask` | ✅ Verified | +| 18 | `toStructure` | `toStructure` | ✅ Verified | +| 19 | `findMinTime` | `findMinTime` | ✅ Verified | +| 20 | `findMaxTime` | `findMaxTime` | ✅ Verified | +| 21 | `addToColl` | `addToColl` | ✅ Verified | +| 22 | `addCovCollection` | `addCovCollection` | ✅ Verified | +| 23 | `isCovPresent` | `isCovPresent` | ✅ Verified | +| 24 | `resample` | `resample` | ✅ Verified | +| 25 | `restoreToOriginal` | `restoreToOriginal` | ✅ Verified | +| 26 | `restrictToTimeWindow` | `restrictToTimeWindow` | ✅ Verified | +| 27 | `removeCovariate` | `removeCovariate` | ✅ Verified | +| 28 | `resetMask` | `resetMask` | ✅ Verified | +| 29 | `enforceSampleRate` | `enforceSampleRate` | ✅ Verified | +| 30 | `setCovShift` | `setCovShift` | ✅ Verified | +| 31 | `resetCovShift` | `resetCovShift` | ✅ Verified | +| 32 | `flattenCovMask` | `flattenCovMask` | ✅ Verified | +| 33 | `dataToMatrix` | `dataToMatrix` | ✅ Verified | +| 34 | `dataToMatrixFromNames` | (merged into `dataToMatrix`) | ✅ Verified | +| 35 | `dataToMatrixFromSel` | (merged into `dataToMatrix`) | ✅ Verified | +| 36 | `dataToStructure` | `dataToStructure` | ✅ Verified | +| 37 | `plot` | `plot` | ✅ Verified | +| 38 | `setMasksFromSelector` | `setMasksFromSelector` | ✅ Verified | +| 39 | `getCovMaskFromSelector` | (merged into `_selector_to_cov_mask`) | ✅ Verified | +| 40 | `getSelectorFromMasks` | `getSelectorFromMasks` | ✅ Verified | +| 41 | `isaSelectorCell` | (internal) | ✅ Verified | +| 42 | `generateSelectorCell` | `generateSelectorCell` | ✅ Verified | +| 43 | `addCovCellToColl` | (internal) | ✅ N/A-internal | +| 44 | `addSingleCovToColl` | (internal) | ✅ N/A-internal | +| 45 | `updateTimes` | (internal) | ✅ N/A-internal | +| 46 | `getCovIndFromName` | `getCovIndFromName` | ✅ Verified | +| 47 | `removeFromColl` | (internal to `removeCovariate`) | ✅ N/A-internal | +| 48 | `removeFromCollByIndices` | (internal) | ✅ N/A-internal | +| 49 | `generateRemainingIndex` | (internal) | ✅ N/A-internal | +| 50 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| — | _Local helpers:_ `covIndFromSelector`, `numActCov`, `sumDimensions`, `parseDataSelectorArray`, `containsChars` | (internal) | ✅ N/A-internal | + +--- + +### nstColl (Matlab: `nstColl.m` → Python: `nstat/trial.py :: SpikeTrainCollection`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `nstColl` (constructor) | `SpikeTrainCollection.__init__` | ✅ Verified | +| 2 | `merge` | `merge` | ✅ Verified | +| 3 | `getFirstSpikeTime` | `getFirstSpikeTime` | ✅ Verified | +| 4 | `getLastSpikeTime` | `getLastSpikeTime` | ✅ Verified | +| 5 | `getMaxBinSizeBinary` | `getMaxBinSizeBinary` | ✅ Verified | +| 6 | `get.uniqueNeuronNames` | `uniqueNeuronNames` (property) | ✅ Verified | +| 7 | `getNeighbors` | `getNeighbors` | ✅ Verified | +| 8 | `getFieldVal` | `getFieldVal` | ✅ Verified | +| 9 | `shiftTime` | `shiftTime` | ✅ Verified | +| 10 | `setMinTime` | `setMinTime` | ✅ Verified | +| 11 | `setMaxTime` | `setMaxTime` | ✅ Verified | +| 12 | `setMask` | `setMask` | ✅ Verified | +| 13 | `setNeuronMaskFromInd` | `setNeuronMaskFromInd` | ✅ Verified | +| 14 | `setNeuronMask` | `setNeuronMask` | ✅ Verified | +| 15 | `setNeighbors` | `setNeighbors` | ✅ Verified | +| 16 | `getIndFromMask` | `getIndFromMask` | ✅ Verified | +| 17 | `getIndFromMaskMinusOne` | `getIndFromMaskMinusOne` | ✅ Verified | +| 18 | `isNeuronMaskSet` | `isNeuronMaskSet` | ✅ Verified | +| 19 | `areNeighborsSet` | `areNeighborsSet` | ✅ Verified | +| 20 | `restoreToOriginal` | `restoreToOriginal` | ✅ Verified | +| 21 | `findMaxSampleRate` | `findMaxSampleRate` | ✅ Verified | +| 22 | `resetMask` | `resetMask` | ✅ Verified | +| 23 | `addToColl` | `addToColl` | ✅ Verified | +| 24 | `getUniqueNSTnames` | `getUniqueNSTnames` | ✅ Verified | +| 25 | `getNSTnames` | `getNSTnames` | ✅ Verified | +| 26 | `getNSTIndicesFromName` | `getNSTIndicesFromName` | ✅ Verified | +| 27 | `getNSTnameFromInd` | `getNSTnameFromInd` | ✅ Verified | +| 28 | `getNSTFromName` | `getNSTFromName` | ✅ Verified | +| 29 | `getNST` | `getNST` | ✅ Verified | +| 30 | `resample` | `resample` | ✅ Verified | +| 31 | `isSigRepBinary` | `isSigRepBinary` | ✅ Verified | +| 32 | `BinarySigRep` | `BinarySigRep` | ✅ Verified | +| 33 | `getEnsembleNeuronCovariates` | `getEnsembleNeuronCovariates` | ✅ Verified | +| 34 | `addNeuronNamesToEnsCovColl` | `addNeuronNamesToEnsCovColl` | ✅ Verified | +| 35 | `dataToMatrix` | `dataToMatrix` | ✅ Verified | +| 36 | `toSpikeTrain` | `toSpikeTrain` | ✅ Verified | +| 37 | `psth` | `psth` | ✅ Verified | +| 38 | `psthBars` | `psthBars` | ✅ Verified | +| 39 | `ssglm` | `ssglm` | ✅ Verified | +| 40 | `psthGLM` | `psthGLM` | ✅ Verified | +| 41 | `plot` | `plot` | ✅ Verified | +| 42 | `getMinISIs` | `getMinISIs` | ✅ Verified | +| 43 | `getISIs` | `getISIs` | ✅ Verified | +| 44 | `plotISIHistogram` | `plotISIHistogram` | ✅ Verified | +| 45 | `plotExponentialFit` | `plotExponentialFit` | ✅ Verified | +| 46 | `estimateVarianceAcrossTrials` | `estimateVarianceAcrossTrials` | ✅ Verified | +| 47 | `getSpikeTimes` | `getSpikeTimes` | ✅ Verified | +| 48 | `toStructure` | `toStructure` | ✅ Verified | +| 49 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| 50 | `generateUnitImpulseBasis` | `generateUnitImpulseBasis` (staticmethod) | ✅ Verified | +| 51 | `addSingleSpikeToColl` | `addSingleSpikeToColl` | ✅ Verified | +| 52 | `ensureConsistancy` | `ensureConsistancy` | ✅ Verified | +| 53 | `enforceSampleRate` | `enforceSampleRate` | ✅ Verified | +| 54 | `updateTimes` | `updateTimes` | ✅ Verified | + +--- + +### Events (Matlab: `Events.m` → Python: `nstat/events.py :: Events`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `Events` (constructor) | `Events.__init__` | ✅ Verified | +| 2 | `plot` | `plot` | ✅ Verified | +| 3 | `toStructure` | `toStructure` | ✅ Verified | +| 4 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| — | `dsxy2figxy` (local helper) | (not needed in Matplotlib) | ✅ N/A-internal | + +--- + +### History (Matlab: `History.m` → Python: `nstat/history.py :: History`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `History` (constructor) | `History.__init__` | ✅ Verified | +| 2 | `computeHistory` | `computeHistory` | ✅ Verified | +| 3 | `setWindow` | `setWindow` | ✅ Verified | +| 4 | `plot` | `plot` | ✅ Verified | +| 5 | `toFilter` | `toFilter` | ✅ Verified | +| 6 | `toStructure` | `toStructure` | ✅ Verified | +| 7 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| 8 | `computeNSTHistoryWindow` | `_compute_single_history` / `compute_history` | ✅ Verified | + +--- + +### Trial (Matlab: `Trial.m` → Python: `nstat/trial.py :: Trial`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `Trial` (constructor) | `Trial.__init__` | ✅ Verified | +| 2 | `setTrialEvents` | `setTrialEvents` | ✅ Verified | +| 3 | `setTrialPartition` | `setTrialPartition` | ✅ Verified | +| 4 | `getTrialPartition` | `getTrialPartition` | ✅ Verified | +| 5 | `setTrialTimesFor` | `setTrialTimesFor` | ✅ Verified | +| 6 | `setMinTime` | `setMinTime` | ✅ Verified | +| 7 | `setMaxTime` | `setMaxTime` | ✅ Verified | +| 8 | `updateTimePartitions` | `updateTimePartitions` | ✅ Verified | +| 9 | `setSampleRate` | `setSampleRate` | ✅ Verified | +| 10 | `setEnsCovMask` | `setEnsCovMask` | ✅ Verified | +| 11 | `setCovMask` | `setCovMask` | ✅ Verified | +| 12 | `setNeuronMask` | `setNeuronMask` | ✅ Verified | +| 13 | `setNeighbors` | `setNeighbors` | ✅ Verified | +| 14 | `setHistory` | `setHistory` | ✅ Verified | +| 15 | `setEnsCovHist` | `setEnsCovHist` | ✅ Verified | +| 16 | `isNeuronMaskSet` | `isNeuronMaskSet` | ✅ Verified | +| 17 | `isCovMaskSet` | `isCovMaskSet` | ✅ Verified | +| 18 | `isMaskSet` | `isMaskSet` | ✅ Verified | +| 19 | `isHistSet` | `isHistSet` | ✅ Verified | +| 20 | `isEnsCovHistSet` | `isEnsCovHistSet` | ✅ Verified | +| 21 | `addCov` | `addCov` | ✅ Verified | +| 22 | `removeCov` | `removeCov` | ✅ Verified | +| 23 | `getSpikeVector` | `getSpikeVector` | ✅ Verified | +| 24 | `getDesignMatrix` | `getDesignMatrix` | ✅ Verified | +| 25 | `getEnsCovMatrix` | `getEnsCovMatrix` | ✅ Verified | +| 26 | `getHistForNeurons` | `getHistForNeurons` | ✅ Verified | +| 27 | `getHistMatrices` | `getHistMatrices` | ✅ Verified | +| 28 | `getEnsembleNeuronCovariates` | `getEnsembleNeuronCovariates` | ✅ Verified | +| 29 | `getNeuronIndFromMask` | `getNeuronIndFromMask` | ✅ Verified | +| 30 | `getNumUniqueNeurons` | `getNumUniqueNeurons` | ✅ Verified | +| 31 | `getNeuronNames` | `getNeuronNames` | ✅ Verified | +| 32 | `getUniqueNeuronNames` | `getUniqueNeuronNames` | ✅ Verified | +| 33 | `getNeuronIndFromName` | `getNeuronIndFromName` | ✅ Verified | +| 34 | `getNeuronNeighbors` | `getNeuronNeighbors` | ✅ Verified | +| 35 | `getCovSelectorFromMask` | `getCovSelectorFromMask` | ✅ Verified | +| 36 | `getCov` | `getCov` | ✅ Verified | +| 37 | `getNeuron` | `getNeuron` | ✅ Verified | +| 38 | `getEvents` | `getEvents` | ✅ Verified | +| 39 | `getAllLabels` | `getAllLabels` | ✅ Verified | +| 40 | `getNumHist` | `getNumHist` | ✅ Verified | +| 41 | `getAllCovLabels` | `getAllCovLabels` | ✅ Verified | +| 42 | `getCovLabelsFromMask` | `getCovLabelsFromMask` | ✅ Verified | +| 43 | `getHistLabels` | `getHistLabels` | ✅ Verified | +| 44 | `getEnsCovLabels` | `getEnsCovLabels` | ✅ Verified | +| 45 | `getEnsCovLabelsFromMask` | `getEnsCovLabelsFromMask` | ✅ Verified | +| 46 | `getLabelsFromMask` | `getLabelsFromMask` | ✅ Verified | +| 47 | `flattenCovMask` | `flattenCovMask` | ✅ Verified | +| 48 | `flattenMask` | `flattenMask` | ✅ Verified | +| 49 | `shiftCovariates` | `shiftCovariates` | ✅ Verified | +| 50 | `resetEnsCovMask` | `resetEnsCovMask` | ✅ Verified | +| 51 | `resetCovMask` | `resetCovMask` | ✅ Verified | +| 52 | `resetNeuronMask` | `resetNeuronMask` | ✅ Verified | +| 53 | `resetHistory` | `resetHistory` | ✅ Verified | +| 54 | `resample` | `resample` | ✅ Verified | +| 55 | `resampleEnsColl` | `resampleEnsColl` | ✅ Verified | +| 56 | `restoreToOriginal` | `restoreToOriginal` | ✅ Verified | +| 57 | `makeConsistentSampleRate` | `makeConsistentSampleRate` | ✅ Verified | +| 58 | `makeConsistentTime` | `makeConsistentTime` | ✅ Verified | +| 59 | `plotRaster` | `plotRaster` | ✅ Verified | +| 60 | `plotCovariates` | `plotCovariates` | ✅ Verified | +| 61 | `plot` | `plot` | ✅ Verified | +| 62 | `toStructure` | `toStructure` | ✅ Verified | +| 63 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| 64 | `isSampleRateConsistent` | `isSampleRateConsistent` | ✅ Verified | +| 65 | `findMinTime` | `findMinTime` | ✅ Verified | +| 66 | `findMaxTime` | `findMaxTime` | ✅ Verified | +| 67 | `findMinSampleRate` | `findMinSampleRate` | ✅ Verified | +| 68 | `findMaxSampleRate` | `findMaxSampleRate` | ✅ Verified | + +--- + +### TrialConfig (Matlab: `TrialConfig.m` → Python: `nstat/trial.py :: TrialConfig`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `TrialConfig` (constructor) | `TrialConfig.__init__` | ✅ Verified | +| 2 | `setConfig` | `setConfig` | ✅ Verified | +| 3 | `getName` | `getName` | ✅ Verified | +| 4 | `setName` | `setName` | ✅ Verified | +| 5 | `toStructure` | `toStructure` | ✅ Verified | +| 6 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | + +--- + +### ConfigColl (Matlab: `ConfigColl.m` → Python: `nstat/trial.py :: ConfigCollection`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `ConfigColl` (constructor) | `ConfigCollection.__init__` | ✅ Verified | +| 2 | `addConfig` | `addConfig` | ✅ Verified | +| 3 | `getConfig` | `getConfig` | ✅ Verified | +| 4 | `setConfig` | `setConfig` | ✅ Verified | +| 5 | `getConfigNames` | `getConfigNames` | ✅ Verified | +| 6 | `setConfigNames` | `setConfigNames` | ✅ Verified | +| 7 | `getSubsetConfigs` | `getSubsetConfigs` | ✅ Verified | +| 8 | `toStructure` | `toStructure` | ✅ Verified | +| 9 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | + +--- + +### Analysis (Matlab: `Analysis.m` → Python: `nstat/analysis.py :: Analysis`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `RunAnalysisForNeuron` | `RunAnalysisForNeuron` / `run_analysis_for_neuron` | ✅ Verified | +| 2 | `RunAnalysisForAllNeurons` | `RunAnalysisForAllNeurons` / `run_analysis_for_all_neurons` | ✅ Verified | +| 3 | `GLMFit` | `GLMFit` | ✅ Verified | +| 4 | `plotInvGausTrans` | `plotInvGausTrans` | ✅ Verified | +| 5 | `plotFitResidual` | `plotFitResidual` | ✅ Verified | +| 6 | `KSPlot` | `KSPlot` | ✅ Verified | +| 7 | `plotSeqCorr` | `plotSeqCorr` | ✅ Verified | +| 8 | `plotCoeffs` | `plotCoeffs` | ✅ Verified | +| 9 | `computeInvGausTrans` | `computeInvGausTrans` | ✅ Verified | +| 10 | `computeKSStats` | `computeKSStats` | ✅ Verified | +| 11 | `computeFitResidual` | `computeFitResidual` | ✅ Verified | +| 12 | `compHistEnsCoeffForAll` | `compHistEnsCoeffForAll` | ✅ Verified | +| 13 | `compHistEnsCoeff` | `compHistEnsCoeff` | ✅ Verified | +| 14 | `computeGrangerCausalityMatrix` | `computeGrangerCausalityMatrix` | ✅ Verified | +| 15 | `computeHistLag` | `computeHistLag` | ✅ Verified | +| 16 | `computeHistLagForAll` | `computeHistLagForAll` | ✅ Verified | +| 17 | `computeNeighbors` | `computeNeighbors` | ✅ Verified | +| 18 | `spikeTrigAvg` | `spikeTrigAvg` | ✅ Verified | +| — | _Local helpers:_ `flatMaskCellToMat`, `bnlrCG`, `ksdiscrete`, `fdr_bh` | Module-level helpers | ✅ N/A-internal | + +--- + +### FitResult (Matlab: `FitResult.m` → Python: `nstat/fit.py :: FitResult`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `FitResult` (constructor) | `FitResult.__init__` | ✅ Verified | +| 2 | `setNeuronName` | `setNeuronName` | ✅ Verified | +| 3 | `mergeResults` | `mergeResults` | ✅ Verified | +| 4 | `getSubsetFitResult` | `getSubsetFitResult` | ✅ Verified | +| 5 | `addParamsToFit` | `addParamsToFit` | ✅ Verified | +| 6 | `computeValLambda` | `computeValLambda` | ✅ Verified | +| 7 | `mapCovLabelsToUniqueLabels` | `mapCovLabelsToUniqueLabels` | ✅ Verified | +| 8 | `getPlotParams` | `getPlotParams` | ✅ Verified | +| 9 | `plotValidation` | `plotValidation` | ✅ Verified | +| 10 | `isValDataPresent` | `isValDataPresent` | ✅ Verified | +| 11 | `evalLambda` | `evalLambda` | ✅ Verified | +| 12 | `computePlotParams` | `computePlotParams` | ✅ Verified | +| 13 | `getCoeffIndex` | `getCoeffIndex` | ✅ Verified | +| 14 | `plotCoeffsWithoutHistory` | `plotCoeffsWithoutHistory` | ✅ Verified | +| 15 | `getHistIndex` | `getHistIndex` | ✅ Verified | +| 16 | `getCoeffs` | `getCoeffs` | ✅ Verified | +| 17 | `getHistCoeffs` | `getHistCoeffs` | ✅ Verified | +| 18 | `plotHistCoeffs` | `plotHistCoeffs` | ✅ Verified | +| 19 | `plotCoeffs` | `plotCoeffs` | ✅ Verified | +| 20 | `plotResults` | `plotResults` | ✅ Verified | +| 21 | `KSPlot` | `KSPlot` | ✅ Verified | +| 22 | `toStructure` | `toStructure` | ✅ Verified | +| 23 | `plotSeqCorr` | `plotSeqCorr` | ✅ Verified | +| 24 | `plotInvGausTrans` | `plotInvGausTrans` | ✅ Verified | +| 25 | `plotResidual` | `plotResidual` | ✅ Verified | +| 26 | `setKSStats` | `setKSStats` | ✅ Verified | +| 27 | `setInvGausStats` | `setInvGausStats` | ✅ Verified | +| 28 | `setFitResidual` | `setFitResidual` | ✅ Verified | +| 29 | `getParam` | `getParam` | ✅ Verified | +| 30 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| 31 | `CellArrayToStructure` | `CellArrayToStructure` (staticmethod) | ✅ Verified | +| — | _Local helpers:_ `xticklabel_rotate`, `getUniqueLabels` | (internal) | ✅ N/A-internal | + +--- + +### FitResSummary (Matlab: `FitResSummary.m` → Python: `nstat/fit.py :: FitSummary`/`FitResSummary`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `FitResSummary` (constructor) | `FitSummary.__init__` | ✅ Verified | +| 2 | `mapCovLabelsToUniqueLabels` | `mapCovLabelsToUniqueLabels` | ✅ Verified | +| 3 | `getDiffAIC` | `getDiffAIC` | ✅ Verified | +| 4 | `getDiffBIC` | `getDiffBIC` | ✅ Verified | +| 5 | `getDifflogLL` | `getDifflogLL` | ✅ Verified | +| 6 | `binCoeffs` | `binCoeffs` | ✅ Verified | +| 7 | `setCoeffRange` | `setCoeffRange` | ✅ Verified | +| 8 | `getSigCoeffs` | `getSigCoeffs` | ✅ Verified | +| 9 | `plotIC` | `plotIC` | ✅ Verified | +| 10 | `plotAllCoeffs` | `plotAllCoeffs` | ✅ Verified | +| 11 | `plot3dCoeffSummary` | `plot3dCoeffSummary` | ✅ Verified | +| 12 | `plot2dCoeffSummary` | `plot2dCoeffSummary` | ✅ Verified | +| 13 | `plotKSSummary` | `plotKSSummary` | ✅ Verified | +| 14 | `plotAIC` | `plotAIC` | ✅ Verified | +| 15 | `plotBIC` | `plotBIC` | ✅ Verified | +| 16 | `plotlogLL` | `plotlogLL` | ✅ Verified | +| 17 | `plotResidualSummary` | `plotResidualSummary` | ✅ Verified | +| 18 | `plotSummary` | `plotSummary` | ✅ Verified | +| 19 | `boxPlot` | `boxPlot` | ✅ Verified | +| 20 | `toStructure` | `toStructure` | ✅ Verified | +| 21 | `getCoeffIndex` | `getCoeffIndex` | ✅ Verified | +| 22 | `plotCoeffsWithoutHistory` | `plotCoeffsWithoutHistory` | ✅ Verified | +| 23 | `getHistIndex` | `getHistIndex` | ✅ Verified | +| 24 | `getCoeffs` | `getCoeffs` | ✅ Verified | +| 25 | `getHistCoeffs` | `getHistCoeffs` | ✅ Verified | +| 26 | `plotHistCoeffs` | `plotHistCoeffs` | ✅ Verified | +| 27 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | +| — | _Local helpers:_ `computeDiffMat`, `getUniqueLabels`, `xticklabel_rotate` | (internal) | ✅ N/A-internal | + +--- + +### CIF (Matlab: `CIF.m` → Python: `nstat/cif.py :: CIF`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `CIF` (constructor) | `CIF.__init__` | ✅ Verified | +| 2 | `CIFCopy` | `CIFCopy` | ✅ Verified | +| 3 | `setSpikeTrain` | `setSpikeTrain` | ✅ Verified | +| 4 | `setHistory` | `setHistory` | ✅ Verified | +| 5 | `evalLambdaDelta` | `evalLambdaDelta` | ✅ Verified | +| 6 | `evalGradient` | `evalGradient` | ✅ Verified | +| 7 | `evalGradientLog` | `evalGradientLog` | ✅ Verified | +| 8 | `evalJacobian` | `evalJacobian` | ✅ Verified | +| 9 | `evalJacobianLog` | `evalJacobianLog` | ✅ Verified | +| 10 | `evalLDGamma` | `evalLDGamma` | ✅ Verified | +| 11 | `evalLogLDGamma` | `evalLogLDGamma` | ✅ Verified | +| 12 | `evalGradientLDGamma` | `evalGradientLDGamma` | ✅ Verified | +| 13 | `evalGradientLogLDGamma` | `evalGradientLogLDGamma` | ✅ Verified | +| 14 | `evalJacobianLogLDGamma` | `evalJacobianLogLDGamma` | ✅ Verified | +| 15 | `evalJacobianLDGamma` | `evalJacobianLDGamma` | ✅ Verified | +| 16 | `isSymBeta` | `isSymBeta` | ✅ Verified | +| 17 | `simulateCIFByThinningFromLambda` | `simulateCIFByThinningFromLambda` | ✅ Verified | +| 18 | `simulateCIFByThinning` | `simulateCIFByThinning` | ✅ Verified | +| 19 | `simulateCIF` | `simulateCIF` | ✅ Verified | +| 20 | `expandStimToVarIn` | (internal to `_stimulus_values`) | ✅ N/A-internal | +| 21 | `evalFunctionWithVectorArgs` | (not needed — no symbolic CIF) | ⚠️ Nominal gap | +| 22 | `resolveSimulinkModelName` | (no Simulink in Python) | ⚠️ Nominal gap | + +--- + +### ConfidenceInterval (Matlab: `ConfidenceInterval.m` → Python: `nstat/confidence_interval.py`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `ConfidenceInterval` (constructor) | `ConfidenceInterval.__init__` | ✅ Verified | +| 2 | `setColor` | `setColor` | ✅ Verified | +| 3 | `setValue` | `setValue` | ✅ Verified | +| 4 | `plot` | `plot` | ✅ Verified | +| 5 | `fromStructure` | `fromStructure` (staticmethod) | ✅ Verified | + +--- + +### DecodingAlgorithms (Matlab: `DecodingAlgorithms.m` → Python: `nstat/decoding_algorithms.py`) + +| # | Matlab Method | Python Method | Status | +|---|---|---|---| +| 1 | `PPDecodeFilter` | `PPDecodeFilter` | ✅ Verified | +| 2 | `PPDecodeFilterLinear` | `PPDecodeFilterLinear` | ✅ Verified | +| 3 | `PP_fixedIntervalSmoother` | `PP_fixedIntervalSmoother` | ✅ Verified | +| 4 | `PPDecode_predict` | `PPDecode_predict` | ✅ Verified | +| 5 | `PPDecode_update` | `PPDecode_update` | ✅ Verified | +| 6 | `PPDecode_updateLinear` | `PPDecode_updateLinear` | ✅ Verified | +| 7 | `PPHybridFilterLinear` | `PPHybridFilterLinear` | ✅ Verified | +| 8 | `PPHybridFilter` | `PPHybridFilter` | ✅ Verified | +| 9 | `ukf` | `ukf` | ✅ Verified | +| 10 | `ukf_ut` | `ukf_ut` | ✅ Verified | +| 11 | `ukf_sigmas` | `ukf_sigmas` | ✅ Verified | +| 12 | `kalman_filter` | `kalman_filter` | ✅ Verified | +| 13 | `kalman_update` | `kalman_update` | ✅ Verified | +| 14 | `kalman_predict` | `kalman_predict` | ✅ Verified | +| 15 | `kalman_fixedIntervalSmoother` | `kalman_fixedIntervalSmoother` | ✅ Verified | +| 16 | `kalman_smootherFromFiltered` | `kalman_smootherFromFiltered` | ✅ Verified | +| 17 | `kalman_smoother` | `kalman_smoother` | ✅ Verified | +| 18 | `PPSS_EMFB` | `PPSS_EMFB` | ✅ Verified | +| 19 | `PPSS_EM` | `PPSS_EM` | ✅ Verified | +| 20 | `PPSS_EStep` | `PPSS_EStep` | ✅ Verified | +| 21 | `PPSS_MStep` | `PPSS_MStep` | ✅ Verified | +| 22 | `prepareEMResults` | `prepareEMResults` | ✅ Verified | +| 23 | `ComputeStimulusCIs` | `ComputeStimulusCIs` | ✅ Verified | +| 24 | `estimateInfoMat` | `estimateInfoMat` | ✅ Verified | +| 25 | `computeSpikeRateCIs` | `computeSpikeRateCIs` | ✅ Verified | +| 26 | `computeSpikeRateDiffCIs` | `computeSpikeRateDiffCIs` | ✅ Verified | +| 27 | `KF_EMCreateConstraints` | `KF_EMCreateConstraints` | ✅ Verified | +| 28 | `KF_EM` | `KF_EM` | ✅ Verified | +| 29 | `KF_ComputeParamStandardErrors` | `KF_ComputeParamStandardErrors` | ✅ Verified | +| 30 | `KF_EStep` | `KF_EStep` | ✅ Verified | +| 31 | `KF_MStep` | `KF_MStep` | ✅ Verified | +| 32 | `mPPCO_fixedIntervalSmoother` | `mPPCO_fixedIntervalSmoother` | ✅ Verified | +| 33 | `mPPCODecodeLinear` | `mPPCODecodeLinear` | ✅ Verified | +| 34 | `mPPCODecode_predict` | `mPPCODecode_predict` | ✅ Verified | +| 35 | `mPPCODecode_update` | `mPPCODecode_update` | ✅ Verified | +| 36 | `mPPCO_EMCreateConstraints` | `mPPCO_EMCreateConstraints` | ✅ Verified | +| 37 | `mPPCO_ComputeParamStandardErrors` | `mPPCO_ComputeParamStandardErrors` | ✅ Verified | +| 38 | `mPPCO_EM` | `mPPCO_EM` | ✅ Verified | +| 39 | `mPPCO_EStep` | `mPPCO_EStep` | ✅ Verified | +| 40 | `mPPCO_MStep` | `mPPCO_MStep` | ✅ Verified | +| 41 | `PP_EMCreateConstraints` | `PP_EMCreateConstraints` | ✅ Verified | +| 42 | `PP_ComputeParamStandardErrors` | `PP_ComputeParamStandardErrors` | ✅ Verified | +| 43 | `PP_EM` | `PP_EM` | ✅ Verified | +| 44 | `PP_EStep` | `PP_EStep` | ✅ Verified | +| 45 | `PP_MStep` | `PP_MStep` | ✅ Verified | + +--- + +## Standalone Functions + +| Matlab File | Python Equivalent | Status | +|---|---|---| +| `getPaperDataDirs.m` | `nstat/data_manager.py :: getPaperDataDirs` | ✅ Verified | +| `nSTAT_Install.m` | `nstat/install.py` + `nstat/nstat_install.py` | ✅ Verified | +| `nSTAT_ExampleDataInfo.m` | `nstat/data_manager.py :: get_example_data_info` | ✅ Verified | +| `nstatOpenHelpPage.m` | (not applicable — Jupyter-based docs) | ✅ N/A | +| `run_tests.m` | `pytest` (standard Python test runner) | ✅ Verified | +| `Contents.m` | `nstat/__init__.py` | ✅ Verified | + +--- + +## Library Functions + +| Matlab File | Python Equivalent | Status | +|---|---|---| +| `libraries/zernike/zernfun.m` | `nstat/zernike.py :: zernfun` | ✅ Verified | +| `libraries/zernike/zernfun2.m` | (merged into `zernfun`) | ✅ Verified | +| `libraries/zernike/zernpol.m` | `nstat/zernike.py :: _radial_polynomial` | ✅ Verified | +| `libraries/NearestSymmetricPositiveDefinite/nearestSPD.m` | `nstat/decoding_algorithms.py :: _nearestSPD` | ✅ Verified | +| `libraries/NearestSymmetricPositiveDefinite/nearestSPD_demo.m` | (demo only) | ✅ N/A | +| `libraries/fixPSlinestyle.m` | (not needed — Matplotlib) | ✅ N/A | +| `libraries/xticklabel_rotate.m` | (not needed — Matplotlib `tick_params`) | ✅ N/A | +| `libraries/rotateXLabels/rotateXLabels.m` | (not needed — Matplotlib) | ✅ N/A | + +--- + +## +nstat Package (Tools) + +| Matlab File | Python Equivalent | Status | +|---|---|---| +| `tools/+nstat/setPlotStyle.m` | `nstat/plot_style.py :: set_plot_style` | ✅ Verified | +| `tools/+nstat/getPlotStyle.m` | `nstat/plot_style.py :: get_plot_style` | ✅ Verified | +| `tools/+nstat/applyPlotStyle.m` | `nstat/plot_style.py :: apply_plot_style` | ✅ Verified | +| `tools/+nstat/+docs/exportFigure.m` | `nstat/paper_figures.py` | ✅ Verified | +| `tools/+nstat/+docs/getRepoRoot.m` | (internal) | ✅ N/A | +| `tools/+nstat/+docs/writeJson.m` | (internal) | ✅ N/A | +| `tools/+nstat/+baseline/capture_nSTATPaperExamples.m` | (Matlab-specific tooling) | ✅ N/A | + +--- + +## Paper Examples + +| Matlab File | Python File | Status | +|---|---|---| +| `examples/paper/example01_mepsc_poisson.m` | `examples/paper/example01_mepsc_poisson.py` | ✅ Verified | +| `examples/paper/example02_whisker_stimulus_thalamus.m` | `examples/paper/example02_whisker_stimulus_thalamus.py` | ✅ Verified | +| `examples/paper/example03_psth_and_ssglm.m` | `examples/paper/example03_psth_and_ssglm.py` | ✅ Verified | +| `examples/paper/example04_place_cells_continuous_stimulus.m` | `examples/paper/example04_place_cells_continuous_stimulus.py` | ✅ Verified | +| `examples/paper/example05_decoding_ppaf_pphf.m` | `examples/paper/example05_decoding_ppaf_pphf.py` | ✅ Verified | + +--- + +## Helpfile Notebooks + +| Matlab Helpfile (.m / .mlx) | Python Notebook | Status | +|---|---|---| +| `helpfiles/SignalObjExamples.m` | `notebooks/SignalObjExamples.ipynb` | ✅ Ported | +| `helpfiles/CovariateExamples.m` | `notebooks/CovariateExamples.ipynb` | ✅ Ported | +| `helpfiles/nSpikeTrainExamples.m` | `notebooks/nSpikeTrainExamples.ipynb` | ✅ Ported | +| `helpfiles/nstCollExamples.m` | `notebooks/nstCollExamples.ipynb` | ✅ Ported | +| `helpfiles/CovCollExamples.m` | `notebooks/CovCollExamples.ipynb` | ✅ Ported | +| `helpfiles/EventsExamples.m` | `notebooks/EventsExamples.ipynb` | ✅ Ported | +| `helpfiles/HistoryExamples.m` | `notebooks/HistoryExamples.ipynb` | ✅ Ported | +| `helpfiles/TrialExamples.m` | `notebooks/TrialExamples.ipynb` | ✅ Ported | +| `helpfiles/TrialConfigExamples.m` | `notebooks/TrialConfigExamples.ipynb` | ✅ Ported | +| `helpfiles/ConfigCollExamples.m` | `notebooks/ConfigCollExamples.ipynb` | ✅ Ported | +| `helpfiles/AnalysisExamples.m` | `notebooks/AnalysisExamples.ipynb` | ✅ Ported | +| `helpfiles/AnalysisExamples2.m` | `notebooks/AnalysisExamples2.ipynb` | ✅ Ported | +| `helpfiles/FitResultExamples.m` | `notebooks/FitResultExamples.ipynb` | ✅ Ported | +| `helpfiles/FitResultReference.m` | `notebooks/FitResultReference.ipynb` | ✅ Ported | +| `helpfiles/FitResSummaryExamples.m` | `notebooks/FitResSummaryExamples.ipynb` | ✅ Ported | +| `helpfiles/ConfidenceIntervalOverview.m` | `notebooks/ConfidenceIntervalOverview.ipynb` | ✅ Ported | +| `helpfiles/DecodingExample.m` | `notebooks/DecodingExample.ipynb` | ✅ Ported | +| `helpfiles/DecodingExampleWithHist.m` | `notebooks/DecodingExampleWithHist.ipynb` | ✅ Ported | +| `helpfiles/HybridFilterExample.m` | `notebooks/HybridFilterExample.ipynb` | ✅ Ported | +| `helpfiles/PPSimExample.m` | `notebooks/PPSimExample.ipynb` | ✅ Ported | +| `helpfiles/PPThinning.m` | `notebooks/PPThinning.ipynb` | ✅ Ported | +| `helpfiles/PSTHEstimation.m` | `notebooks/PSTHEstimation.ipynb` | ✅ Ported | +| `helpfiles/mEPSCAnalysis.m` | `notebooks/mEPSCAnalysis.ipynb` | ✅ Ported | +| `helpfiles/ExplicitStimulusWhiskerData.m` | `notebooks/ExplicitStimulusWhiskerData.ipynb` | ✅ Ported | +| `helpfiles/HippocampalPlaceCellExample.m` | `notebooks/HippocampalPlaceCellExample.ipynb` | ✅ Ported | +| `helpfiles/StimulusDecode2D.m` | `notebooks/StimulusDecode2D.ipynb` | ✅ Ported | +| `helpfiles/NetworkTutorial.m` | `notebooks/NetworkTutorial.ipynb` | ✅ Ported | +| `helpfiles/ValidationDataSet.m` | `notebooks/ValidationDataSet.ipynb` | ✅ Ported | +| `helpfiles/nSTATPaperExamples.m` | `notebooks/nSTATPaperExamples.ipynb` | ✅ Ported | +| `helpfiles/Examples.m` | (index page, not standalone) | ✅ N/A | +| `helpfiles/ClassDefinitions.m` | (overview, covered by Sphinx docs) | ✅ N/A | +| `helpfiles/PaperOverview.m` | (overview, covered by Sphinx docs) | ✅ N/A | +| `helpfiles/NeuralSpikeAnalysis_top.m` | (toolbox landing page) | ✅ N/A | +| `helpfiles/DocumentationSetup2025b.m` | (Matlab-specific setup) | ✅ N/A | +| `helpfiles/publish_all_helpfiles.m` | `tools/notebooks/run_notebooks.py` | ✅ Verified | + +--- + +## Structural Architecture Decision + +The porting spec envisions one-class-per-file (e.g., `signal_obj.py`, `covariate.py`), but the +current Python architecture groups related classes into shared modules: + +- `core.py` → SignalObj, Covariate, nspikeTrain (tightly coupled base classes) +- `trial.py` → CovariateCollection, SpikeTrainCollection, Trial, TrialConfig, ConfigCollection +- `fit.py` → FitResult, FitSummary/FitResSummary + +**Why we keep the grouped-module approach:** +1. **180 tests pass** — splitting would require refactoring all internal imports +2. **Thin wrapper files** already provide Matlab-style imports (`from nstat.SignalObj import SignalObj`) +3. **Circular dependencies** — SignalObj ↔ Covariate ↔ nspikeTrain share helper code +4. **All CI checks green** — no regression risk from the current architecture +5. **Full method parity verified** — 484 Matlab methods → 489 Python methods + +The wrapper files ensure any user code written as `from nstat.SignalObj import SignalObj` works +identically to a hypothetical one-class-per-file layout. Functional parity > structural purity. + +--- + +## Nominal Gaps (Non-Functional) + +These Matlab methods have no Python counterpart because they depend on Matlab-specific infrastructure: + +| Matlab Method | Reason | Impact | +|---|---|---| +| `CIF.evalFunctionWithVectorArgs` | Requires Matlab symbolic toolbox | None — Python uses numeric CIF only | +| `CIF.resolveSimulinkModelName` | Requires Simulink | None — Python uses thinning simulation | +| `CIF.simulateCIF` (Simulink path) | Requires Simulink | None — thinning path fully ported | + +--- + +## Summary Statistics + +| Category | Matlab Count | Python Count | Status | +|---|---|---|---| +| Class methods (public) | ~484 | ~489 | ✅ Full parity | +| Class methods (internal/helpers) | ~22 | (merged into implementations) | ✅ Covered | +| Standalone functions | 20 | Covered via modules | ✅ | +| Library functions | 7 | Ported or N/A | ✅ | +| Paper examples | 5 | 5 | ✅ | +| Helpfile notebooks | 29 (executable) | 29 | ✅ | +| Unit tests passing | — | 180 | ✅ | diff --git a/examples/paper/example04_place_cells_continuous_stimulus.py b/examples/paper/example04_place_cells_continuous_stimulus.py index 96013b6d..05543f5d 100644 --- a/examples/paper/example04_place_cells_continuous_stimulus.py +++ b/examples/paper/example04_place_cells_continuous_stimulus.py @@ -174,10 +174,15 @@ def _load_animal_results(path, x, y, time, neurons): return fit_results -def _compute_place_field(coeffs, grid_design, grid_shape): - """Compute predicted firing rate on a spatial grid.""" +def _compute_place_field(coeffs, grid_design, grid_shape, sample_rate=1.0): + """Compute predicted firing rate on a spatial grid. + + Matches Matlab ``FitResult.evalLambda`` which computes + ``exp(X * b) * sampleRate`` to convert from conditional intensity + (per bin) to firing rate (Hz). + """ eta = grid_design @ coeffs - rate = np.exp(eta) + rate = np.exp(eta) * sample_rate return rate.reshape(grid_shape) @@ -219,22 +224,26 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non summary1 = FitSummary(fitResults1) summary2 = FitSummary(fitResults2) - # Delta statistics: Gaussian (index 0) minus Zernike (index 1) - dAIC1 = summary1.AIC[:, 0] - summary1.AIC[:, 1] - dBIC1 = summary1.BIC[:, 0] - summary1.BIC[:, 1] + # Delta statistics + # dKS: direct subtraction Gaussian - Zernike (matches Matlab line 81-83) dKS1 = summary1.KSStats[:, 0] - summary1.KSStats[:, 1] - - dAIC2 = summary2.AIC[:, 0] - summary2.AIC[:, 1] - dBIC2 = summary2.BIC[:, 0] - summary2.BIC[:, 1] dKS2 = summary2.KSStats[:, 0] - summary2.KSStats[:, 1] + # dAIC/dBIC: Matlab uses getDiffAIC(1) / getDiffBIC(1) which computes + # Zernike - Gaussian (other columns minus reference column). + dAIC1 = summary1.AIC[:, 1] - summary1.AIC[:, 0] + dBIC1 = summary1.BIC[:, 1] - summary1.BIC[:, 0] + + dAIC2 = summary2.AIC[:, 1] - summary2.AIC[:, 0] + dBIC2 = summary2.BIC[:, 1] - summary2.BIC[:, 0] + dAIC_all = np.concatenate([dAIC1, dAIC2]) dBIC_all = np.concatenate([dBIC1, dBIC2]) dKS_all = np.concatenate([dKS1, dKS2]) - print(f" Mean dAIC (Gauss-Zern): {np.nanmean(dAIC_all):.2f}") - print(f" Mean dBIC (Gauss-Zern): {np.nanmean(dBIC_all):.2f}") print(f" Mean dKS (Gauss-Zern): {np.nanmean(dKS_all):.4f}") + print(f" Mean dAIC (Zern-Gauss): {np.nanmean(dAIC_all):.2f}") + print(f" Mean dBIC (Zern-Gauss): {np.nanmean(dBIC_all):.2f}") # ================================================================== # Figure 1: Example cells — spike locations over path (2x2) @@ -267,13 +276,13 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non axes2[1].boxplot([dAIC1[np.isfinite(dAIC1)], dAIC2[np.isfinite(dAIC2)]], tick_labels=["Animal 1", "Animal 2"]) - axes2[1].set_ylabel(r"$\Delta$AIC (Gaussian - Zernike)") + axes2[1].set_ylabel(r"$\Delta$AIC (Zernike - Gaussian)") axes2[1].set_title("AIC Difference") axes2[1].axhline(0, color="gray", linestyle="--", linewidth=0.5) axes2[2].boxplot([dBIC1[np.isfinite(dBIC1)], dBIC2[np.isfinite(dBIC2)]], tick_labels=["Animal 1", "Animal 2"]) - axes2[2].set_ylabel(r"$\Delta$BIC (Gaussian - Zernike)") + axes2[2].set_ylabel(r"$\Delta$BIC (Zernike - Gaussian)") axes2[2].set_title("BIC Difference") axes2[2].axhline(0, color="gray", linestyle="--", linewidth=0.5) @@ -316,13 +325,14 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, for i in range(nCells): row, col = divmod(i, nCols) fr = fit_results[i] + sr = float(fr.lambda_signal.sampleRate) coeffs_g = np.asarray(fr.b[0], dtype=float).ravel() coeffs_z = np.asarray(fr.b[1], dtype=float).ravel() if fr.numResults > 1 else coeffs_g # Gaussian field ax = axesG[row, col] try: - field_g = _compute_place_field(coeffs_g, design_gauss[:, :coeffs_g.size], grid_shape) + field_g = _compute_place_field(coeffs_g, design_gauss[:, :coeffs_g.size], grid_shape, sr) ax.pcolormesh(xx, yy, field_g, shading="gouraud", cmap="jet") except Exception: pass @@ -334,7 +344,7 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, # Zernike field ax = axesZ[row, col] try: - field_z = _compute_place_field(coeffs_z, design_zern[:, :coeffs_z.size], grid_shape) + field_z = _compute_place_field(coeffs_z, design_zern[:, :coeffs_z.size], grid_shape, sr) ax.pcolormesh(xx, yy, field_z, shading="gouraud", cmap="jet") except Exception: pass @@ -372,13 +382,14 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, # ================================================================== exampleCell = min(24, nCells1 - 1) # 0-indexed → cell 25 in Matlab fr_ex = fitResults1[exampleCell] + sr_ex = float(fr_ex.lambda_signal.sampleRate) coeffs_g = np.asarray(fr_ex.b[0], dtype=float).ravel() coeffs_z = np.asarray(fr_ex.b[1], dtype=float).ravel() field_g = _compute_place_field( - coeffs_g, gridDesignGauss[:, :coeffs_g.size], xx.shape) + coeffs_g, gridDesignGauss[:, :coeffs_g.size], xx.shape, sr_ex) field_z = _compute_place_field( - coeffs_z, gridDesignZern[:, :coeffs_z.size], xx.shape) + coeffs_z, gridDesignZern[:, :coeffs_z.size], xx.shape, sr_ex) fig7 = plt.figure(figsize=(12, 8)) ax3d = fig7.add_subplot(111, projection="3d") diff --git a/nstat/analysis.py b/nstat/analysis.py index 36269a30..be9a8f1d 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -13,6 +13,22 @@ def psth(spike_trains: Sequence[object], bin_edges: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Compute peri-stimulus time histogram (PSTH) from multiple spike trains. + + Parameters + ---------- + spike_trains : sequence of nspikeTrain + Collection of spike train objects, each with a ``spikeTimes`` attribute. + bin_edges : array_like, shape (n_bins + 1,) + Edges of the time bins (seconds). + + Returns + ------- + mean_rate_hz : ndarray, shape (n_bins,) + Trial-averaged firing rate in Hz for each bin. + counts : ndarray, shape (n_bins,) + Raw spike counts summed across all trials per bin. + """ edges = np.asarray(bin_edges, dtype=float) if edges.ndim != 1 or edges.size < 2: raise ValueError("bin_edges must be 1D and length >= 2") @@ -125,7 +141,19 @@ def _benjamini_hochberg(p_values: np.ndarray, alpha: float) -> np.ndarray: class Analysis: - """Canonical analysis entry points preserving MATLAB-facing workflow semantics.""" + """Collection of static methods for GLM analysis of point-process data. + + Every public method is a ``@staticmethod``; the class acts as a pure + namespace that mirrors the Matlab ``@Analysis`` class. Two naming + conventions coexist: + + * **PEP 8** (snake_case): ``run_analysis_for_neuron``, ``run_analysis_for_all_neurons`` + * **Matlab-facing** (camelCase): ``RunAnalysisForNeuron``, ``RunAnalysisForAllNeurons`` + + See Also + -------- + Trial, ConfigCollection, SpikeTrainCollection, History + """ colors = ["b", "g", "r", "c", "m", "y", "k"] @@ -145,6 +173,7 @@ def _collapse_spike_input(nspikeObj): @staticmethod def psth(spike_trains: Sequence[object], bin_edges: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Compute peri-stimulus time histogram. Delegates to module-level :func:`psth`.""" return psth(spike_trains, bin_edges) @staticmethod @@ -157,6 +186,50 @@ def GLMFit( l2: float = 1e-6, max_iter: int = 120, ): + """Fit a point-process GLM for a single neuron from a Trial. + + Extracts the design matrix *X* from the current covariate masks, + history, and ensemble history in the Trial, and the observation + vector *Y*, then performs the GLM regression. + + Parameters + ---------- + tObj : Trial + Trial containing spike trains and covariates. + neuronNumber : int or str or sequence + Matlab-style 1-based neuron index, name, or sequence thereof. + lambdaIndex : int + Configuration index used for labelling the returned λ. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + ``'GLM'`` — standard Poisson GLM regression. + ``'BNLRCG'`` — truncated, L-2 regularised binomial logistic + regression (requires binary spike representation). + l2 : float, default 1e-6 + L-2 regularisation strength. + max_iter : int, default 120 + Maximum IRLS / CG iterations. + + Returns + ------- + lambda_sig : Covariate + Conditional intensity function evaluated on the design-matrix + time grid. + b : ndarray + GLM regression coefficients. + dev : float + Deviance of the fit. + stats : dict + Fit statistics (standard errors, convergence info, covariance + matrix). + AIC : float + Akaike information criterion. + BIC : float + Bayesian information criterion. + logLL : float + Log-likelihood evaluated with the fit parameters. + distribution : str + ``'poisson'`` or ``'binomial'``. + """ algorithm = str(Algorithm or "GLM").upper() if algorithm not in {"GLM", "BNLRCG"}: raise ValueError("Algorithm not supported!") @@ -264,6 +337,33 @@ def run_analysis_for_neuron( l2: float = 1e-6, max_iter: int = 120, ) -> FitResult: + """Run GLM analysis for one neuron across all configurations. + + Iterates over the configurations in *config_collection*, fits a GLM + for each, computes KS diagnostics, and returns a single + :class:`FitResult` that aggregates all fits. + + Parameters + ---------- + trial : Trial + Trial object containing spike trains and covariates. + neuron_index : int + Zero-based neuron index. + config_collection : ConfigCollection + Configurations describing the fits to perform (covariates, + history, ensemble history). + algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + l2 : float, default 1e-6 + L-2 regularisation strength. + max_iter : int, default 120 + Maximum iterations for the GLM solver. + + Returns + ------- + FitResult + Fit result with KS statistics already populated. + """ if neuron_index < 0: raise IndexError("neuron_index must be >= 0") @@ -393,6 +493,29 @@ def run_analysis_for_all_neurons( l2: float = 1e-6, max_iter: int = 120, ) -> list[FitResult]: + """Run GLM analysis for every unmasked neuron in the trial. + + Calls :meth:`run_analysis_for_neuron` for each neuron in the + trial's spike-train collection. + + Parameters + ---------- + trial : Trial + Trial to analyse. + config_collection : ConfigCollection + Configurations describing the fits to perform. + algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + l2 : float, default 1e-6 + L-2 regularisation strength. + max_iter : int, default 120 + Maximum iterations for the GLM solver. + + Returns + ------- + list of FitResult + One :class:`FitResult` per neuron. + """ out: list[FitResult] = [] for i in range(trial.spike_collection.num_spike_trains): out.append( @@ -409,6 +532,31 @@ def run_analysis_for_all_neurons( @staticmethod def RunAnalysisForNeuron(tObj: Trial, neuronNumber, configColl: ConfigCollection, makePlot=1, Algorithm="GLM", DTCorrection=1, batchMode=0): + """Matlab-facing wrapper for :meth:`run_analysis_for_neuron`. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + neuronNumber : int or str or sequence + Matlab-style 1-based neuron index, name, or vector of indices. + If more than one neuron is specified the return value is a list. + configColl : ConfigCollection + Configurations describing the fits. + makePlot : int, default 1 + If ``1``, plot a summary for the neuron. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + DTCorrection : int, default 1 + Discrete-time KS correction flag (kept for API parity; unused). + batchMode : int, default 0 + Batch-mode flag (kept for API parity; unused). + + Returns + ------- + FitResult or list of FitResult + Single result when one neuron is specified, list otherwise. + """ del DTCorrection, batchMode indices = _as_neuron_indices(tObj, neuronNumber) fits = [Analysis.run_analysis_for_neuron(tObj, idx - 1, configColl, algorithm=Algorithm) for idx in indices] @@ -418,6 +566,31 @@ def RunAnalysisForNeuron(tObj: Trial, neuronNumber, configColl: ConfigCollection @staticmethod def RunAnalysisForAllNeurons(tObj: Trial, configs: ConfigCollection, makePlot=1, Algorithm="GLM", DTCorrection=1, batchMode=0): + """Matlab-facing wrapper for :meth:`run_analysis_for_all_neurons`. + + Runs the fits specified by *configs* on every unmasked neuron in + the trial. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + configs : ConfigCollection + Configurations describing the fits. + makePlot : int, default 1 + If ``1``, generate a summary plot for each neuron. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + DTCorrection : int, default 1 + Discrete-time KS correction flag (unused). + batchMode : int, default 0 + Batch-mode flag (unused). + + Returns + ------- + FitResult or list of FitResult + Single result when the trial has one neuron, list otherwise. + """ del DTCorrection, batchMode fits = Analysis.run_analysis_for_all_neurons(tObj, configs, algorithm=Algorithm) if makePlot and len(fits) == 1: @@ -426,11 +599,63 @@ def RunAnalysisForAllNeurons(tObj: Trial, configs: ConfigCollection, makePlot=1, @staticmethod def computeKSStats(nspikeObj, lambdaInput: Covariate, DTCorrection: int = 1, *, random_values=None): + """Compute KS goodness-of-fit statistics via the time-rescaling theorem. + + Given a neural spike train and a candidate conditional intensity + function, computes the rescaled ISIs and the KS plot data. + + Parameters + ---------- + nspikeObj : nspikeTrain or SpikeTrainCollection or sequence + Neural spike train(s). + lambdaInput : Covariate + Candidate conditional intensity function. + DTCorrection : int, default 1 + If ``1``, apply discrete-time correction to KS plot. + random_values : array_like, optional + Pre-drawn uniform random values for reproducibility. + + Returns + ------- + Z : ndarray + Rescaled spike times. + U : ndarray + Z transformed to uniform(0, 1). + xAxis : ndarray + x-axis of the KS plot. + KSSorted : ndarray + Sorted rescaled times (y-axis of KS plot). + ks_stat : ndarray + KS statistic — maximum deviation from the 45° line for each + conditional intensity function. + """ nspikeObj = Analysis._collapse_spike_input(nspikeObj) return _matlab_compute_ks_arrays(nspikeObj, lambdaInput, dt_correction=DTCorrection, random_values=random_values) @staticmethod def computeInvGausTrans(Z): + """Compute the inverse-Gaussian transform of rescaled spike times. + + Transforms rescaled spike times *Z* to uniform(0, 1) via + ``U = 1 − exp(−Z)``, then applies the inverse-Gaussian (probit) + transform ``X = Φ⁻¹(U)``. The autocorrelation of *X* is used + to test for independence of the rescaled ISIs (a condition for + the time-rescaling theorem). + + Parameters + ---------- + Z : array_like + Rescaled spike times (exponential rate-1 under H₀). + + Returns + ------- + X : ndarray + Inverse-Gaussian transformed values. + rhoSig : SignalObj + Autocorrelation function of *X*. + confBoundSig : SignalObj + ±1.96 / √N confidence bounds for zero autocorrelation. + """ z = np.asarray(Z, dtype=float) if z.ndim == 1: z = z[:, None] @@ -459,6 +684,31 @@ def computeInvGausTrans(Z): @staticmethod def computeFitResidual(nspikeObj, lambdaInput: Covariate, windowSize: float = 0.01): + """Compute the point-process residual. + + Defined as the difference between the observed spike count and + the integral of the candidate conditional intensity function + in each time window, following Truccolo *et al.* (2005). + + Parameters + ---------- + nspikeObj : nspikeTrain or SpikeTrainCollection or sequence + Neural spike train(s). + lambdaInput : Covariate + Candidate conditional intensity function. + windowSize : float, default 0.01 + Size of the integration window (seconds). + + Returns + ------- + Covariate + Point-process residual M(t_k). + + References + ---------- + Truccolo, W., Eden, U. T., Fellows, M. R., Donoghue, J. P., & + Brown, E. N. (2005). *J Neurophysiol*, 93(2), 1074–1089. + """ nspikeObj = Analysis._collapse_spike_input(nspikeObj) nCopy = nspikeObj.nstCopy() @@ -492,30 +742,136 @@ def computeFitResidual(nspikeObj, lambdaInput: Covariate, windowSize: float = 0. @staticmethod def KSPlot(fitResults: FitResult, DTCorrection: int = 1, makePlot: int = 1): + """Compute KS statistics and optionally generate the KS plot. + + Parameters + ---------- + fitResults : FitResult + Fit result to compute KS statistics for. + DTCorrection : int, default 1 + Discrete-time correction flag. + makePlot : int, default 1 + If ``1``, generate the KS plot. + + Returns + ------- + list + Plot handles (empty list when *makePlot* is ``0``). + """ fitResults.computeKSStats(dt_correction=DTCorrection) return fitResults.KSPlot() if makePlot else [] @staticmethod def plotFitResidual(fitResults: FitResult, windowSize: float = 0.01, makePlot: int = 1): + """Compute and plot the point-process residual. + + Parameters + ---------- + fitResults : FitResult + Fit result to compute the residual for. + windowSize : float, default 0.01 + Integration window size (seconds). + makePlot : int, default 1 + If ``1``, generate the residual plot. + + Returns + ------- + list + Plot handles (empty list when *makePlot* is ``0``). + """ fitResults.computeFitResidual(windowSize=windowSize) return fitResults.plotResidual() if makePlot else [] @staticmethod def plotInvGausTrans(fitResults: FitResult, makePlot: int = 0): + """Compute and optionally plot the inverse-Gaussian transform ACF. + + Parameters + ---------- + fitResults : FitResult + Fit result to compute the transform for. + makePlot : int, default 0 + If ``1``, generate the ACF plot. + + Returns + ------- + list + Plot handles (empty list when *makePlot* is ``0``). + """ fitResults.computeInvGausTrans() return fitResults.plotInvGausTrans() if makePlot else [] @staticmethod def plotSeqCorr(fitResults: FitResult): + """Plot the sequential correlation of rescaled ISIs (z_j vs z_{j-1}). + + Parameters + ---------- + fitResults : FitResult + Fit result (inverse-Gaussian transform is computed if needed). + + Returns + ------- + list + Plot handles. + """ fitResults.computeInvGausTrans() return fitResults.plotSeqCorr() @staticmethod def plotCoeffs(fitResults: FitResult): + """Plot regression coefficients for all fits in *fitResults*. + + Parameters + ---------- + fitResults : FitResult + Fit result whose coefficients to plot. + + Returns + ------- + list + Plot handles. + """ return fitResults.plotCoeffs() @staticmethod def computeHistLag(tObj: Trial, neuronNum=None, windowTimes=None, CovLabels=None, Algorithm="GLM", batchMode=0, sampleRate=None, makePlot=1, histMinTimes=None, histMaxTimes=None): + """Sweep self-history window orders for a single neuron. + + Fits a sequence of GLMs with increasing numbers of history + windows (no extrinsic covariates, no ensemble history) and + returns the fit results for model selection via AIC / BIC / KS. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + neuronNum : int or None + Matlab-style 1-based neuron index. If ``None``, uses the + first unmasked neuron. + windowTimes : array_like + Vector of window boundary times. ``len(windowTimes) - 1`` + configurations are created with increasing history order. + CovLabels : list of str or None + Covariate labels to include in each configuration. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + batchMode : int + Unused (Matlab API parity). + sampleRate : float or None + Sample rate override; defaults to ``tObj.sampleRate``. + makePlot : int, default 1 + If ``1``, generate a summary plot. + histMinTimes, histMaxTimes : float or None + Optional time bounds passed to the ``History`` object. + + Returns + ------- + fitResults : FitResult + Fit result containing all history-order configurations. + tcc : ConfigCollection + The generated configuration collection. + """ del batchMode if windowTimes is None: raise ValueError("Must specify a vector of windowTimes") @@ -548,6 +904,35 @@ def computeHistLag(tObj: Trial, neuronNum=None, windowTimes=None, CovLabels=None @staticmethod def computeHistLagForAll(tObj: Trial, windowTimes, CovLabels=None, Algorithm="GLM", batchMode=0, sampleRate=None, makePlot=1, histMinTimes=None, histMaxTimes=None): + """Sweep self-history window orders for all unmasked neurons. + + Calls :meth:`computeHistLag` for each unmasked neuron in the + trial. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + windowTimes : array_like + Vector of window boundary times. + CovLabels : list of str or None + Covariate labels for each configuration. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + batchMode : int + Unused (Matlab API parity). + sampleRate : float or None + Sample rate override. + makePlot : int, default 1 + Summary plot flag. + histMinTimes, histMaxTimes : float or None + Optional time bounds for the ``History`` object. + + Returns + ------- + list of FitResult + One fit result per unmasked neuron. + """ results = [] for neuron_idx in tObj.getNeuronIndFromMask(): fit, _ = Analysis.computeHistLag( @@ -567,6 +952,39 @@ def computeHistLagForAll(tObj: Trial, windowTimes, CovLabels=None, Algorithm="GL @staticmethod def compHistEnsCoeff(tObj: Trial, history, neuronNum=None, neighbors=None, ensembleCov=None, makePlot=1): + """Compute ensemble-history coefficients for one neuron. + + Builds a covariate collection from the spiking history of + neighbouring neurons and fits a GLM with ensemble history as the + design matrix. + + Parameters + ---------- + tObj : Trial + Trial containing spike trains and covariates. + history : History + History object defining the window structure. + neuronNum : int or None + Matlab-style 1-based neuron index. Defaults to the first + unmasked neuron. + neighbors : array_like or None + Indices of neighbouring neurons. Defaults to + ``tObj.getNeuronNeighbors(neuronNum)``. + ensembleCov : CovariateCollection or None + Pre-computed ensemble covariates. If ``None``, computed + automatically. + makePlot : int, default 1 + Summary plot flag. + + Returns + ------- + fitResults : FitResult + Fit result for the ensemble-history model. + ensembleCov : CovariateCollection + Ensemble covariates used in the fit. + tcc : ConfigCollection + Configuration collection used. + """ from .trial import TrialConfig neuron_index = _as_neuron_indices(tObj, neuronNum if neuronNum is not None else tObj.getNeuronIndFromMask()[0])[0] @@ -583,6 +1001,29 @@ def compHistEnsCoeff(tObj: Trial, history, neuronNum=None, neighbors=None, ensem @staticmethod def compHistEnsCoeffForAll(tObj: Trial, history, makePlot=1): + """Compute ensemble-history coefficients for all unmasked neurons. + + Calls :meth:`compHistEnsCoeff` for each neuron that is not + masked. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + history : History + History object defining the window structure. + makePlot : int, default 1 + Summary plot flag. + + Returns + ------- + fit_results : list of FitResult + One fit result per neuron. + ensemble_cov : CovariateCollection or None + Ensemble covariates from the last neuron. + config_collections : list of ConfigCollection + Configuration collections used per neuron. + """ neuron_indices = tObj.getNeuronIndFromMask() if not neuron_indices: return [], None, [] @@ -606,6 +1047,38 @@ def compHistEnsCoeffForAll(tObj: Trial, history, makePlot=1): @staticmethod def computeGrangerCausalityMatrix(tObj: Trial, Algorithm="GLM", confidenceInterval=0.95, batchMode=0): + """Compute the Granger-causality matrix for the neural ensemble. + + For every pair of neurons, fits a baseline model (full ensemble + history) and a reduced model (one neighbour excluded), then + computes the log-likelihood ratio. Statistical significance is + corrected for multiple comparisons with Benjamini–Hochberg FDR. + + Parameters + ---------- + tObj : Trial + Trial with ensemble history configured. + Algorithm : {'GLM', 'BNLRCG'}, default ``'GLM'`` + Regression algorithm. + confidenceInterval : float, default 0.95 + Confidence level for the significance test. + batchMode : int + Unused (Matlab API parity). + + Returns + ------- + fitResults : list of list of FitResult + ``fitResults[i][j]`` is the fit result for the test of + neighbour *j* → neuron *i*. + gammaMat : ndarray, shape (N, N) + Log-likelihood ratio Γ matrix. + phiMat : ndarray, shape (N, N) + Signed Γ matrix (sign from sum of excluded coefficients). + devianceMat : ndarray, shape (N, N) + Deviance (−2Γ) matrix. + sigMat : ndarray, shape (N, N) + Binary significance matrix after FDR correction. + """ del batchMode neuron_indices = tObj.getNeuronIndFromMask() n_neurons = tObj.nspikeColl.numSpikeTrains @@ -681,6 +1154,32 @@ def computeGrangerCausalityMatrix(tObj: Trial, Algorithm="GLM", confidenceInterv @staticmethod def computeNeighbors(tObj: Trial, neuronNum=None, sampleRate=None, windowTimes=None, makePlot=1): + """Sweep ensemble-history orders for one neuron (no self-history). + + Fits a sequence of GLMs with increasing ensemble-history window + orders but no self-history and no extrinsic covariates, for model + selection of the ensemble effect. + + Parameters + ---------- + tObj : Trial + Trial to analyse. + neuronNum : int or None + Matlab-style 1-based neuron index. + sampleRate : float or None + Sample rate override. + windowTimes : array_like + Vector of window boundary times. + makePlot : int, default 1 + Summary plot flag. + + Returns + ------- + fitResults : FitResult + Fit result with all configurations. + tcc : ConfigCollection + Generated configuration collection. + """ if windowTimes is None: raise ValueError("Must specify a vector of windowTimes") neuron_index = _as_neuron_indices(tObj, neuronNum if neuronNum is not None else tObj.getNeuronIndFromMask()[0])[0] @@ -707,6 +1206,28 @@ def computeNeighbors(tObj: Trial, neuronNum=None, sampleRate=None, windowTimes=N @staticmethod def spikeTrigAvg(tObj: Trial, neuronNum, windowSize): + """Compute the spike-triggered average of all covariates. + + Each covariate dimension is sampled at every spike time of the + specified neuron ± ``windowSize / 2``. The returned collection + contains one covariate per original dimension, where each column + corresponds to a single spike. Use ``plotVariability`` on the + returned signals to visualise the average and spread. + + Parameters + ---------- + tObj : Trial + Trial containing spike trains and covariates. + neuronNum : int + Matlab-style 1-based neuron index. + windowSize : float + Total window length (seconds) centred on each spike. + + Returns + ------- + CovariateCollection + Collection of spike-triggered covariate samples. + """ from .trial import CovariateCollection train = tObj.getNeuron(neuronNum).nstCopy() diff --git a/nstat/cif.py b/nstat/cif.py index 16331ae7..9edbf13f 100644 --- a/nstat/cif.py +++ b/nstat/cif.py @@ -326,7 +326,36 @@ def from_linear_terms( class CIF: - """MATLAB-facing CIF object plus native Python simulation helpers.""" + """Conditional Intensity Function for point-process modelling. + + Encapsulates the regression coefficients, variable names, link function, + and optional spike-history terms that define a conditional intensity + function (CIF). Supports symbolic differentiation (gradient / Jacobian) + for use in point-process adaptive filters and decoders. + + Parameters + ---------- + beta : array_like or None + Regression coefficients. + Xnames : sequence of str or None + Names of the variables in the order they appear in *beta*. + stimNames : sequence of str or None + Names of the subset of variables that define the stimulus. + fitType : {'poisson', 'binomial'}, default ``'poisson'`` + Link function. For Poisson: ``λΔ = exp(Xβ)``. + For binomial: ``λΔ = exp(Xβ) / (1 + exp(Xβ))``. + histCoeffs : array_like or None + Coefficients for each history window defined in *historyObj*. + historyObj : History or array_like or None + History object (or window-times vector) defining the spiking- + history basis. + nst : nspikeTrain or None + Spike train used to pre-compute history values. + + See Also + -------- + History, DecodingAlgorithms + """ def __init__( self, @@ -378,6 +407,7 @@ def __init__( self.setSpikeTrain(nst) def CIFCopy(self): + """Return a deep copy of this CIF object.""" copied = CIF( beta=np.asarray(self.b, dtype=float).copy(), Xnames=list(self.varIn), @@ -394,6 +424,7 @@ def CIFCopy(self): return copied def setSpikeTrain(self, spikeTrain) -> None: + """Attach a spike train and pre-compute the history matrix.""" if not isinstance(spikeTrain, nspikeTrain): spikeTrain = getattr(spikeTrain, "nstCopy", lambda: spikeTrain)() self.spikeTrain = spikeTrain.nstCopy() @@ -403,6 +434,14 @@ def setSpikeTrain(self, spikeTrain) -> None: self.historyMat = np.zeros((0, 0), dtype=float) def setHistory(self, histObj) -> None: + """Set the History object for this CIF. + + Parameters + ---------- + histObj : History or array_like + A ``History`` object, or a vector of window-times from which + one will be created. + """ if isinstance(histObj, History): self.history = History(histObj.windowTimes, histObj.minTime, histObj.maxTime, histObj.name) elif isinstance(histObj, (np.ndarray, Sequence)) and not isinstance(histObj, (str, bytes)): @@ -541,24 +580,93 @@ def _jacobian(self, stimVal, time_index: int | None = None, nst: nspikeTrain | N return np.zeros_like(outer) if log else lambda_delta * outer def evalLambdaDelta(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None): + """Evaluate λΔ at the given stimulus values. + + Parameters + ---------- + stimVal : array_like + Stimulus variable values. + time_index : int or None + 1-based time index into the pre-computed history matrix. + nst : nspikeTrain or None + Spike train for on-the-fly history computation. + + Returns + ------- + float + Scalar value of λΔ. + """ return self._lambda_delta(stimVal, time_index=time_index, nst=nst) def evalGradient(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None): + """Gradient of λΔ with respect to the stimulus variables. + + Returns + ------- + ndarray, shape (1, n_stim) + Row vector of partial derivatives. + """ return self._gradient(stimVal, time_index=time_index, nst=nst) def evalGradientLog(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None): + """Gradient of log(λΔ) with respect to the stimulus variables. + + Returns + ------- + ndarray, shape (1, n_stim) + Row vector of partial derivatives. + """ return self._gradient(stimVal, time_index=time_index, nst=nst, log=True) def evalJacobian(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None): + """Hessian of λΔ with respect to the stimulus variables. + + Returns + ------- + ndarray, shape (n_stim, n_stim) + Second-derivative matrix. + """ return self._jacobian(stimVal, time_index=time_index, nst=nst) def evalJacobianLog(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None): + """Hessian of log(λΔ) with respect to the stimulus variables. + + Returns + ------- + ndarray, shape (n_stim, n_stim) + Second-derivative matrix. + """ return self._jacobian(stimVal, time_index=time_index, nst=nst, log=True) def evalLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Evaluate λΔ with gamma-scaled history coefficients. + + Parameters + ---------- + stimVal : array_like + Stimulus variable values. + time_index : int or None + 1-based time index into the history matrix. + nst : nspikeTrain or None + Spike train for on-the-fly history computation. + gamma : array_like or None + Scaling factors applied to the history coefficients. + + Returns + ------- + float + Scalar value of λΔ. + """ return self._lambda_delta(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Evaluate log(λΔ) with gamma-scaled history coefficients. + + Returns + ------- + float + Scalar value of log(λΔ). + """ if self._expression_surface is not None and self._expression_surface["log_lambda_gamma_fn"] is not None: return float( np.asarray( @@ -571,6 +679,7 @@ def evalLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrai return float(np.log(np.clip(self.evalLDGamma(stimVal, time_index=time_index, nst=nst, gamma=gamma), 1e-12, None))) def evalGradientLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Gradient of λΔ w.r.t. gamma (history-coefficient scaling).""" if self._expression_surface is not None and self._expression_surface["gradient_gamma_fn"] is not None: return _reshape_row( self._expression_surface["gradient_gamma_fn"]( @@ -581,6 +690,7 @@ def evalGradientLDGamma(self, stimVal, time_index: int | None = None, nst: nspik return self._gradient(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalGradientLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Gradient of log(λΔ) w.r.t. gamma (history-coefficient scaling).""" if self._expression_surface is not None and self._expression_surface["gradient_log_gamma_fn"] is not None: return _reshape_row( self._expression_surface["gradient_log_gamma_fn"]( @@ -591,6 +701,7 @@ def evalGradientLogLDGamma(self, stimVal, time_index: int | None = None, nst: ns return self._gradient(stimVal, time_index=time_index, nst=nst, gamma=gamma, log=True) def evalJacobianLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Hessian of λΔ w.r.t. gamma (history-coefficient scaling).""" if self._expression_surface is not None and self._expression_surface["jacobian_gamma_fn"] is not None: return _reshape_square( self._expression_surface["jacobian_gamma_fn"]( @@ -601,6 +712,7 @@ def evalJacobianLDGamma(self, stimVal, time_index: int | None = None, nst: nspik return self._jacobian(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalJacobianLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + """Hessian of log(λΔ) w.r.t. gamma (history-coefficient scaling).""" if self._expression_surface is not None and self._expression_surface["jacobian_log_gamma_fn"] is not None: return _reshape_square( self._expression_surface["jacobian_log_gamma_fn"]( @@ -611,12 +723,29 @@ def evalJacobianLogLDGamma(self, stimVal, time_index: int | None = None, nst: ns return self._jacobian(stimVal, time_index=time_index, nst=nst, gamma=gamma, log=True) def isSymBeta(self) -> bool: + """Return ``True`` if the coefficients contain symbolic expressions.""" beta = np.asarray(self.b) if beta.dtype == object: return True return any(type(item).__module__.startswith("sympy") for item in beta.reshape(-1)) def evaluate(self, design_matrix: np.ndarray, *, delta: float = 1.0, history_matrix: np.ndarray | None = None) -> np.ndarray: + """Evaluate the CIF on a full design matrix (vectorised). + + Parameters + ---------- + design_matrix : ndarray, shape (T, n_vars) + Design matrix with one row per time step. + delta : float, default 1.0 + Bin width (seconds). The returned rate is λΔ / Δ. + history_matrix : ndarray or None + Pre-computed history matrix, shape (T, n_hist). + + Returns + ------- + ndarray, shape (T,) + Firing rate in Hz. + """ x = np.asarray(design_matrix, dtype=float) if x.ndim == 1: x = x[:, None] @@ -650,6 +779,7 @@ def evaluate(self, design_matrix: np.ndarray, *, delta: float = 1.0, history_mat return lambda_delta / max(float(delta), 1e-12) def to_covariate(self, time: Sequence[float], design_matrix: np.ndarray, *, delta: float = 1.0, name: str = "lambda") -> Covariate: + """Evaluate the CIF and return the result as a :class:`Covariate`.""" rate = self.evaluate(design_matrix, delta=delta) return Covariate(time, rate, name, "time", "s", "spikes/sec", [name]) @@ -664,6 +794,35 @@ def simulateCIFByThinningFromLambda( thinning_values: np.ndarray | None = None, return_details: bool = False, ) -> SpikeTrainCollection: + """Simulate spike trains from a given λ(t) via thinning. + + Uses the thinning (rejection) algorithm: propose spikes from a + homogeneous Poisson process at the bound rate, then accept each + candidate with probability λ(t) / λ_max. + + Parameters + ---------- + lambda_covariate : Covariate + Conditional intensity function time series (Hz). + numRealizations : int, default 1 + Number of independent spike-train realisations. + maxTimeRes : float or None + Minimum inter-spike interval (seconds). Spikes closer than + this are merged. + seed : int or None + Random seed for reproducibility. + random_values : ndarray or None + Pre-drawn uniform random values for the proposal process. + thinning_values : ndarray or None + Pre-drawn uniform random values for thinning acceptance. + return_details : bool, default False + If ``True``, return ``(collection, details_dict)`` instead. + + Returns + ------- + SpikeTrainCollection + Collection of simulated spike trains. + """ if int(numRealizations) < 1: raise ValueError("numRealizations must be >= 1") @@ -789,6 +948,11 @@ def simulateCIFByThinning( seed: int | None = None, return_lambda: bool = False, ): + """Simulate a point process via the thinning algorithm. + + Alias for :meth:`simulateCIF`. See that method for full + parameter documentation. + """ return CIF.simulateCIF( mu, hist, @@ -818,6 +982,45 @@ def simulateCIF( random_values: np.ndarray | None = None, return_details: bool = False, ): + """Simulate a point process from component kernels and inputs. + + Constructs λΔ from the input terms + ``μ + stim ∗ inputStimSignal + hist ∗ spikeHistory + ens ∗ inputEnsSignal`` + and generates spike trains via Bernoulli draws at each time step. + + Parameters + ---------- + mu : float + Baseline (mean) log-rate of the process. + hist : transfer-function-like or array_like + History kernel convolved with the process's own spiking. + stim : transfer-function-like or array_like + Stimulus kernel convolved with *inputStimSignal*. + ens : transfer-function-like or array_like + Ensemble kernel convolved with *inputEnsSignal*. + inputStimSignal : Covariate + Stimulus time series. + inputEnsSignal : Covariate + Ensemble activity time series. + numRealizations : int, default 1 + Number of independent realisations. + simType : {'binomial', 'poisson'}, default ``'binomial'`` + Link function for computing λΔ. + seed : int or None + Random seed. + return_lambda : bool, default False + If ``True``, return ``(collection, lambda_array)``. + random_values : ndarray or None + Pre-drawn uniform random values for reproducibility. + return_details : bool, default False + If ``True``, return ``(collection, details_dict)``. + + Returns + ------- + SpikeTrainCollection + Simulated spike trains (or tuple if *return_lambda* / + *return_details* is ``True``). + """ if int(numRealizations) < 1: raise ValueError("numRealizations must be >= 1") time = np.asarray(inputStimSignal.time, dtype=float).reshape(-1) diff --git a/nstat/confidence_interval.py b/nstat/confidence_interval.py index a7bb8f39..3d43aef7 100644 --- a/nstat/confidence_interval.py +++ b/nstat/confidence_interval.py @@ -18,6 +18,27 @@ class ConfidenceInterval: + """Confidence interval for a time series or Covariate. + + Stores a pair of (lower, upper) bound traces over a shared time + axis, with plotting support for both line and shaded-patch styles. + + Parameters + ---------- + time : array_like + Time vector. + bounds : array_like, shape (n, 2) + Lower and upper bounds at each time point. + *args + Positional metadata: ``(name, xlabelval, xunits, yunits, + dataLabels, plotProps)``. If a single short string is passed + it is interpreted as the colour (Matlab compatibility). + color : str or None + Line / patch colour. Default ``'b'``. + value : float, default 0.95 + Confidence level (e.g. 0.95 for 95 %). + """ + def __init__(self, time, bounds, *args, color: str | None = None, value: float = 0.95) -> None: t = np.asarray(time, dtype=float).reshape(-1) b = np.asarray(bounds, dtype=float) @@ -89,15 +110,19 @@ def upper(self) -> np.ndarray: return self.bounds[:, 1] def setColor(self, color: str) -> None: + """Set the plot colour.""" self.color = str(color) def setValue(self, value: float) -> None: + """Set the confidence level (e.g. 0.95 for 95 %).""" self.value = float(value) def dataToMatrix(self) -> np.ndarray: + """Return the bounds as an (n, 2) numpy array.""" return np.asarray(self.bounds, dtype=float) def dataToStructure(self) -> dict: + """Serialise to a plain dictionary (matches Matlab ``dataToStructure``).""" return { "time": self.time.tolist(), "signals": {"values": self.bounds.tolist(), "dimensions": self.dimension}, @@ -160,10 +185,12 @@ def __neg__(self): return ConfidenceInterval(self.time, np.column_stack([-self.upper, -self.lower]), self.color) def toStructure(self) -> dict: + """Alias for :meth:`dataToStructure`.""" return self.dataToStructure() @staticmethod def fromStructure(structure: dict) -> "ConfidenceInterval": + """Reconstruct a ConfidenceInterval from a dictionary.""" signals = structure.get("signals", {}) values = signals.get("values", structure.get("data")) ci = ConfidenceInterval( @@ -187,6 +214,25 @@ def fromStructure(structure: dict) -> "ConfidenceInterval": return ci def plot(self, color: str | None = None, alphaVal: float = 0.2, drawPatches: int = 0, ax=None): + """Plot the confidence interval. + + Parameters + ---------- + color : str or None + Override colour (default: ``self.color``). + alphaVal : float, default 0.2 + Transparency for shaded patches. + drawPatches : int, default 0 + If ``1``, draw a shaded ``fill_between`` region instead of + lines. + ax : Axes or None + Matplotlib axes. If ``None``, uses ``plt.gca()``. + + Returns + ------- + PolyCollection or list of Line2D + Plot handles. + """ import matplotlib.pyplot as plt axis = plt.gca() if ax is None else ax diff --git a/nstat/core.py b/nstat/core.py index dc9bb895..6d1eb882 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -71,7 +71,39 @@ def _nearest_sample_matrix(target_time: np.ndarray, source_time: np.ndarray, sou class SignalObj: - """Closer MATLAB-style signal abstraction used throughout the Python port.""" + """Multi-dimensional time-series signal object (Matlab ``SignalObj``). + + ``SignalObj`` is the foundational data container in nSTAT. It stores + one or more signal channels sampled on a common time axis, along with + metadata (name, units, labels) and supports element-wise arithmetic, + resampling, filtering, correlation analysis, and spectral estimation. + + Parameters + ---------- + time : array_like + Monotonically increasing time vector of length *n*. + data : array_like + Signal values. Shape ``(n,)`` for a scalar signal or ``(n, d)`` + for a *d*-dimensional signal. + name : str, optional + Human-readable signal name (used as y-axis label in plots). + xlabelval : str, optional + X-axis label string (default ``'time'``). + xunits : str, optional + X-axis unit string (default ``'s'``). + yunits : str, optional + Y-axis unit string. + dataLabels : sequence of str or str, optional + Per-dimension labels. A single string is broadcast to all + dimensions. + plotProps : sequence or str, optional + Per-dimension Matplotlib format strings. + + See Also + -------- + Covariate : SignalObj subclass with confidence-interval support. + nspikeTrain : Point-process (spike train) companion class. + """ def __init__( self, @@ -120,20 +152,24 @@ def __init__( @property def dimension(self) -> int: + """Number of signal channels (columns in the data matrix).""" return int(self.data.shape[1]) @property def values(self) -> np.ndarray: + """Signal data as a 1-D array (scalar) or 2-D matrix.""" if self.dimension == 1: return self.data[:, 0] return self.data @property def units(self) -> str: + """Y-axis unit string (alias for ``yunits``).""" return self.yunits @property def sample_rate(self) -> float: + """Sampling rate in Hz (alias for ``sampleRate``).""" return float(self.sampleRate) def _spawn( @@ -158,6 +194,7 @@ def _spawn( ) def copySignal(self) -> "SignalObj": + """Return a deep copy of this signal (Matlab ``copySignal``).""" copied = self._spawn(self.time, self.data) if self.conf_interval is not None: copied.conf_interval = ( @@ -215,36 +252,48 @@ def _binary_op(self, other, op) -> "SignalObj": return self._spawn(self.time, result, data_labels=labels) def setName(self, name: str) -> None: + """Set the signal name (y-axis label).""" if not isinstance(name, str): raise TypeError("Name must be a string!") self.name = name def setXlabel(self, name: str) -> None: + """Set the x-axis label string.""" self.xlabelval = str(name) def setYLabel(self, name: str) -> None: + """Set the y-axis label (alias for ``setName``).""" self.setName(name) def setUnits(self, xUnits: str, yUnits: str | None = None) -> None: + """Set x-axis and optionally y-axis units.""" if yUnits is not None: self.setYUnits(yUnits) self.setXUnits(xUnits) def setXUnits(self, units: str) -> None: + """Set the x-axis unit string.""" if isinstance(units, str): self.xunits = units def setYUnits(self, units: str) -> None: + """Set the y-axis unit string.""" if isinstance(units, str): self.yunits = units def setSampleRate(self, sampleRate: float) -> None: + """Set the sample rate, resampling the data if it differs from current.""" requested = float(sampleRate) current = float(self.sampleRate) if abs(round(requested, 3) - round(current, 3)) > 0: self.resampleMe(requested) def setDataLabels(self, dataLabels: Sequence[str] | str | None) -> None: + """Set per-dimension data labels. + + A single string is broadcast to all dimensions. A sequence must + have length equal to ``dimension``. + """ if dataLabels is None or (isinstance(dataLabels, str) and dataLabels == ""): self.dataLabels = ["" for _ in range(self.dimension)] return @@ -259,6 +308,10 @@ def setDataLabels(self, dataLabels: Sequence[str] | str | None) -> None: self.dataLabels = labels def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None = None) -> None: + """Set per-dimension Matplotlib format strings. + + When *index* (1-based) is given, only that dimension is updated. + """ if index is None: if plotProps is None: self.plotProps = [None for _ in range(self.dimension)] @@ -287,6 +340,7 @@ def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None self.plotProps[target] = plotProps def setDataMask(self, dataMask: Sequence[int] | np.ndarray) -> None: + """Set binary data mask (1 = visible, 0 = hidden) for each dimension.""" mask = np.asarray(dataMask, dtype=int).reshape(-1) if mask.size != self.dimension: raise ValueError("dataMask must match the number of signal dimensions.") @@ -295,12 +349,14 @@ def setDataMask(self, dataMask: Sequence[int] | np.ndarray) -> None: self.dataMask = mask def setMaskByInd(self, index: Sequence[int] | np.ndarray) -> None: + """Enable only the dimensions at the given 1-based indices.""" selected = _coerce_1based_indices(index, self.dimension) mask = np.zeros(self.dimension, dtype=int) mask[np.asarray(selected, dtype=int) - 1] = 1 self.setDataMask(mask) def setMaskByLabels(self, labels: Sequence[str] | str) -> None: + """Enable only the dimensions whose data labels match *labels*.""" indices = self.getIndicesFromLabels(labels) if isinstance(indices, list) and indices and isinstance(indices[0], list): flat = [item for sub in indices for item in sub] @@ -311,6 +367,12 @@ def setMaskByLabels(self, labels: Sequence[str] | str) -> None: self.setMaskByInd(flat) def setMask(self, mask: Sequence[int] | Sequence[str] | np.ndarray | None = None) -> None: + """Flexible mask setter accepting indices, labels, or a binary vector. + + ``None`` clears the mask (all hidden). A binary vector of length + ``dimension`` is used directly. A list of labels or 1-based indices + enables only those dimensions. + """ if mask is None: self.setDataMask(np.zeros(self.dimension, dtype=int)) return @@ -336,28 +398,35 @@ def setMask(self, mask: Sequence[int] | Sequence[str] | np.ndarray | None = None self.setMaskByInd(arr.astype(int)) def getTime(self) -> np.ndarray: + """Return a copy of the time vector.""" return self.time.copy() def getData(self) -> np.ndarray: + """Return signal data as a matrix (alias for ``dataToMatrix()``).""" return self.dataToMatrix() def getOriginalData(self) -> tuple[np.ndarray, np.ndarray]: + """Return ``(originalTime, originalData)`` copies.""" return self.originalTime.copy(), self.originalData.copy() def getOrigDataSig(self) -> "SignalObj": + """Return the original (pre-resample) data as a new ``SignalObj``.""" return self._spawn(self.originalTime, self.originalData) def getPlotProps(self, index: int) -> Any: + """Return the plot property for dimension *index* (1-based).""" idx = _coerce_1based_indices([index], self.dimension)[0] - 1 return self.plotProps[idx] def getIndexFromLabel(self, label: str) -> list[int]: + """Return 1-based indices of dimensions whose label equals *label*.""" matches = [i + 1 for i, value in enumerate(self.dataLabels) if value == label] if not matches: raise ValueError("Label does not exist!") return matches def getIndicesFromLabels(self, label: Sequence[str] | str): + """Return 1-based index(es) for one or more data-label strings.""" if isinstance(label, str): matches = self.getIndexFromLabel(label) return matches[0] if len(matches) == 1 else matches @@ -415,6 +484,7 @@ def convertNamesToIndices(self, selectorArray) -> list[int] | np.ndarray: return list(range(1, self.dimension + 1)) def getValueAt(self, x: Sequence[float] | float) -> np.ndarray: + """Return signal value(s) at time(s) *x* via nearest-neighbour lookup.""" query = np.asarray(x, dtype=float).reshape(-1) out = np.zeros((query.size, self.dimension), dtype=float) valid = (query >= self.minTime) & (query <= self.maxTime) @@ -447,6 +517,11 @@ def _selector_to_zero_based(self, selectorArray: Sequence[int] | np.ndarray | No return indices - 1 def dataToMatrix(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> np.ndarray: + """Return signal data as an ``(n, d)`` matrix. + + *selectorArray* is an optional sequence of 1-based dimension + indices. When ``None``, the data mask selects visible dimensions. + """ indices = self._selector_to_zero_based(selectorArray) if indices.size == 0: return np.zeros((self.time.size, 0), dtype=float) @@ -461,6 +536,7 @@ def _plot_props_for_indices(self, zero_based: np.ndarray) -> list[Any]: return [self.plotProps[int(i)] for i in zero_based] def getSubSignalFromInd(self, selectorArray: Sequence[int] | np.ndarray) -> "SignalObj": + """Return a new ``SignalObj`` with only the selected dimensions (1-based).""" indices = self._selector_to_zero_based(selectorArray) return self._spawn( self.time, @@ -470,10 +546,12 @@ def getSubSignalFromInd(self, selectorArray: Sequence[int] | np.ndarray) -> "Sig ) def getSubSignalFromNames(self, labels: Sequence[str] | str) -> "SignalObj": + """Return a sub-signal selected by data-label name(s).""" indices = self.getIndicesFromLabels(labels) return self.getSubSignalFromInd(indices if isinstance(indices, list) else [indices]) def getSubSignal(self, identifier) -> "SignalObj": + """Return a sub-signal selected by labels, indices, or mixed.""" if isinstance(identifier, str): return self.getSubSignalFromNames(identifier) if isinstance(identifier, np.ndarray): @@ -487,6 +565,7 @@ def getSubSignal(self, identifier) -> "SignalObj": return self.getSubSignalFromInd(values) def findNearestTimeIndex(self, time: float) -> int: + """Return the 1-based index of the sample nearest to *time*.""" value = float(time) if value < self.minTime: return 1 @@ -503,9 +582,15 @@ def findNearestTimeIndex(self, time: float) -> int: return left + 1 def findNearestTimeIndices(self, times: Sequence[float] | np.ndarray) -> np.ndarray: + """Return 1-based indices of the samples nearest to each time in *times*.""" return np.asarray([self.findNearestTimeIndex(value) for value in np.asarray(times, dtype=float).reshape(-1)], dtype=int) def setMinTime(self, minTime: float | None = None, holdVals: int = 0) -> None: + """Extend or trim the signal to start at *minTime*. + + If *holdVals* is 1, endpoint values are held when extending; + otherwise the signal is zero-padded. + """ target = self.time[0] if minTime is None else float(minTime) timeVec = self.getTime() if target < float(np.min(timeVec)): @@ -526,6 +611,11 @@ def setMinTime(self, minTime: float | None = None, holdVals: int = 0) -> None: self.minTime = float(np.min(self.time)) def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None: + """Extend or trim the signal to end at *maxTime*. + + If *holdVals* is 1, endpoint values are held when extending; + otherwise the signal is zero-padded. + """ target = self.time[-1] if maxTime is None else float(maxTime) timeVec = self.getTime() if float(np.max(timeVec)) < target: @@ -665,6 +755,12 @@ def getSigInTimeWindow( wMax: Sequence[float] | float | None = None, holdVals: int = 0, ) -> "SignalObj": + """Extract signal within ``[wMin, wMax]``. + + Multiple windows can be specified by passing equal-length sequences + for *wMin* and *wMax*; the extracted segments are concatenated as + additional dimensions (Matlab ``getSigInTimeWindow``). + """ if wMax is None: wMax = self.maxTime if wMin is None: @@ -700,6 +796,10 @@ def getSigInTimeWindow( return windowed if windowed is not None else self.copySignal() def restoreToOriginal(self, rMask: int = 0) -> None: + """Restore time, data, and sample rate to their original values. + + If *rMask* is 1, the data mask is also reset (all visible). + """ self.time = self.originalTime.copy() self.data = self.originalData.copy() self.minTime = float(np.min(self.time)) @@ -709,15 +809,19 @@ def restoreToOriginal(self, rMask: int = 0) -> None: self.resetMask() def resetMask(self) -> None: + """Reset the data mask so all dimensions are visible.""" self.dataMask = np.ones(self.dimension, dtype=int) def findIndFromDataMask(self) -> list[int]: + """Return 1-based indices of dimensions currently visible (mask == 1).""" return [int(index) + 1 for index in np.flatnonzero(self.dataMask == 1)] def isMaskSet(self) -> bool: + """Return ``True`` if any dimension is currently masked out.""" return bool(np.any(self.dataMask == 0)) def abs(self) -> "SignalObj": + """Element-wise absolute value (Matlab ``abs``).""" labels = [f"|{label}|" if label else "" for label in self.dataLabels] return self._spawn(self.time, np.abs(self.data), data_labels=labels).with_metadata( name=f"|{self.name}|", @@ -728,6 +832,7 @@ def __abs__(self) -> "SignalObj": return self.abs() def log(self) -> "SignalObj": + """Element-wise natural logarithm (Matlab ``log``).""" labels = [f"ln({label})" if label else "" for label in self.dataLabels] yunits = f"ln({self.yunits})" if self.yunits else "" return self._spawn(self.time, np.log(self.data), data_labels=labels).with_metadata( @@ -736,6 +841,7 @@ def log(self) -> "SignalObj": ) def with_metadata(self, *, name: str | None = None, xlabelval: str | None = None, xunits: str | None = None, yunits: str | None = None) -> "SignalObj": + """Return a copy with selectively overridden metadata fields.""" out = self.copySignal() if name is not None: out.name = str(name) @@ -748,6 +854,12 @@ def with_metadata(self, *, name: str | None = None, xlabelval: str | None = None return out def median(self, axis: int | None = None) -> "SignalObj": + """Column-wise median (default) or row-wise median of signal data. + + ``median()`` or ``median(0)`` computes the median of each + component across time. ``median(1)`` computes the median value at + each time point across dimensions. + """ axis_arg = 0 if axis is None else axis median_data = np.median(self.data, axis=axis_arg) array = np.asarray(median_data, dtype=float) @@ -762,6 +874,7 @@ def median(self, axis: int | None = None) -> "SignalObj": return self._spawn(self.time, reshaped, data_labels=[f"median({self.name})"]).with_metadata(name=f"median({self.name})") def mode(self, axis: int | None = None) -> "SignalObj": + """Column-wise mode of signal data (Matlab ``mode``).""" axis_arg = 0 if axis is None else axis if axis_arg == 0: mode_data = np.asarray([_matlab_mode_1d(self.data[:, i]) for i in range(self.dimension)], dtype=float) @@ -781,6 +894,12 @@ def mode(self, axis: int | None = None) -> "SignalObj": return self._spawn(self.time, reshaped, data_labels=[f"mode({self.name})"]).with_metadata(name=f"mode({self.name})") def mean(self, axis: int | None = None) -> "SignalObj": + """Column-wise mean (default) or row-wise mean of signal data. + + ``mean()`` or ``mean(0)`` computes the mean of each component + across time. ``mean(1)`` computes the mean value at each time + point across dimensions. + """ axis_arg = 0 if axis is None else axis mean_data = np.mean(self.data, axis=axis_arg) array = np.asarray(mean_data, dtype=float) @@ -795,6 +914,11 @@ def mean(self, axis: int | None = None) -> "SignalObj": return self._spawn(self.time, reshaped, data_labels=[f"\\mu({self.name})"]) def std(self, axis: int | None = None) -> "SignalObj": + """Column-wise standard deviation (sample, ddof=1) of signal data. + + ``std()`` or ``std(0)`` computes std of each component across + time. ``std(1)`` computes std at each time point across dimensions. + """ axis_arg = 0 if axis is None else axis std_data = np.std(self.data, axis=axis_arg, ddof=1) array = np.asarray(std_data, dtype=float) @@ -809,6 +933,7 @@ def std(self, axis: int | None = None) -> "SignalObj": return self._spawn(self.time, reshaped, data_labels=[f"\\sigma({self.name})"]) def max(self, axis: int | None = None): + """Return ``(values, indices, times)`` of column-wise maxima.""" axis_arg = 0 if axis is None else axis values = np.max(self.data, axis=axis_arg) indices = np.argmax(self.data, axis=axis_arg) @@ -816,6 +941,7 @@ def max(self, axis: int | None = None): return values, indices, time def min(self, axis: int | None = None): + """Return ``(values, indices, times)`` of column-wise minima.""" axis_arg = 0 if axis is None else axis values = np.min(self.data, axis=axis_arg) indices = np.argmin(self.data, axis=axis_arg) @@ -823,11 +949,13 @@ def min(self, axis: int | None = None): return values, indices, time def resample(self, sample_rate: float) -> "SignalObj": + """Return a resampled copy at *sample_rate* Hz.""" copied = self.copySignal() copied.resampleMe(sample_rate) return copied def resampleMe(self, newSampleRate: float) -> None: + """Resample data in-place to *newSampleRate* Hz via cubic interpolation.""" try: from scipy.interpolate import interp1d except Exception as exc: # pragma: no cover @@ -877,11 +1005,18 @@ def derivative(self) -> "SignalObj": return self._spawn(self.time, deriv, data_labels=labels) def derivativeAt(self, x0: Sequence[float] | float): + """Return the derivative value(s) at time(s) *x0*.""" deriv = self.derivative values = deriv.getValueAt(x0) return values def integral(self, t0: float | None = None, tf: float | None = None) -> "SignalObj": + """Cumulative integral of the signal from *t0* to *tf*. + + Computed via a causal IIR accumulator: + ``y[n] = y[n-1] + x[n] * deltaT``. If *t0* / *tf* are not + specified, ``minTime`` / ``maxTime`` are used. + """ start = self.minTime if t0 is None else float(t0) stop = self.maxTime if tf is None else float(tf) integrated = self.getSigInTimeWindow(start, stop) @@ -905,6 +1040,10 @@ def integral(self, t0: float | None = None, tf: float | None = None) -> "SignalO return integrated def filter(self, B, A=1) -> "SignalObj": + """Apply a causal IIR/FIR filter ``(B, A)`` to each dimension. + + Equivalent to ``scipy.signal.lfilter(B, A, data)``. + """ try: from scipy.signal import lfilter except Exception as exc: # pragma: no cover @@ -916,6 +1055,10 @@ def filter(self, B, A=1) -> "SignalObj": return self._spawn(self.time, filtered, data_labels=list(self.dataLabels)) def filtfilt(self, B, A=1) -> "SignalObj": + """Apply a zero-phase IIR/FIR filter ``(B, A)`` to each dimension. + + Equivalent to ``scipy.signal.filtfilt(B, A, data)``. + """ try: from scipy.signal import filtfilt except Exception as exc: # pragma: no cover @@ -954,6 +1097,12 @@ def makeCompatible(self, other: "SignalObj", holdVals: int = 0) -> tuple["Signal return s1c, s2c def autocorrelation(self) -> "SignalObj": + """Normalized auto-correlation for each signal dimension. + + Returns a new ``SignalObj`` whose time axis is lag (in the original + x-units) and whose data are the correlation coefficients normalised + to unity at lag zero (Matlab ``autocorrelation``). + """ centered = self.data - np.mean(self.data, axis=0, keepdims=True) columns: list[np.ndarray] = [] lags: np.ndarray | None = None @@ -981,6 +1130,12 @@ def autocorrelation(self) -> "SignalObj": ) def crosscorrelation(self, other: "SignalObj") -> "SignalObj": + """Normalized cross-correlation between two scalar signals. + + Both signals must be one-dimensional. The result is normalised + so that the peak equals the Pearson correlation coefficient + (Matlab ``crosscorrelation``). + """ if self.dimension != 1 or other.dimension != 1: raise ValueError("crosscorrelation only supports one-dimensional signals") s1c, s2c = self.makeCompatible(other) @@ -1005,6 +1160,13 @@ def crosscorrelation(self, other: "SignalObj") -> "SignalObj": ) def xcorr(self, other: "SignalObj" | None = None, maxlag: int | None = None) -> "SignalObj": + """Raw (un-normalised) cross-correlation (Matlab ``xcorr``). + + Computes pairwise cross-correlation for all dimension pairs. + When *other* is ``None`` (auto-correlation), only non-negative + lags are returned. *maxlag* truncates to ``|lag| ≤ maxlag`` + samples. + """ s2 = self if other is None else other s1c, s2c = self.makeCompatible(s2) data_columns: list[np.ndarray] = [] @@ -1551,6 +1713,7 @@ def spectrogram(self, nperseg: int = 256, noverlap: int | None = None, return f, t, Sxx def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: + """Attach ``(lower, upper)`` confidence bounds aligned with time.""" low, high = bounds low_arr = np.asarray(low, dtype=float) high_arr = np.asarray(high, dtype=float) @@ -1559,6 +1722,7 @@ def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: self.conf_interval = (low_arr, high_arr) def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> dict[str, Any]: + """Serialize signal data to a plain dict (Matlab ``dataToStructure``).""" data = self.dataToMatrix(selectorArray) plot_props = list(self.plotProps) if all(prop is None for prop in plot_props): @@ -1575,10 +1739,12 @@ def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = Non } def toStructure(self) -> dict[str, Any]: + """Serialize the full signal to a plain dict (Matlab ``toStructure``).""" return self.dataToStructure() @staticmethod def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": + """Reconstruct a ``SignalObj`` from a dict (Matlab ``signalFromStruct``).""" return SignalObj( structure["time"], structure["data"], @@ -1591,6 +1757,22 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": ) def plot(self, selectorArray=None, plotPropsIn=None, handle=None): + """Plot selected signal dimensions (Matlab ``plot``). + + Parameters + ---------- + selectorArray : optional + Dimension selector (labels, 1-based indices, or ``None`` for all + visible dimensions). + plotPropsIn : optional + Override Matplotlib format strings for each dimension. + handle : matplotlib Axes, optional + Axes to draw into; defaults to ``plt.gca()``. + + Returns + ------- + list of Line2D + """ import matplotlib.pyplot as plt from .confidence_interval import MATLAB_COLOR_ORDER @@ -1626,7 +1808,26 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): class Covariate(SignalObj): - """MATLAB-style covariate signal with CI and zero-mean views.""" + """Signal with per-dimension confidence intervals (Matlab ``Covariate``). + + ``Covariate`` extends :class:`SignalObj` with a list of + :class:`~nstat.confidence_interval.ConfidenceInterval` objects (one per + dimension) and propagates those intervals through ``+`` and ``-`` + arithmetic. It also provides ``'zero-mean'`` and ``'standard'`` + signal representations used by the GLM design-matrix builder. + + Parameters + ---------- + *args, **kwargs + Forwarded to :class:`SignalObj`. The keyword aliases ``values`` + (→ ``data``) and ``units`` (→ ``yunits``) are accepted for + convenience. + + See Also + -------- + SignalObj : Base time-series container. + ConfidenceInterval : CI storage class used by ``ci``. + """ def __init__(self, *args, **kwargs) -> None: if "values" in kwargs and "data" not in kwargs: @@ -1638,13 +1839,21 @@ def __init__(self, *args, **kwargs) -> None: @property def mu(self) -> SignalObj: + """Column-wise mean as a ``SignalObj`` (Matlab ``mu`` property).""" return self.mean() @property def sigma(self) -> SignalObj: + """Column-wise standard deviation as a ``SignalObj`` (Matlab ``sigma``).""" return self.std() def computeMeanPlusCI(self, alphaVal: float = 0.05) -> "Covariate": + """Compute row-wise mean with empirical confidence intervals. + + Treats each column as a replicate. Returns a scalar ``Covariate`` + whose CI bounds are the *alphaVal*/2 and 1−*alphaVal*/2 quantiles + of the empirical CDF across replicates (Matlab ``computeMeanPlusCI``). + """ from .confidence_interval import ConfidenceInterval sorted_data = np.sort(self.data, axis=1) @@ -1707,6 +1916,7 @@ def getSigRep(self, repType: str = "standard") -> "Covariate": raise ValueError("repType must be either 'zero-mean' or 'standard'") def plot(self, selectorArray=None, plotPropsIn=None, handle=None): + """Plot signal dimensions with shaded confidence intervals.""" lines = super().plot(selectorArray, plotPropsIn, handle) if self.isConfIntervalSet(): import matplotlib.pyplot as plt @@ -1728,15 +1938,18 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): return lines def isConfIntervalSet(self) -> bool: + """Return ``True`` if at least one dimension has a CI attached.""" return bool(self.ci) def setConfInterval(self, ciObj) -> None: + """Attach one or more ``ConfidenceInterval`` objects to this covariate.""" if isinstance(ciObj, list): self.ci = list(ciObj) else: self.ci = [ciObj] def copySignal(self) -> "Covariate": + """Deep-copy including confidence intervals (Matlab ``copySignal``).""" copied = Covariate( self.time.copy(), self.data.copy(), @@ -1763,6 +1976,7 @@ def copySignal(self) -> "Covariate": return copied def getSubSignal(self, identifier) -> "Covariate": + """Return a sub-covariate preserving matching CIs.""" sub = super().getSubSignal(identifier) cov = Covariate( sub.time, @@ -1788,6 +2002,7 @@ def getSubSignal(self, identifier) -> "Covariate": return cov def __add__(self, other): + """Add two covariates, propagating confidence intervals.""" covOut = super().__add__(other) if isinstance(other, Covariate): if self.isConfIntervalSet() and not other.isConfIntervalSet(): @@ -1799,6 +2014,7 @@ def __add__(self, other): return covOut def __sub__(self, other): + """Subtract two covariates, propagating confidence intervals.""" covOut = super().__sub__(other) if isinstance(other, Covariate): if self.isConfIntervalSet() and not other.isConfIntervalSet(): @@ -1810,6 +2026,7 @@ def __sub__(self, other): return covOut def toStructure(self) -> dict[str, Any]: + """Serialize to a dict, including CI payload if present.""" structure = super().toStructure() if self.isConfIntervalSet(): ci_payload: list[dict[str, Any]] = [] @@ -1822,6 +2039,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "Covariate": + """Reconstruct a ``Covariate`` (with optional CIs) from a dict.""" from .confidence_interval import ConfidenceInterval cov = Covariate( @@ -1847,7 +2065,36 @@ def fromStructure(structure: dict[str, Any]) -> "Covariate": class nspikeTrain: - """Closer MATLAB-style spike-train object with cached signal representation.""" + """Point-process (spike train) object (Matlab ``nspikeTrain``). + + Stores an array of event times (spikes) and converts them on demand + into a binned ``SignalObj`` signal representation (``sigRep``). Burst + statistics, ISI analysis, and raster plotting are built in. + + Parameters + ---------- + spikeTimes : array_like + Spike times in seconds. + name : str, optional + Neuron / channel label. + binwidth : float, optional + Bin width in seconds for the signal representation (default 1 ms). + minTime, maxTime : float, optional + Observation window. Defaults to ``min/max(spikeTimes)``. + xlabelval, xunits, yunits : str, optional + Axis label and unit strings. + dataLabels : str or sequence of str, optional + Label(s) for the spike-train dimension. + makePlots : int, optional + ``0`` — compute statistics silently (default); + ``1`` — compute and plot; + ``< 0`` — skip statistics entirely (fast construction). + + See Also + -------- + SignalObj : Continuous time-series container returned by ``getSigRep``. + SpikeTrainCollection : Multi-neuron collection. + """ def __init__( self, @@ -1905,30 +2152,37 @@ def __init__( @property def times(self) -> np.ndarray: + """Alias for ``spikeTimes``.""" return self.spikeTimes @property def n_spikes(self) -> int: + """Number of spikes in the train.""" return int(self.spikeTimes.size) @property def duration(self) -> float: + """Observation window duration ``maxTime − minTime`` in seconds.""" return float(self.maxTime - self.minTime) @property def firing_rate_hz(self) -> float: + """Average firing rate (spikes / duration) in Hz.""" if self.duration <= 0: return 0.0 return float(self.n_spikes / self.duration) def setMER(self, MERSig: SignalObj) -> None: + """Attach a micro-electrode recording signal to this spike train.""" if isinstance(MERSig, SignalObj): self.MER = MERSig def setName(self, name: str) -> None: + """Set the neuron / channel name.""" self.name = str(name) def computeStatistics(self, makePlots: int = 0) -> None: + """Compute ISI, burst, and regularity statistics (Matlab ``computeStatistics``).""" self.avgFiringRate = self.firing_rate_hz isi = self.getISIs() # Filter spike times to [minTime, maxTime] so burst statistics @@ -1998,6 +2252,7 @@ def computeStatistics(self, makePlots: int = 0) -> None: self.plot() def getLStatistic(self) -> float: + """Return the L-statistic (number of unique bin counts in ``sigRep``).""" isi = self.getISIs() if isi.size == 0: return np.nan @@ -2065,6 +2320,7 @@ def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> Sign return sig def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj: + """Build the binned signal representation and store it in-place.""" sig = self.getSigRep(binwidth, minTime, maxTime) self.sigRep = sig.copySignal() self.sampleRate = float(sig.sampleRate) @@ -2077,38 +2333,45 @@ def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, return self.sigRep def clearSigRep(self) -> None: + """Invalidate the cached signal representation.""" self.sigRep = None self._sigrep_cache_key = None self.isSigRepBin = None def setMinTime(self, minTime: float) -> None: + """Set the observation-window start and recompute statistics.""" self.minTime = float(minTime) self.clearSigRep() self.computeStatistics(0) def setMaxTime(self, maxTime: float) -> None: + """Set the observation-window end and recompute statistics.""" self.maxTime = float(maxTime) self.clearSigRep() self.computeStatistics(0) def resample(self, sampleRate: float) -> "nspikeTrain": + """Rebuild the signal representation at *sampleRate* Hz.""" self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime) self.sampleRate = float(sampleRate) return self def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + """Return spike times within ``[minTime, maxTime]``.""" start = self.minTime if minTime is None else float(minTime) stop = self.maxTime if maxTime is None else float(maxTime) spikes = self.spikeTimes[(self.spikeTimes >= start) & (self.spikeTimes <= stop)] return spikes.copy() def getISIs(self, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + """Return inter-spike intervals within the given time window.""" spikes = self.getSpikeTimes(minTime, maxTime) if spikes.size < 2: return np.array([], dtype=float) return np.diff(spikes) def getMinISI(self, minTime: float | None = None, maxTime: float | None = None) -> float: + """Return the minimum ISI (refractory period estimate).""" isi = self.getISIs(minTime, maxTime) if isi.size == 0: return float("nan") @@ -2120,6 +2383,11 @@ def getSigRep( minTime: float | None = None, maxTime: float | None = None, ) -> SignalObj: + """Return the binned signal representation, using cache when possible. + + The result is a ``SignalObj`` of spike counts on a regular grid + with bin width *binwidth* (default ``1/sampleRate``). + """ bw = (1.0 / self.sampleRate) if binwidth is None else float(binwidth) start = self.minTime if minTime is None else float(minTime) stop = self.maxTime if maxTime is None else float(maxTime) @@ -2132,18 +2400,21 @@ def getSigRep( return sig def getMaxBinSizeBinary(self) -> float: + """Return the largest bin width that keeps the ``sigRep`` binary.""" isi = self.getISIs() if isi.size == 0: return np.inf return float(np.min(isi)) def isSigRepBinary(self) -> bool: + """Return ``True`` if every bin in the default ``sigRep`` has ≤ 1 spike.""" default_key = self._cache_key(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) if self._sigrep_cache_key != default_key or self.isSigRepBin is None: self.getSigRep(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) return bool(self.isSigRepBin) def computeRate(self) -> SignalObj: + """Return firing rate ``sigRep × sampleRate`` in spikes/sec.""" sig = self.getSigRep() if self.sampleRate <= 0: return sig @@ -2151,6 +2422,7 @@ def computeRate(self) -> SignalObj: return SignalObj(sig.time, rate, self.name, sig.xlabelval, sig.xunits, "spikes/sec", sig.dataLabels) def restoreToOriginal(self) -> None: + """Reset spike times and time bounds to original values.""" self.spikeTimes = self.originalSpikeTimes.copy() self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 @@ -2163,6 +2435,21 @@ def partitionNST( lbound: float | None = None, ubound: float | None = None, ): + """Partition into per-trial spike trains (Matlab ``partitionNST``). + + Parameters + ---------- + windowTimes : sequence of float + Edge times defining trial boundaries (N edges → N−1 trials). + normalizeTime : bool, optional + If ``True``, rescale each trial's spikes to [0, 1]. + lbound, ubound : float, optional + Accept only windows whose duration falls in ``[lbound, ubound]``. + + Returns + ------- + nstColl + """ from .nstColl import nstColl windows = np.asarray(windowTimes, dtype=float).reshape(-1) @@ -2195,9 +2482,11 @@ def partitionNST( return coll def getFieldVal(self, fieldName: str): + """Return the value of attribute *fieldName* (Matlab ``getFieldVal``).""" return getattr(self, fieldName, []) def plotISISpectrumFunction(self): + """Plot ISI vs. time (Matlab ``plotISISpectrumFunction``).""" import matplotlib.pyplot as plt fig, ax = plt.subplots(1, 1, figsize=(6.0, 3.5)) @@ -2211,6 +2500,7 @@ def plotISISpectrumFunction(self): return line def plotJointISIHistogram(self): + """Joint ISI scatter plot: ISI(t) vs ISI(t+1) on log-log axes.""" import matplotlib.pyplot as plt ax = plt.subplots(1, 1, figsize=(4.5, 4.0))[1] @@ -2293,6 +2583,7 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = return counts def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = None, handle=None): + """Exponential probability plot of ISIs (Matlab ``plotProbPlot``).""" import matplotlib.pyplot as plt ax = plt.gca() if handle is None else handle @@ -2311,6 +2602,7 @@ def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = Non return ax def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): + """ISI histogram + exponential prob-plot side by side.""" import matplotlib.pyplot as plt fig = handle if handle is not None else plt.figure(figsize=(10.0, 4.0)) @@ -2322,6 +2614,17 @@ def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None return fig def plot(self, dHeight: float = 1.0, yOffset: float = 0.5, currentHandle=None, handle=None): + """Raster plot: vertical tick per spike (Matlab ``plot``). + + Parameters + ---------- + dHeight : float + Tick height (default 1.0). + yOffset : float + Vertical centre of ticks (default 0.5). + currentHandle, handle : matplotlib Axes, optional + Axes to draw into. + """ import matplotlib.pyplot as plt ax = plt.gca() if (currentHandle is None and handle is None) else (currentHandle or handle) @@ -2362,11 +2665,13 @@ def nstCopy(self) -> "nspikeTrain": ) def to_binned_counts(self, bin_edges: Sequence[float]) -> np.ndarray: + """Histogram spike times into *bin_edges* and return count vector.""" edges = np.asarray(bin_edges, dtype=float).reshape(-1) counts, _ = np.histogram(self.spikeTimes, bins=edges) return counts.astype(float) def toStructure(self) -> dict[str, Any]: + """Serialize to a plain dict (Matlab ``toStructure``).""" return { "spikeTimes": self.spikeTimes.tolist(), "name": self.name, @@ -2381,6 +2686,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "nspikeTrain": + """Reconstruct an ``nspikeTrain`` from a dict.""" sampleRate = float(structure.get("sampleRate", 1000.0)) binwidth = 1.0 / sampleRate if sampleRate > 0 else 0.001 return nspikeTrain( diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index c8544f26..72cc3abe 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -346,8 +346,30 @@ def _ztest_pvalue(param: float, se: float) -> float: return float(2.0 * norm.sf(np.abs(z))) class DecodingAlgorithms: + """Static-method library for neural decoding and state-space estimation. + + Provides Kalman filtering/smoothing, point-process adaptive filtering + (PPAF), hybrid discrete–continuous decoding, unscented Kalman filtering + (UKF), state-space GLM EM algorithms (SSGLM), and mixed point-process / + continuous-observation (mPPCO) EM algorithms. + + All methods are ``@staticmethod``; no instance is required. Method + signatures follow the Matlab ``DecodingAlgorithms`` class as closely + as possible. + + See Also + -------- + CIF : Conditional intensity function objects used by ``PPDecodeFilter``. + Analysis : High-level fitting routines that call these decoders. + """ + @staticmethod def linear_decode(spike_counts: np.ndarray, stimulus: np.ndarray) -> dict[str, np.ndarray]: + """Ordinary-least-squares linear decoder (spike counts → stimulus). + + Returns a dict with keys ``'coefficients'``, ``'decoded'``, + ``'residual'``, and ``'ci'`` (95 % confidence band). + """ x = np.asarray(spike_counts, dtype=float) y = np.asarray(stimulus, dtype=float).reshape(-1) if x.ndim == 1: @@ -537,6 +559,7 @@ def _sel(M, n): @staticmethod def kalman_predict(x_u, Pe_u, A, Pv, GnConv=None): + """Kalman filter predict step: ``x_p = A x_u``, ``Pe_p = A Pe A' + Pv``.""" x_vec = np.asarray(x_u, dtype=float).reshape(-1) dim = x_vec.size A_mat = _as_state_matrix(A, dim) @@ -551,6 +574,7 @@ def kalman_predict(x_u, Pe_u, A, Pv, GnConv=None): @staticmethod def kalman_update(x_p, Pe_p, C, Pw, y, GnConv=None): + """Kalman filter update step: incorporate observation *y* and return ``(x_u, Pe_u, G)``.""" x_vec = np.asarray(x_p, dtype=float).reshape(-1) dim = x_vec.size C_mat = np.asarray(C, dtype=float) @@ -591,6 +615,11 @@ def _state_history_time_major(x, P): @staticmethod def kalman_smootherFromFiltered(A, x_p, Pe_p, x_u, Pe_u): + """RTS backward smoother from precomputed filter estimates. + + Returns ``(x_N, P_N, Ln)`` — smoothed states, covariances, and + backward Kalman gains. + """ x_p_tm, Pe_p_tm, predicted_transposed = DecodingAlgorithms._state_history_time_major(x_p, Pe_p) x_u_tm, Pe_u_tm, updated_transposed = DecodingAlgorithms._state_history_time_major(x_u, Pe_u) if predicted_transposed != updated_transposed: @@ -614,6 +643,10 @@ def kalman_smootherFromFiltered(A, x_p, Pe_p, x_u, Pe_u): @staticmethod def kalman_smoother(A, C, Pv, Pw, Px0, x0, y): + """Run a Kalman filter followed by an RTS smoother. + + Returns ``(x_N, P_N, Ln, x_p, Pe_p, x_u, Pe_u)``. + """ observations = np.asarray(y, dtype=float) if observations.ndim == 1: observations = observations[:, None] @@ -1191,6 +1224,10 @@ def computeSpikeRateDiffCIs( @staticmethod def PPDecode_predict(x_u, W_u, A, Q, Wconv=None): + """Point-process adaptive filter predict step. + + Returns ``(x_p, W_p)`` — predicted state and covariance. + """ x_vec = np.asarray(x_u, dtype=float).reshape(-1) dim = x_vec.size W_mat = _as_state_matrix(W_u, dim) @@ -1210,6 +1247,11 @@ def PPDecode_predict(x_u, W_u, A, Q, Wconv=None): @staticmethod def PPDecode_update(x_p, W_p, dN, lambdaIn, binwidth=0.001, time_index=1, WuConv=None): + """Point-process adaptive filter update step using CIF objects. + + Evaluates symbolic CIF gradients and Jacobians for the + Newton-step posterior update. Returns ``(x_u, W_u, lambda_delta)``. + """ x_vec = np.asarray(x_p, dtype=float).reshape(-1) W_mat = _as_state_matrix(W_p, x_vec.size) obs = _as_observation_matrix(dN) @@ -1263,6 +1305,11 @@ def PPDecode_update(x_p, W_p, dN, lambdaIn, binwidth=0.001, time_index=1, WuConv @staticmethod def PPDecode_updateLinear(x_p, W_p, dN, mu, beta, fitType="poisson", gamma=None, HkAll=None, time_index=1, WuConv=None): + """Point-process adaptive filter update step using linear parameters. + + Uses ``mu``, ``beta``, and optional ``gamma`` history coefficients + instead of CIF objects. Returns ``(x_u, W_u, lambda_delta)``. + """ x_vec = np.asarray(x_p, dtype=float).reshape(-1) W_mat = _as_state_matrix(W_p, x_vec.size) obs = _as_observation_matrix(dN) @@ -1542,6 +1589,12 @@ def _ppdecode_filter_linear( @staticmethod def PPDecodeFilterLinear(*args, **kwargs): + """Point-process adaptive filter using linear GLM parameters. + + Dispatches to ``_ppdecode_filter_linear`` when a ``fitType`` string + is present, otherwise falls back to ``kalman_filter``. Matches the + Matlab ``DecodingAlgorithms.PPDecodeFilterLinear`` signature. + """ if len(args) >= 6 and isinstance(args[5], str): return DecodingAlgorithms._ppdecode_filter_linear(*args, **kwargs) if "fitType" in kwargs or "delta" in kwargs: @@ -1550,6 +1603,14 @@ def PPDecodeFilterLinear(*args, **kwargs): @staticmethod def PPDecodeFilter(A, Q, Px0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, Wconv=None): + """Point-process adaptive filter using CIF object collection. + + Runs the full forward filter loop, evaluating CIF objects at each + time step. When *yT* / *PiT* are supplied, delegates to the + linear-parameter variant with goal-state estimation. + + Returns ``(x_p, W_p, x_u, W_u, xT, WT, MT, MT_cov)``. + """ obs = _as_observation_matrix(dN) lambda_items = _normalize_cif_collection(lambdaCIFColl) num_cells, num_steps = obs.shape @@ -1621,6 +1682,12 @@ def PPDecodeFilter(A, Q, Px0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=No @staticmethod def PP_fixedIntervalSmoother(A, Q, dN, lags, mu, beta, fitType="poisson", delta=0.001, gamma=None, windowTimes=None, x0=None, Pi0=None): + """Point-process fixed-interval (fixed-lag) smoother. + + Runs ``PPDecode_updateLinear`` forward, then applies backward + RTS-style smoothing at each step with the specified number of + *lags*. Returns ``(x_pLag, W_pLag, x_uLag, W_uLag)``. + """ obs = _as_observation_matrix(dN) num_cells, num_steps = obs.shape num_states = _infer_state_dim(A, beta, num_cells) @@ -1722,6 +1789,14 @@ def PPHybridFilterLinear( estimateTarget=0, MinClassificationError=0, ): + """Hybrid point-process filter with discrete-mode switching (linear parameters). + + Combines multiple linear state-space models (one per discrete mode) + with a Markov transition matrix *p_ij*. At each time step, per-mode + PPAF updates are merged using posterior mode probabilities. + + Returns ``(x_p, W_p, x_u, W_u, xT, WT, S_est, X_est, W_est)``. + """ obs = _as_observation_matrix(dN) A_models = list(A) if isinstance(A, Sequence) and not isinstance(A, np.ndarray) else [A] Q_models = list(Q) if isinstance(Q, Sequence) and not isinstance(Q, np.ndarray) else [Q] diff --git a/nstat/events.py b/nstat/events.py index bcbbec69..15c3a9e2 100644 --- a/nstat/events.py +++ b/nstat/events.py @@ -10,7 +10,22 @@ class Events: - """MATLAB-style event container.""" + """Experimental event markers for highlighting epochs in figures. + + Events represent times of importance during an experiment (e.g. + stimulus onset, trial boundaries) that are overlaid on raster or + signal plots. + + Parameters + ---------- + eventTimes : array_like + Vector of event times (seconds). + eventLabels : sequence of str or None + Labels for each event. Must match the length of *eventTimes* + when provided. + eventColor : str, default ``'r'`` + Colour string for the event lines (Matlab-style colour codes). + """ def __init__(self, eventTimes, eventLabels: Sequence[str] | None = None, eventColor: str = "r") -> None: times = np.asarray(eventTimes, dtype=float).reshape(-1) @@ -27,6 +42,7 @@ def __init__(self, eventTimes, eventLabels: Sequence[str] | None = None, eventCo self.labels = self.eventLabels def toStructure(self) -> dict[str, Any]: + """Serialise the Events to a plain dictionary.""" return { "eventTimes": self.eventTimes.tolist(), "eventLabels": list(self.eventLabels), @@ -35,6 +51,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any] | None) -> "Events" | None: + """Reconstruct Events from a dictionary (inverse of :meth:`toStructure`).""" if structure is None: return None event_times = structure.get("eventTimes", structure.get("event_times", [])) diff --git a/nstat/fit.py b/nstat/fit.py index 535d9dd7..bff0d873 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -373,7 +373,28 @@ class _SingleFit: class FitResult: - """MATLAB-facing fit result container with Python compatibility aliases.""" + """GLM fit results for one neuron across one or more model configs (Matlab ``FitResult``). + + Stores coefficients, deviance, AIC/BIC, log-likelihood, fitted λ signal, + and KS-test diagnostics for each configuration in a + :class:`~nstat.trial.ConfigCollection`. Provides coefficient accessors, + residual analysis, and Matlab-compatible plot methods. + + Parameters + ---------- + neuralSpikeTrain : nspikeTrain or sequence of nspikeTrain + The observed spike train(s) that were fitted. + *args, **kwargs + Positional / keyword construction matching the Matlab + ``FitResult(nst, covLabels, numHist, …)`` signature, or + the simplified ``FitResult(nst, lambdaCov, fits)`` form. + + See Also + -------- + FitSummary : Aggregate summary across multiple neurons. + Analysis.RunAnalysisForAllNeurons : Main entry point that produces + ``FitResult`` objects. + """ def __init__(self, neuralSpikeTrain: nspikeTrain | Sequence[nspikeTrain], *args, **kwargs) -> None: if args and isinstance(args[0], Covariate): @@ -605,6 +626,7 @@ def __getattr__(self, name: str): raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") def setNeuronName(self, name: str): + """Rename the neuron on the underlying spike train(s).""" if isinstance(self.neuralSpikeTrain, nspikeTrain): self.neuralSpikeTrain.setName(str(name)) elif isinstance(self.neuralSpikeTrain, Sequence): @@ -615,6 +637,7 @@ def setNeuronName(self, name: str): return self def mapCovLabelsToUniqueLabels(self): + """Rebuild the unique-label map and ``flatMask`` from ``covLabels``.""" self.uniqueCovLabels = _ordered_unique([label for labels in self.covLabels for label in labels]) self.indicesToUniqueLabels = [] self.flatMask = np.zeros((len(self.uniqueCovLabels), max(len(self.covLabels), 1)), dtype=int) @@ -626,6 +649,7 @@ def mapCovLabelsToUniqueLabels(self): return self def getSubsetFitResult(self, subfits) -> "FitResult": + """Return a new ``FitResult`` with only the selected fit indices (1-based).""" indices = np.asarray(subfits if isinstance(subfits, Sequence) and not isinstance(subfits, (str, bytes)) else [subfits], dtype=int).reshape(-1) zero_based = [int(idx) - 1 for idx in indices] from .trial import ConfigCollection @@ -656,6 +680,7 @@ def getSubsetFitResult(self, subfits) -> "FitResult": return subset def addParamsToFit(self, neuronNum, lambda_signal, b, dev, stats, AIC, BIC, logLL, configColl): + """Append a new fit configuration's results (Matlab ``addParamsToFit``).""" del neuronNum merged = self.mergeResults( FitResult( @@ -717,6 +742,7 @@ def getHistCoeffsWithLabels(self, fit_num: int = 1) -> tuple[np.ndarray, list[st return coeffs[-num_hist:], labels[-num_hist:], se[-num_hist:] def getCoeffIndex(self, fit_num: int = 1, sortByEpoch: int = 0): + """Return ``(indices, epochIds, numEpochs)`` for non-history coefficients.""" del sortByEpoch labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [] num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 @@ -726,6 +752,7 @@ def getCoeffIndex(self, fit_num: int = 1, sortByEpoch: int = 0): return coeff_index, epoch_id, 0 def getHistIndex(self, fit_num: int = 1, sortByEpoch: int = 0): + """Return ``(indices, epochIds, numEpochs)`` for history coefficients.""" del sortByEpoch labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [] num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 @@ -737,6 +764,7 @@ def getHistIndex(self, fit_num: int = 1, sortByEpoch: int = 0): return hist_index, epoch_id, 0 def getParam(self, paramNames, fit_num: int = 1): + """Return ``(coeffs, SE, significance)`` for named parameters.""" names = [paramNames] if isinstance(paramNames, str) else list(paramNames) coeffs, labels, se = self.getCoeffsWithLabels(fit_num) sig = _extract_significance_mask(self.stats[fit_num - 1] if fit_num - 1 < len(self.stats) else None, coeffs, se) @@ -744,6 +772,7 @@ def getParam(self, paramNames, fit_num: int = 1): return coeffs[indices], se[indices], sig[indices] def getCoeffsWithLabels(self, fit_num: int = 1) -> tuple[np.ndarray, list[str], np.ndarray]: + """Return ``(coefficients, labels, standardErrors)`` for *fit_num*.""" coeffs = self._rawCoeffs(fit_num) labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [f"b_{idx + 1}" for idx in range(coeffs.size)] if coeffs.size == len(labels) + 1: @@ -754,6 +783,7 @@ def getCoeffsWithLabels(self, fit_num: int = 1) -> tuple[np.ndarray, list[str], return coeffs, labels, se def computePlotParams(self, fit_num: int | None = None): + """Compute the aligned coefficient / SE / significance arrays for plotting.""" del fit_num if not self.uniqueCovLabels: self.mapCovLabelsToUniqueLabels() @@ -782,9 +812,11 @@ def computePlotParams(self, fit_num: int | None = None): return self.plotParams def getPlotParams(self): + """Alias for :meth:`computePlotParams`.""" return self.computePlotParams() def isValDataPresent(self) -> bool: + """Return ``True`` if cross-validation data was stored.""" if not self.XvalTime or not self.XvalData: return False for time in self.XvalTime: @@ -794,11 +826,13 @@ def isValDataPresent(self) -> bool: return False def plotValidation(self): + """Plot validation fit results (if present).""" if self.validation is not None: return self.validation.plotResults() return None def mergeResults(self, other: "FitResult") -> "FitResult": + """Concatenate another ``FitResult``'s configs into this one.""" from .trial import ConfigCollection if isinstance(self.lambda_signal, Covariate) and isinstance(other.lambda_signal, Covariate): @@ -941,6 +975,7 @@ def _compute_diagnostics(self, fit_num: int = 1, *, dt_correction: int = 1) -> d return diagnostics def computeKSStats(self, fit_num: int = 1, *, dt_correction: int = 1) -> dict[str, float]: + """Return KS statistic, p-value, and within-CI flag for *fit_num*.""" diag = self._compute_diagnostics(fit_num, dt_correction=dt_correction) return { "ks_stat": float(diag["ks_stat"]), @@ -949,9 +984,11 @@ def computeKSStats(self, fit_num: int = 1, *, dt_correction: int = 1) -> dict[st } def computeInvGausTrans(self, fit_num: int = 1) -> np.ndarray: + """Return Gaussianized (inverse-normal-transformed) rescaled ISIs.""" return np.asarray(self._compute_diagnostics(fit_num)["gaussianized"], dtype=float) def computeFitResidual(self, fit_num: int = 1, *, windowSize: float | None = None) -> Covariate: + """Compute the martingale residual M(t) = N(t) − Λ(t) (Matlab ``computeFitResidual``).""" time, rate_hz = self._lambda_series(fit_num) if time.size == 0: residual = Covariate([], [], "M(t_k)", "time", "s", "counts/bin", ["residual"]) @@ -1006,6 +1043,7 @@ def computeFitResidual(self, fit_num: int = 1, *, windowSize: float | None = Non return residual def evalLambda(self, fit_num: int = 1, newData=None) -> np.ndarray: + """Evaluate λ(t) = exp(X·β) · sampleRate on *newData* (Matlab ``evalLambda``).""" coeffs = self._rawCoeffs(fit_num) x = np.asarray(newData if newData is not None else [], dtype=float) if x.ndim == 0: @@ -1107,6 +1145,7 @@ def plotResults(self, fit_num: int = 1, handle=None): return fig def KSPlot(self, fit_num: int = 1, handle=None): + """KS goodness-of-fit plot with 95 % confidence bands (Matlab ``KSPlot``).""" diag = self._compute_diagnostics(fit_num) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 4.0))[1] ideal = np.asarray(diag["ks_ideal"], dtype=float) @@ -1125,6 +1164,7 @@ def KSPlot(self, fit_num: int = 1, handle=None): return ax def plotResidual(self, fit_num: int = 1, handle=None): + """Plot the martingale residual M(t) (Matlab ``plotResidual``).""" ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] residual = self.computeFitResidual(fit_num) ax.plot(np.asarray(residual.time, dtype=float), np.asarray(residual.data[:, 0], dtype=float), color="tab:purple", linewidth=1.0) @@ -1213,6 +1253,7 @@ def plotCoeffs(self, fit_num: int = 1, handle=None, plotSignificance: int = 1): return ax def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): + """Plot non-history (stimulus/baseline) coefficients only.""" del sortByEpoch, plotSignificance coeffs, labels, _ = self.getCoeffsWithLabels(fit_num) num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 @@ -1229,6 +1270,7 @@ def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotS return ax def plotHistCoeffs(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): + """Plot history-filter coefficients (Matlab ``plotHistCoeffs``).""" del sortByEpoch, plotSignificance coeffs, labels, _se = self.getHistCoeffsWithLabels(fit_num) if not labels: @@ -1244,6 +1286,7 @@ def plotHistCoeffs(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificanc return ax def setKSStats(self, Z, U, xAxis, KSSorted, ks_stat): + """Store pre-computed KS-test arrays (Matlab ``setKSStats``).""" self.Z = np.asarray(Z, dtype=float) self.U = np.asarray(U, dtype=float) self.KSXAxis = np.asarray(xAxis, dtype=float) @@ -1268,14 +1311,17 @@ def setKSStats(self, Z, U, xAxis, KSSorted, ks_stat): return self def setInvGausStats(self, X, rhoSig, confBoundSig): + """Store pre-computed inverse-Gaussian transform statistics.""" self.invGausStats = {"X": np.asarray(X, dtype=float), "rhoSig": rhoSig, "confBoundSig": confBoundSig} return self def setFitResidual(self, M): + """Store the pre-computed fit residual ``Covariate``.""" self.Residual = M return self def toStructure(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict (Matlab ``toStructure``).""" return { "covLabels": [list(labels) for labels in self.covLabels], "numHist": list(self.numHist), @@ -1318,6 +1364,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "FitResult": + """Reconstruct a ``FitResult`` from a dict.""" from .trial import ConfigCollection, TrialConfig spike_times = structure["neural_spike_times"] @@ -1360,11 +1407,26 @@ def fromStructure(structure: dict[str, Any]) -> "FitResult": @staticmethod def CellArrayToStructure(fitResObjCell): + """Serialize a list of ``FitResult`` objects to a list of dicts.""" return [fit.toStructure() for fit in fitResObjCell] class FitSummary: - """Cross-fit summary statistics for one or more FitResult objects.""" + """Population-level summary across multiple neurons (Matlab ``FitResSummary``). + + Aggregates AIC, BIC, log-likelihood, KS statistics, and coefficients + from a collection of :class:`FitResult` objects, providing box-plots, + coefficient histograms, and 2-D/3-D coefficient surfaces. + + Parameters + ---------- + fit_results : FitResult or iterable of FitResult + One or more per-neuron fit results to summarise. + + See Also + -------- + FitResult : Per-neuron fit container. + """ def __init__(self, fit_results: FitResult | Iterable[FitResult]) -> None: if isinstance(fit_results, FitResult): @@ -1398,35 +1460,41 @@ def __init__(self, fit_results: FitResult | Iterable[FitResult]) -> None: self.mapCovLabelsToUniqueLabels() def getDiffAIC(self, idx: int = 1) -> np.ndarray: + """Return ΔAIC relative to config *idx* (1-based).""" if self.numResults > 1: keep = [col for col in range(self.AIC.shape[1]) if col != (idx - 1)] return self.AIC[:, keep] - self.AIC[:, [idx - 1]] return self.AIC.copy() def getDiffBIC(self, idx: int = 1) -> np.ndarray: + """Return ΔBIC relative to config *idx* (1-based).""" if self.numResults > 1: keep = [col for col in range(self.BIC.shape[1]) if col != (idx - 1)] return self.BIC[:, keep] - self.BIC[:, [idx - 1]] return self.BIC.copy() def getDifflogLL(self, idx: int = 1) -> np.ndarray: + """Return Δlog-likelihood relative to config *idx* (1-based).""" if self.numResults > 1: keep = [col for col in range(self.logLL.shape[1]) if col != (idx - 1)] return self.logLL[:, keep] - self.logLL[:, [idx - 1]] return self.logLL.copy() def mapCovLabelsToUniqueLabels(self): + """Rebuild the union of covariate labels across all neurons.""" self.uniqueCovLabels = _ordered_unique( [label for fit in self.fitResCell for labels in fit.covLabels for label in labels] ) return self.uniqueCovLabels def setCoeffRange(self, minVal, maxVal): + """Set the coefficient range used by ``binCoeffs``.""" self.coeffMin = float(minVal) self.coeffMax = float(maxVal) return self def getCoeffs(self, fitNum: int = 1): + """Return ``(coeffMat, labels, seMat)`` aligned to unique labels.""" labels = self.uniqueCovLabels coeff_rows = [] se_rows = [] @@ -1444,6 +1512,7 @@ def getCoeffs(self, fitNum: int = 1): return np.asarray(coeff_rows, dtype=float), labels, np.asarray(se_rows, dtype=float) def getHistCoeffs(self, fitNum: int = 1): + """Return ``(histMat, labels, seMat)`` for history coefficients.""" labels = _ordered_unique( [label for fit in self.fitResCell for label in fit.covLabels[fitNum - 1][-int(fit.numHist[fitNum - 1]) :] if fitNum - 1 < len(fit.covLabels) and int(fit.numHist[fitNum - 1]) > 0] ) @@ -1469,6 +1538,7 @@ def getHistCoeffs(self, fitNum: int = 1): return np.asarray(coeff_rows, dtype=float), labels, np.asarray(se_rows, dtype=float) def getSigCoeffs(self, fitNum: int = 1): + """Return (nNeurons × nCov) binary significance matrix.""" coeff_mat, labels, se_mat = self.getCoeffs(fitNum) sig = np.zeros_like(coeff_mat, dtype=float) for row_idx, fit in enumerate(self.fitResCell): @@ -1522,6 +1592,7 @@ def binCoeffs(self, minVal=-12.0, maxVal=12.0, binSize=0.1): return N, edges, percentSig def plotIC(self, handle=None): + """Plot AIC, BIC, and log-likelihood box-plots side by side.""" fig = handle if handle is not None else plt.figure(figsize=(9.0, 3.5)) fig.clear() axes = fig.subplots(1, 3) @@ -1532,6 +1603,7 @@ def plotIC(self, handle=None): return fig def plotAIC(self, handle=None): + """Box-plot of AIC across neurons (Matlab ``plotAIC``).""" ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] ax.boxplot(self.AIC, tick_labels=self.fitNames) ax.set_ylabel("AIC") @@ -1539,6 +1611,7 @@ def plotAIC(self, handle=None): return ax def plotBIC(self, handle=None): + """Box-plot of BIC across neurons (Matlab ``plotBIC``).""" ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] ax.boxplot(self.BIC, tick_labels=self.fitNames) ax.set_ylabel("BIC") @@ -1546,6 +1619,7 @@ def plotBIC(self, handle=None): return ax def plotlogLL(self, handle=None): + """Box-plot of log-likelihood across neurons (Matlab ``plotlogLL``).""" ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] ax.boxplot(self.logLL, tick_labels=self.fitNames) ax.set_ylabel("log likelihood") @@ -1553,6 +1627,7 @@ def plotlogLL(self, handle=None): return ax def plotResidualSummary(self, handle=None): + """Overlay all neurons' martingale residuals (Matlab ``plotResidualSummary``).""" fig = handle if handle is not None else plt.figure(figsize=(8.0, 3.5)) fig.clear() ax = fig.subplots(1, 1) @@ -1566,6 +1641,7 @@ def plotResidualSummary(self, handle=None): return fig def plotSummary(self, handle=None): + """Bar chart of mean AIC, BIC, and log-likelihood across configs.""" fig = handle if handle is not None else plt.figure(figsize=(10.0, 4.5)) fig.clear() axes = fig.subplots(1, 3) @@ -1585,6 +1661,7 @@ def plotSummary(self, handle=None): return fig def boxPlot(self, X, diffIndex: int = 1, h=None, dataLabels=None, **kwargs): + """General-purpose box-plot of *X* columns with fit-name labels.""" del kwargs ax = h if h is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] values = np.asarray(X, dtype=float) @@ -1786,6 +1863,7 @@ def plotKSSummary(self, neurons: list[int] | None = None, handle=None): return fig def toStructure(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict.""" return { "fitResCell": FitResult.CellArrayToStructure(self.fitResCell), "numNeurons": self.numNeurons, @@ -1802,6 +1880,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "FitSummary": + """Reconstruct a ``FitSummary`` from a dict.""" fits = [FitResult.fromStructure(item) for item in structure.get("fitResCell", [])] return FitSummary(fits) diff --git a/nstat/history.py b/nstat/history.py index d7c5e7ba..929eba0f 100644 --- a/nstat/history.py +++ b/nstat/history.py @@ -70,7 +70,34 @@ def __rmatmul__(self, coefficients) -> HistoryFilter: class History: - """MATLAB-style spike-history basis described by window boundaries.""" + """Spike-history basis defined by a set of window boundaries. + + Given a vector of *N* window-times, *N − 1* piecewise-constant basis + functions are created. Window *i* spans from ``windowTimes[i]`` to + ``windowTimes[i+1]``. + + Parameters + ---------- + windowTimes : array_like + Strictly increasing vector of window boundary times (seconds). + Must contain at least two entries. + minTime : float or None + Optional lower time bound for the computed history covariate. + maxTime : float or None + Optional upper time bound for the computed history covariate. + name : str, default ``'History'`` + Display name. + + Examples + -------- + >>> h = History([0, 0.001, 0.01, 0.05, 0.1]) + >>> h.numWindows + 4 + + See Also + -------- + TrialConfig, Analysis.computeHistLag + """ def __init__(self, windowTimes, minTime: float | None = None, maxTime: float | None = None, name: str = "History") -> None: times = np.asarray(windowTimes, dtype=float).reshape(-1) @@ -93,12 +120,32 @@ def numWindows(self) -> int: return int(self.windowTimes.size - 1) def setWindow(self, windowTimes) -> None: + """Replace the window-times vector. + + Parameters + ---------- + windowTimes : array_like + New window boundary times (must be strictly increasing, + length ≥ 2). + """ replacement = History(windowTimes, self.minTime, self.maxTime, self.name) self.windowTimes = replacement.windowTimes self.minTime = replacement.minTime self.maxTime = replacement.maxTime def toFilter(self, delta: float) -> HistoryFilterBank: + """Convert the history windows to a discrete-time filter bank. + + Parameters + ---------- + delta : float + Sample period (seconds). Must be positive. + + Returns + ------- + HistoryFilterBank + Bank of FIR filters (one per window) in the ``z⁻¹`` domain. + """ delta = float(delta) if delta <= 0: raise ValueError("delta must be positive") @@ -161,6 +208,24 @@ def _compute_single_history(self, train: nspikeTrain, historyIndex: int | None = return windowed def compute_history(self, trains, historyIndex: int | None = None, time_grid=None): + """Compute the history covariate(s) for one or more spike trains. + + Parameters + ---------- + trains : nspikeTrain or nstColl or sequence of nspikeTrain + Spike train(s) whose history to compute. + historyIndex : int or None + Optional label index appended to the data labels. + time_grid : array_like or None + Optional external time grid. If ``None``, the spike train's + own time grid is used. + + Returns + ------- + CovariateCollection + One :class:`Covariate` per spike train, each with + ``numWindows`` columns corresponding to the history windows. + """ from .trial import CovariateCollection if isinstance(trains, nspikeTrain): @@ -180,9 +245,11 @@ def compute_history(self, trains, historyIndex: int | None = None, time_grid=Non raise TypeError("History can only be computed from nspikeTrain, nstColl, or sequences of nspikeTrain") def computeHistory(self, trains, historyIndex: int | None = None, time_grid=None): + """Matlab-facing alias for :meth:`compute_history`.""" return self.compute_history(trains, historyIndex, time_grid=time_grid) def toStructure(self) -> dict[str, Any]: + """Serialise the History to a plain dictionary.""" return { "windowTimes": self.windowTimes.tolist(), "minTime": self.minTime, @@ -192,6 +259,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any] | None) -> "History" | None: + """Reconstruct a History from a dictionary (inverse of :meth:`toStructure`).""" if structure is None: return None if "windowTimes" in structure: @@ -209,6 +277,19 @@ def fromStructure(structure: dict[str, Any] | None) -> "History" | None: ) def plot(self, *_, handle=None, **__): + """Plot each of the history windows as step functions. + + Parameters + ---------- + handle : Axes or None + Matplotlib axes to draw into. If ``None``, a new figure is + created. + + Returns + ------- + Axes or list + The axes or plot handles. + """ tmin = np.asarray(self.windowTimes[:-1], dtype=float) tmax = np.asarray(self.windowTimes[1:], dtype=float) sampleRate = 1000.0 diff --git a/nstat/trial.py b/nstat/trial.py index 558b6c77..0ed6f237 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -81,7 +81,22 @@ def _copy_covariate_for_collection_view(cov: Covariate) -> Covariate: class CovariateCollection: - """MATLAB-style CovColl implementation with collection-level masks and timing.""" + """Ordered collection of :class:`~nstat.core.Covariate` objects (Matlab ``CovColl``). + + Provides collection-level masking, time alignment, sample-rate + enforcement, and covariate shifting. Individual covariates are + accessed via 1-based indexing (``getCov(1)``) to match Matlab. + + Parameters + ---------- + covariates : Covariate, sequence of Covariate, or None + Initial covariate(s) to add. + + See Also + -------- + Covariate : Scalar or multi-dimensional signal with CIs. + Trial : Combines a ``CovariateCollection`` with spike data. + """ def __init__(self, covariates: Sequence[Covariate] | Covariate | None = None, *more_covariates: Covariate) -> None: self.covArray: list[Covariate] = [] @@ -102,10 +117,12 @@ def __init__(self, covariates: Sequence[Covariate] | Covariate | None = None, *m @property def covariates(self) -> list[Covariate]: + """List of all covariates (copies with collection state applied).""" return [self.getCov(i) for i in range(1, self.numCov + 1)] @property def names(self) -> list[str]: + """List of covariate names in insertion order.""" return [cov.name for cov in self.covArray] def _capture_originals_if_needed(self) -> None: @@ -167,12 +184,15 @@ def _apply_collection_state(self, cov: Covariate, index: int) -> Covariate: return out def add(self, covariate: Covariate) -> None: + """Alias for :meth:`addToColl`.""" self.addToColl(covariate) def addCovariate(self, covariate: Covariate) -> None: + """Alias for :meth:`addToColl`.""" self.addToColl(covariate) def addCovCollection(self, covariates: "CovariateCollection") -> None: + """Merge all covariates from another collection into this one.""" self.addToColl(covariates) def addToColl(self, covariates: Sequence[Covariate] | Covariate | "CovariateCollection" | None) -> None: @@ -194,19 +214,29 @@ def addToColl(self, covariates: Sequence[Covariate] | Covariate | "CovariateColl raise TypeError("CovColl can only add Covariate instances or sequences of Covariates.") def removeCovariate(self, identifier: int | str) -> None: + """Remove a covariate by 1-based index or name.""" index = self._covariate_from_identifier(identifier) del self.covArray[index - 1] del self.covMask[index - 1] self._refresh_summary() def copy(self) -> "CovariateCollection": + """Return a deep copy of this collection.""" cov = [self.getCov(i).copySignal() for i in range(1, self.numCov + 1)] return CovariateCollection(cov) def get(self, name: str) -> Covariate: + """Retrieve a covariate by name (convenience alias for :meth:`getCov`).""" return self.getCov(name) def getCov(self, identifier: int | str | Sequence[int] | Sequence[str]): + """Return a covariate copy with collection state (shift, mask, rate) applied. + + Parameters + ---------- + identifier : int, str, or sequence + 1-based index, covariate name, or sequence of either. + """ if isinstance(identifier, str): return self._apply_collection_state(self.covArray[self.getCovIndFromName(identifier) - 1], self.getCovIndFromName(identifier)) if isinstance(identifier, Sequence) and not isinstance(identifier, (str, bytes, np.ndarray)): @@ -219,17 +249,20 @@ def getCov(self, identifier: int | str | Sequence[int] | Sequence[str]): return self._apply_collection_state(self.covArray[index - 1], index) def getCovIndFromName(self, name: str) -> int: + """Return the 1-based index of a covariate by *name*.""" for idx, cov in enumerate(self.covArray, start=1): if cov.name == name: return idx raise KeyError(f"Covariate '{name}' not found") def getCovIndicesFromNames(self, name: Sequence[str] | str): + """Return 1-based index(es) for one or more covariate names.""" if isinstance(name, str): return self.getCovIndFromName(name) return [self.getCovIndFromName(item) for item in name] def isCovPresent(self, cov) -> int: + """Return ``1`` if a covariate is in this collection, ``0`` otherwise.""" if isinstance(cov, Covariate): if not cov.name: return 0 @@ -250,63 +283,77 @@ def isCovPresent(self, cov) -> int: raise TypeError("Need either covariate class or name of covariate or index of covariate") def findMinTime(self) -> float: + """Return the earliest ``minTime`` across all stored covariates.""" if self.numCov == 0: return float("inf") return float(min(cov.minTime for cov in self.covArray)) def findMaxTime(self) -> float: + """Return the latest ``maxTime`` across all stored covariates.""" if self.numCov == 0: return float("-inf") return float(max(cov.maxTime for cov in self.covArray)) def findMaxSampleRate(self) -> float: + """Return the highest sample rate across all stored covariates.""" if self.numCov == 0: return float("nan") return float(max(cov.sampleRate for cov in self.covArray if np.isfinite(cov.sampleRate))) def setMinTime(self, minTime: float | None = None) -> None: + """Set the collection-level minimum time (applies shift if set).""" if minTime is None: minTime = self.findMinTime() + float(self.covShift) self.minTime = float(minTime) def setMaxTime(self, maxTime: float | None = None) -> None: + """Set the collection-level maximum time (applies shift if set).""" if maxTime is None: maxTime = self.findMaxTime() + float(self.covShift) self.maxTime = float(maxTime) def restrictToTimeWindow(self, wMin: float, wMax: float) -> None: + """Set both min and max time to restrict the visible window.""" self.setMinTime(wMin) self.setMaxTime(wMax) def setSampleRate(self, sampleRate: float) -> None: + """Set the collection sample rate and enforce it on all covariates.""" if self.originalSampleRate is None and np.isfinite(self.sampleRate): self.originalSampleRate = float(self.sampleRate) self.sampleRate = float(sampleRate) self.enforceSampleRate() def resample(self, sampleRate: float) -> None: + """Alias for :meth:`setSampleRate`.""" self.setSampleRate(sampleRate) def enforceSampleRate(self) -> None: + """Ensure the collection's sample rate is finite and positive.""" if not np.isfinite(self.sampleRate) or self.sampleRate <= 0: self.sampleRate = self.findMaxSampleRate() def resetMask(self) -> None: + """Enable all covariate dimensions (clear any masking).""" self.covMask = [np.ones(cov.dimension, dtype=int) for cov in self.covArray] def getCovDataMask(self, identifier: int | str) -> np.ndarray: + """Return the binary dimension mask for a single covariate.""" index = self._covariate_from_identifier(identifier) return np.asarray(self.covMask[index - 1], dtype=int).copy() def isCovMaskSet(self) -> bool: + """Return ``True`` if any covariate dimension is currently masked out.""" return any(np.any(mask == 0) for mask in self.covMask) def flattenCovMask(self) -> np.ndarray: + """Concatenate all per-covariate masks into a single 1-D binary array.""" if not self.covMask: return np.array([], dtype=int) return np.concatenate([np.asarray(mask, dtype=int).reshape(-1) for mask in self.covMask]) def getSelectorFromMasks(self, covMask: list[np.ndarray] | None = None) -> list[list[int]]: + """Convert per-covariate binary masks to lists of active 1-based indices.""" current = self.covMask if covMask is None else covMask selector: list[list[int]] = [] for mask in current: @@ -344,6 +391,11 @@ def _selector_cell_from_names(self, dataSelector: Sequence[Any]) -> list[list[in return selectorCell def generateSelectorCell(self, dataSelector) -> list[list[int]]: + """Parse a heterogeneous *dataSelector* into per-covariate index lists. + + Accepts name-based (``[['covName', 'label1', ...], ...]``) or + numeric (``[[1,2], [3], ...]``) selectors. + """ if dataSelector is None: return [[] for _ in range(self.numCov)] if isinstance(dataSelector, str): @@ -395,9 +447,14 @@ def _selector_to_cov_mask(self, selectorCell: list[list[int]]) -> list[np.ndarra return masks def setMasksFromSelector(self, selectorCell: list[list[int]]) -> None: + """Set covariate masks from a list of 1-based index lists.""" self.covMask = self._selector_to_cov_mask(selectorCell) def setMask(self, cellInput) -> None: + """Set the covariate mask from a selector or ``'all'`` to reset. + + Accepts the same formats as :meth:`generateSelectorCell`. + """ if isinstance(cellInput, str) and cellInput == "all": self.resetMask() return @@ -405,9 +462,11 @@ def setMask(self, cellInput) -> None: self.setMasksFromSelector(selectorCell) def nActCovar(self) -> int: + """Return the number of covariates with at least one active dimension.""" return int(sum(1 for selector in self.getSelectorFromMasks() if selector)) def maskAwayCov(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + """Zero-out the mask for the specified covariate(s).""" identifiers = identifier if isinstance(identifier, (int, str)): identifiers = [identifier] @@ -416,10 +475,12 @@ def maskAwayCov(self, identifier: int | str | Sequence[int] | Sequence[str]) -> self.covMask[index - 1] = np.zeros(self.covArray[index - 1].dimension, dtype=int) def maskAwayOnlyCov(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + """Reset all masks then mask away only the specified covariate(s).""" self.resetMask() self.maskAwayCov(identifier) def maskAwayAllExcept(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + """Mask away every covariate *except* the ones specified.""" if isinstance(identifier, (int, str)): keep = {self._covariate_from_identifier(identifier)} else: @@ -429,6 +490,7 @@ def maskAwayAllExcept(self, identifier: int | str | Sequence[int] | Sequence[str self.covMask[idx - 1] = np.zeros(cov.dimension, dtype=int) def setCovShift(self, deltaT: float, identifier=None) -> "CovariateCollection": + """Apply a temporal shift *deltaT* to the collection's time axis.""" self.covShift = float(deltaT) if np.isfinite(self.minTime): self.minTime = float(self.minTime + self.covShift) @@ -437,11 +499,13 @@ def setCovShift(self, deltaT: float, identifier=None) -> "CovariateCollection": return self def resetCovShift(self) -> None: + """Remove the temporal shift and recompute time bounds.""" self.covShift = 0.0 self.setMinTime() self.setMaxTime() def restoreToOriginal(self) -> None: + """Restore original sample rate, time bounds, shift, and masks.""" self.covShift = 0.0 if self.originalSampleRate is not None: self.sampleRate = float(self.originalSampleRate) @@ -452,6 +516,7 @@ def restoreToOriginal(self) -> None: self.resetMask() def plot(self, *_, handle=None, **__): + """Plot each covariate in a vertically stacked panel layout.""" selected = [idx for idx in range(1, self.numCov + 1)] fig = handle if handle is not None else plt.figure(figsize=(8.5, max(2.5, 2.2 * max(len(selected), 1)))) fig.clear() @@ -466,12 +531,14 @@ def plot(self, *_, handle=None, **__): return fig def getAllCovLabels(self) -> list[str]: + """Return the data-labels of every covariate dimension (no mask filtering).""" labels: list[str] = [] for index in range(1, self.numCov + 1): labels.extend(self.getCov(index).dataLabels) return labels def getCovLabelsFromMask(self) -> list[str]: + """Return data-labels only for dimensions that are currently unmasked.""" labels: list[str] = [] for index in range(1, self.numCov + 1): cov = self.getCov(index) @@ -497,6 +564,15 @@ def getCovDimension(self, identifier=None) -> np.ndarray: return np.array([int(c.dimension) for c in covs], dtype=int) def matrixWithTime(self, repType: str = "standard", dataSelector=None) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Return ``(time, data_matrix, labels)`` for active covariate dimensions. + + Parameters + ---------- + repType : {'standard', 'zero-mean'} + Signal representation type. + dataSelector : optional + Name-based or numeric selector; ``None`` uses the current mask. + """ if self.numCov == 0: raise ValueError("CovariateCollection is empty") if dataSelector is None: @@ -526,6 +602,7 @@ def matrixWithTime(self, repType: str = "standard", dataSelector=None) -> tuple[ return time.copy(), np.hstack(parts) if parts else np.zeros((time.size, 0), dtype=float), labels def dataToMatrix(self, repType: str | Sequence[str] | None = "standard", dataSelector=None, *_) -> np.ndarray: + """Return the covariate data matrix (no time column) for active dimensions.""" if repType not in {"standard", "zero-mean"}: dataSelector = repType repType = "standard" @@ -539,6 +616,7 @@ def dataToStructure( minTime: float | None = None, maxTime: float | None = None, ) -> dict[str, Any]: + """Serialize active covariate data to a ``{'time': ..., 'signals': ...}`` dict.""" del binwidth, minTime, maxTime if selectorCell is None: if self.isCovMaskSet(): @@ -552,6 +630,7 @@ def dataToStructure( } def toStructure(self) -> dict[str, Any]: + """Serialize to a plain dict (Matlab ``CovColl.toStructure``).""" self.resetMask() structure: dict[str, Any] = { "numCov": int(self.numCov), @@ -570,6 +649,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure) -> "CovariateCollection" | list["CovariateCollection"]: + """Reconstruct from a dict produced by :meth:`toStructure`.""" if isinstance(structure, list): return [CovariateCollection.fromStructure(item) for item in structure] if not isinstance(structure, dict): @@ -596,7 +676,23 @@ def fromStructure(structure) -> "CovariateCollection" | list["CovariateCollectio class SpikeTrainCollection: - """MATLAB-style nstColl implementation.""" + """Ordered collection of :class:`~nstat.core.nspikeTrain` objects (Matlab ``nstColl``). + + Provides a neuron mask, neighbour graph, and methods for PSTH, + GLM-PSTH, state-space GLM, raster plots, and data-matrix export. + Spike trains are accessed via 1-based indexing (``getNST(1)``) to + match Matlab conventions. + + Parameters + ---------- + trains : nspikeTrain, sequence of nspikeTrain, or None + Initial spike train(s) to add. + + See Also + -------- + nspikeTrain : Single-neuron point-process representation. + Trial : Combines a ``SpikeTrainCollection`` with covariates. + """ def __init__(self, trains: Sequence[nspikeTrain] | nspikeTrain | None = None) -> None: self.nstrain: list[nspikeTrain] = [] @@ -611,10 +707,12 @@ def __init__(self, trains: Sequence[nspikeTrain] | nspikeTrain | None = None) -> @property def num_spike_trains(self) -> int: + """Number of spike trains in this collection.""" return self.numSpikeTrains @property def uniqueNeuronNames(self) -> list[str]: + """Unique, insertion-ordered neuron names in the collection.""" return self.getUniqueNSTnames() def __iter__(self): @@ -640,6 +738,7 @@ def _refresh_summary(self) -> None: self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) def addSingleSpikeToColl(self, nst: nspikeTrain) -> None: + """Append a single spike train (deep-copied) to the collection.""" train = nst.nstCopy() if not getattr(train, "name", ""): train.setName(str(self.numSpikeTrains + 1)) @@ -658,6 +757,7 @@ def addSingleSpikeToColl(self, nst: nspikeTrain) -> None: self.neighbors = [] def addToColl(self, nst: Sequence[nspikeTrain] | nspikeTrain | "SpikeTrainCollection") -> None: + """Add one or more spike trains (or another collection) to this collection.""" if isinstance(nst, SpikeTrainCollection): for train in nst.nstrain: self.addSingleSpikeToColl(train) @@ -674,24 +774,30 @@ def addToColl(self, nst: Sequence[nspikeTrain] | nspikeTrain | "SpikeTrainCollec raise TypeError("nstColl can only add nspikeTrain instances or sequences of nspikeTrain.") def merge(self, nstColl2: "SpikeTrainCollection") -> "SpikeTrainCollection": + """Merge another collection into this one (in-place).""" self.addToColl(nstColl2) return self def length(self) -> int: + """Return the number of spike trains (Matlab ``nstColl.length``).""" return int(self.numSpikeTrains) def getFirstSpikeTime(self) -> float: + """Return the earliest time boundary across all trains.""" return float(self.minTime) def getLastSpikeTime(self) -> float: + """Return the latest time boundary across all trains.""" return float(self.maxTime) def get_nst(self, idx: int) -> nspikeTrain: + """Return a spike train by 0-based index (Pythonic API).""" if idx < 0 or idx >= self.numSpikeTrains: raise IndexError("SpikeTrainCollection index out of bounds (0-based indexing).") return self.nstrain[idx] def getNST(self, idx) -> nspikeTrain | list[nspikeTrain]: + """Return spike train(s) by 1-based index (Matlab ``nstColl.getNST``).""" if isinstance(idx, Sequence) and not isinstance(idx, (str, bytes, np.ndarray)): return [self.getNST(int(item)) for item in idx] index = int(idx) @@ -714,10 +820,12 @@ def getNSTnames(self, selectorArray=None) -> list[str]: return [all_names[i] for i in indices if 0 <= i < len(all_names)] def getUniqueNSTnames(self, selectorArray=None) -> list[str]: + """Return unique, insertion-ordered neuron names.""" names = [name for name in self.getNSTnames(selectorArray) if name] return list(dict.fromkeys(names)) def getNSTIndicesFromName(self, name: Sequence[str] | str): + """Return 1-based index(es) for a neuron name (or list of names).""" if isinstance(name, str): matches = [i + 1 for i, value in enumerate(self.getNSTnames()) if value == name] if not matches: @@ -726,18 +834,21 @@ def getNSTIndicesFromName(self, name: Sequence[str] | str): return [self.getNSTIndicesFromName(item) for item in name] def getNSTnameFromInd(self, ind: int) -> str: + """Return the neuron name for 1-based index *ind*.""" index = int(ind) if index < 1 or index > self.numSpikeTrains: raise IndexError("Index is out of bounds!") return str(self.nstrain[index - 1].name) def getNSTFromName(self, neuronName=None): + """Return spike train(s) matching the given neuron name(s).""" if neuronName is None: neuronName = self.getUniqueNSTnames() indices = self.getNSTIndicesFromName(neuronName) return self.getNST(indices) def getFieldVal(self, fieldName: str): + """Collect a named field from every spike train (Matlab ``nstColl.getFieldVal``).""" fieldVal: list[float] = [] neuronNumbers: list[int] = [] cnt = 1 @@ -757,6 +868,7 @@ def getFieldVal(self, fieldName: str): return np.asarray(fieldVal, dtype=float), np.asarray(neuronNumbers, dtype=int) def shiftTime(self, timeShift: float | None = None) -> "SpikeTrainCollection": + """Return a new collection with spike times shifted by *timeShift*.""" if timeShift is None: timeShift = -float(self.minTime) shifted = [nspikeTrain(np.asarray(train.spikeTimes, dtype=float) + float(timeShift)) for train in self.nstrain] @@ -769,6 +881,11 @@ def toSpikeTrain( maxTime: float | None = None, windowTimes: Sequence[float] | None = None, ) -> nspikeTrain: + """Collapse selected spike trains into a single :class:`nspikeTrain`. + + Concatenates spike times end-to-end, optionally rescaling + each trial into windows defined by *windowTimes*. + """ if self.numSpikeTrains == 0: raise ValueError("nstColl.toSpikeTrain requires at least one spike train") @@ -824,6 +941,7 @@ def toSpikeTrain( return collapsed def setMinTime(self, value: float | None = None) -> None: + """Set the minimum time for every train in the collection.""" if value is None: value = self.minTime for train in self.nstrain: @@ -831,6 +949,7 @@ def setMinTime(self, value: float | None = None) -> None: self.minTime = float(value) def setMaxTime(self, value: float | None = None) -> None: + """Set the maximum time for every train in the collection.""" if value is None: value = self.maxTime for train in self.nstrain: @@ -838,6 +957,7 @@ def setMaxTime(self, value: float | None = None) -> None: self.maxTime = float(value) def resample(self, sampleRate: float) -> None: + """Resample all trains to *sampleRate* and align time bounds.""" self.sampleRate = float(sampleRate) for train in self.nstrain: train.resample(sampleRate) @@ -845,17 +965,20 @@ def resample(self, sampleRate: float) -> None: train.setMaxTime(float(self.maxTime)) def enforceSampleRate(self) -> None: + """Resample any train whose rate differs from the collection rate.""" for index in range(1, self.numSpikeTrains + 1): currSpike = self.getNST(index) if round(float(currSpike.sampleRate), 9) != round(float(self.sampleRate), 9): currSpike.resample(float(self.sampleRate)) def findMaxSampleRate(self) -> float: + """Return the highest sample rate among all trains.""" if self.numSpikeTrains == 0: return float("-inf") return float(max(train.sampleRate for train in self.nstrain)) def setMask(self, mask: Sequence[int] | np.ndarray) -> None: + """Set the neuron mask from a binary array or 1-based indices.""" arr = np.asarray(mask, dtype=int).reshape(-1) if arr.size == self.numSpikeTrains and np.all(np.isin(arr, [0, 1])): self.setNeuronMask(arr) @@ -863,6 +986,7 @@ def setMask(self, mask: Sequence[int] | np.ndarray) -> None: self.setNeuronMaskFromInd(arr) def setNeuronMaskFromInd(self, mask: Sequence[int] | np.ndarray) -> None: + """Set the neuron mask from 1-based neuron indices.""" arr = np.asarray(mask, dtype=int).reshape(-1) newMask = np.zeros(self.numSpikeTrains, dtype=int) if arr.size: @@ -872,24 +996,34 @@ def setNeuronMaskFromInd(self, mask: Sequence[int] | np.ndarray) -> None: self.setNeuronMask(newMask) def setNeuronMask(self, mask: Sequence[int] | np.ndarray) -> None: + """Set the binary neuron mask directly (length must match ``numSpikeTrains``).""" arr = np.asarray(mask, dtype=int).reshape(-1) if arr.size != self.numSpikeTrains: raise ValueError("neuronMask length must match number of spike trains.") self.neuronMask = arr.astype(int) def resetMask(self) -> None: + """Enable all neurons (ones-mask).""" self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) def getIndFromMask(self) -> list[int]: + """Return 1-based indices of neurons currently enabled by the mask.""" return (np.flatnonzero(self.neuronMask == 1) + 1).astype(int).tolist() def getIndFromMaskMinusOne(self, neuron: int) -> list[int]: + """Return active indices excluding *neuron* (1-based).""" return [idx for idx in self.getIndFromMask() if idx != int(neuron)] def isNeuronMaskSet(self) -> bool: + """Return ``True`` if any neuron is currently masked out.""" return bool(np.any(self.neuronMask == 0)) def setNeighbors(self, neighborArray: Sequence[Sequence[int]] | np.ndarray | None = None) -> None: + """Set or auto-generate the neuron neighbour matrix. + + If *neighborArray* is ``None``, every neuron is a neighbour of + every other neuron (all-to-all minus self). + """ if neighborArray is None: if self.numSpikeTrains == 0: self.neighbors = [] @@ -907,9 +1041,11 @@ def setNeighbors(self, neighborArray: Sequence[Sequence[int]] | np.ndarray | Non self.neighbors = arr def areNeighborsSet(self) -> bool: + """Return ``True`` if the neighbour matrix has been initialized.""" return np.size(self.neighbors) > 0 def getNeighbors(self, neuronNum: int | Sequence[int]): + """Return the 1-based neighbour indices for one or more neurons.""" if isinstance(neuronNum, Sequence) and not isinstance(neuronNum, (str, bytes, np.ndarray)): rows = [self.getNeighbors(int(item)) for item in neuronNum] if rows and all(len(row) == len(rows[0]) for row in rows): @@ -926,6 +1062,7 @@ def getNeighbors(self, neuronNum: int | Sequence[int]): return [value for value in row if value in available and value > 0] def getMaxBinSizeBinary(self) -> float: + """Return the largest bin-width that keeps all active trains binary.""" selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) if not selectorArray: return np.inf @@ -933,9 +1070,11 @@ def getMaxBinSizeBinary(self) -> float: return float(np.min(values)) def BinarySigRep(self) -> bool: + """Return ``True`` if every train's signal representation is binary.""" return bool(all(self.getNST(index).isSigRepBinary() for index in range(1, self.numSpikeTrains + 1))) def isSigRepBinary(self) -> bool: + """Alias for :meth:`BinarySigRep`.""" return self.BinarySigRep() def dataToMatrix( @@ -945,6 +1084,7 @@ def dataToMatrix( minTime: float | None = None, maxTime: float | None = None, ) -> np.ndarray: + """Return an ``(nTimeBins, nNeurons)`` binary spike-count matrix.""" if self.numSpikeTrains == 0: return np.zeros((0, 0), dtype=float) if maxTime is None: @@ -974,6 +1114,7 @@ def dataToMatrix( return dataMat def getEnsembleNeuronCovariates(self, neuronNum: int = 1, neighborIndex=None, windowTimes=None): + """Build ensemble-history covariates for *neuronNum* from its neighbours.""" if neighborIndex is None or ( isinstance(neighborIndex, (list, tuple, np.ndarray)) and np.asarray(neighborIndex).size == 0 ): @@ -991,6 +1132,7 @@ def getEnsembleNeuronCovariates(self, neuronNum: int = 1, neighborIndex=None, wi return ensembleCovariates def addNeuronNamesToEnsCovColl(self, ensembleCovariates: CovariateCollection) -> None: + """Prefix ensemble-covariate labels with their neuron name.""" for i in range(1, ensembleCovariates.numCov + 1): tempCov = ensembleCovariates.covArray[i - 1] name = self.getNST(i).name @@ -1000,6 +1142,7 @@ def addNeuronNamesToEnsCovColl(self, ensembleCovariates: CovariateCollection) -> tempCov.setDataLabels(dataLabels) def restoreToOriginal(self, rMask: int = 0) -> None: + """Restore all trains to their original state; optionally reset the mask.""" for train in self.nstrain: train.restoreToOriginal() self._refresh_summary() @@ -1009,11 +1152,13 @@ def restoreToOriginal(self, rMask: int = 0) -> None: self.resetMask() def ensureConsistancy(self) -> None: + """Enforce consistent sample rate and time bounds across all trains.""" self.enforceSampleRate() self.setMinTime() self.setMaxTime() def updateTimes(self, nst: nspikeTrain) -> None: + """Expand collection time bounds to include *nst*, or clamp *nst*.""" if float(nst.minTime) <= float(self.minTime): self.setMinTime(float(nst.minTime)) else: @@ -1064,10 +1209,12 @@ def plot(self, selectorArray: Sequence[int] | None = None, return ax def getMinISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + """Return the minimum ISI for each selected neuron.""" isis = self.getISIs(selectorArray, minTime, maxTime) return np.asarray([float(np.min(values)) if values.size else 0.0 for values in isis], dtype=float) def getISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + """Return a list of ISI arrays, one per selected neuron.""" if maxTime is None: maxTime = self.maxTime if minTime is None: @@ -1077,6 +1224,7 @@ def getISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | N return [self.getNST(int(neuron)).getISIs(minTime, maxTime) for neuron in selectorArray] def plotISIHistogram(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None, handle=None): + """Plot ISI histograms for each selected neuron in stacked subplots.""" if maxTime is None: maxTime = self.maxTime if minTime is None: @@ -1101,6 +1249,7 @@ def plotExponentialFit( numBins: int | None = None, handle=None, ): + """Plot exponential-distribution fits of ISIs for selected neurons.""" if maxTime is None: maxTime = self.maxTime if minTime is None: @@ -1118,6 +1267,7 @@ def plotExponentialFit( return fig def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + """Return a list of spike-time arrays, one per active neuron.""" del minTime, maxTime selector = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) return [self.getNST(int(index)).getSpikeTimes() for index in selector] @@ -1129,6 +1279,10 @@ def psth( minTime: float | None = None, maxTime: float | None = None, ) -> Covariate: + """Compute the peri-stimulus time histogram (standard binned PSTH). + + Returns a :class:`Covariate` with firing rate in Hz. + """ if binwidth <= 0: raise ValueError("binwidth must be > 0") min_time = self.minTime if minTime is None else float(minTime) @@ -1438,6 +1592,11 @@ def estimateVarianceAcrossTrials( numIter: int | None = None, fitType: str | None = None, ) -> np.ndarray: + """Estimate the state-noise covariance ``Q`` from bootstrap GLM fits. + + Used internally by :meth:`ssglm` / :meth:`ssglmFB` to initialise + the EM algorithm's state-noise prior. + """ if fitType is None or fitType == "": fitType = "poisson" if numIter is None: @@ -1501,6 +1660,11 @@ def estimateVarianceAcrossTrials( @staticmethod def generateUnitImpulseBasis(basisWidth: float, minTime: float, maxTime: float, sampleRate: float = 1000.0) -> Covariate: + """Create a piecewise-constant (unit impulse) basis :class:`Covariate`. + + Each column is a rectangular pulse spanning one *basisWidth* + interval, used as the design matrix for GLM-PSTH estimation. + """ windowTimes = np.arange(float(minTime), float(maxTime), float(basisWidth)) if windowTimes.size == 0 or not np.isclose(windowTimes[-1], maxTime): windowTimes = np.append(windowTimes, float(maxTime)) @@ -1733,7 +1897,31 @@ def fromStructure(structure: dict[str, Any]) -> "SpikeTrainCollection": class TrialConfig: - """MATLAB-style TrialConfig with configuration-application semantics.""" + """Description of a single GLM fit configuration. + + A ``TrialConfig`` specifies which covariates, history, ensemble + history, and sample rate to apply to a :class:`Trial` before fitting. + Multiple ``TrialConfig`` objects are collected in a + :class:`ConfigCollection` to run a batch of nested-model comparisons. + + Parameters + ---------- + covMask : sequence of str or nested sequences, or None + Covariate labels to include in the design matrix. + ``'all'`` includes every covariate. + sampleRate : float or None + If provided, the trial is resampled to this rate before fitting. + history : History or array_like or None + Self-history specification (History object or window-times). + ensCovHist : History or array_like or None + Ensemble-history specification. + ensCovMask : array_like or None + Binary mask selecting which neighbours contribute ensemble history. + covLag : array_like or None + Covariate shift / lag specification. + name : str + Human-readable name for this configuration. + """ def __init__( self, @@ -1755,6 +1943,7 @@ def __init__( @property def covariate_names(self) -> list[str]: + """Return the name of each covariate group in the mask.""" if not self.covMask: return [] names: list[str] = [] @@ -1766,12 +1955,19 @@ def covariate_names(self) -> list[str]: return names def getName(self) -> str: + """Return this configuration's human-readable name.""" return self.name def setName(self, name: str) -> None: + """Set this configuration's human-readable name.""" self.name = str(name) def setConfig(self, trial: "Trial") -> None: + """Apply this configuration to a Trial (in place). + + Sets the covariate mask, history, ensemble history, sample rate, + and covariate lag on the trial. + """ if not _is_empty_config_value(self.history): trial.setHistory(self.history) else: @@ -1795,6 +1991,7 @@ def setConfig(self, trial: "Trial") -> None: trial.setEnsCovMask() def toStructure(self) -> dict[str, Any]: + """Serialize to a plain dict (Matlab ``TrialConfig.toStructure``).""" return { "covMask": self.covMask, "sampleRate": self.sampleRate, @@ -1807,6 +2004,9 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "TrialConfig": + """Reconstruct from a dict produced by :meth:`toStructure`. + + .. note:: Follows Matlab's omission of ``ensCovMask``.""" # MATLAB's `TrialConfig.fromStructure` omits `ensCovMask` and shifts # the remaining trailing arguments left by one position. return TrialConfig( @@ -1820,7 +2020,18 @@ def fromStructure(structure: dict[str, Any]) -> "TrialConfig": class ConfigCollection: - """MATLAB-style ConfigColl implementation.""" + """Ordered collection of :class:`TrialConfig` objects. + + Used by :class:`Analysis` to iterate over multiple model + specifications (e.g. baseline, baseline + stimulus, + baseline + stimulus + history) and compare their fits. + + Parameters + ---------- + configs : TrialConfig, sequence of TrialConfig, or None + Initial configuration(s). ``None`` creates a single + ``"Empty Config"`` entry (Matlab parity). + """ def __init__(self, configs: Sequence[TrialConfig] | TrialConfig | str | None = None) -> None: self.numConfigs = 0 @@ -1832,12 +2043,15 @@ def __init__(self, configs: Sequence[TrialConfig] | TrialConfig | str | None = N @property def configs(self) -> list[TrialConfig]: + """List of actual ``TrialConfig`` entries (excludes empty placeholders).""" return [cfg for cfg in self.configArray if isinstance(cfg, TrialConfig)] def add_config(self, cfg: TrialConfig) -> None: + """Pythonic alias for :meth:`addConfig`.""" self.addConfig(cfg) def addConfig(self, cfg: Sequence[TrialConfig] | TrialConfig | str | None) -> None: + """Append one or more configurations to this collection.""" if isinstance(cfg, Sequence) and not isinstance(cfg, (str, bytes, TrialConfig, np.ndarray)): if len(cfg) == 0: self.numConfigs += 1 @@ -1863,16 +2077,19 @@ def addConfig(self, cfg: Sequence[TrialConfig] | TrialConfig | str | None) -> No raise TypeError("ConfigColl can only add TrialConfig objects, strings, or sequences of them.") def get_config(self, idx: int) -> TrialConfig | str | list[str]: + """Return a config by 0-based index (Pythonic API).""" if idx < 0 or idx >= self.numConfigs: raise IndexError("ConfigCollection index out of bounds (0-based indexing).") return self.configArray[idx] def getConfig(self, idx: int): + """Return a config by 1-based index (Matlab ``ConfigColl.getConfig``).""" if idx < 1 or idx > self.numConfigs: raise IndexError("Index Out of Bounds") return self.configArray[idx - 1] def setConfig(self, trial: "Trial", index: int) -> None: + """Apply configuration *index* (1-based) to the given Trial.""" config = self.getConfig(index) if isinstance(config, TrialConfig): config.setConfig(trial) @@ -1880,6 +2097,7 @@ def setConfig(self, trial: "Trial", index: int) -> None: raise ValueError("Cannot Set Empty Configs") def getConfigNames(self, index: Sequence[int] | None = None) -> list[str]: + """Return the names for selected configs (1-based), or all if *index* is ``None``.""" if index is None: index = list(range(1, self.numConfigs + 1)) out: list[str] = [] @@ -1891,6 +2109,7 @@ def getConfigNames(self, index: Sequence[int] | None = None) -> list[str]: return out def setConfigNames(self, names, index: Sequence[int] | None = None) -> None: + """Set the human-readable names for configs at 1-based *index* positions.""" if index is None: index = list(range(1, self.numConfigs + 1)) if isinstance(names, str): @@ -1910,10 +2129,12 @@ def setConfigNames(self, names, index: Sequence[int] | None = None) -> None: raise TypeError("names must be a string or sequence of strings.") def getSubsetConfigs(self, subset: Sequence[int]) -> "ConfigCollection": + """Return a new collection containing only configs at 1-based *subset* indices.""" tempconfigs = [self.getConfig(int(i)) for i in subset] return ConfigCollection(tempconfigs) def toStructure(self) -> dict[str, Any]: + """Serialize to a plain dict (Matlab ``ConfigColl.toStructure``).""" structure = { "numConfigs": self.numConfigs, "configNames": list(self.configNames), @@ -1928,6 +2149,7 @@ def toStructure(self) -> dict[str, Any]: @staticmethod def fromStructure(structure: dict[str, Any]) -> "ConfigCollection": + """Reconstruct from a dict produced by :meth:`toStructure`.""" configs = [] for row in structure.get("configArray", []): if isinstance(row, dict): @@ -1938,7 +2160,32 @@ def fromStructure(structure: dict[str, Any]) -> "ConfigCollection": class Trial: - """MATLAB-style Trial object preserving collection-level workflow semantics.""" + """Single-trial data container binding spikes, covariates, and events (Matlab ``Trial``). + + A ``Trial`` enforces consistent time bounds and sample rate across + its spike collection, covariate collection, and optional event stream. + It provides the design-matrix construction used by :class:`Analysis` + to fit point-process GLMs. + + Parameters + ---------- + spike_collection : SpikeTrainCollection + Neural spike data. + covariate_collection : CovariateCollection + Stimulus or task covariates. + events : Events, optional + Discrete event markers. + hist : History or array_like, optional + Self-history specification. + ensCovHist : History or array_like, optional + Ensemble-history specification. + ensCovMask : array_like, optional + Binary mask for ensemble neighbours. + + See Also + -------- + SpikeTrainCollection, CovariateCollection, Analysis + """ def __init__( self, @@ -1991,20 +2238,25 @@ def __init__( @property def spike_collection(self) -> SpikeTrainCollection: + """The trial's spike-train collection.""" return self.nspikeColl @property def covariate_collection(self) -> CovariateCollection: + """The trial's covariate collection.""" return self.covarColl @property def spikeColl(self) -> SpikeTrainCollection: + """Alias for :attr:`spike_collection` (Matlab compat).""" return self.nspikeColl def setTrialEvents(self, event: Events | None) -> None: + """Attach an :class:`Events` object (or ``None`` to clear).""" self.ev = event if isinstance(event, Events) else None def getEvents(self) -> Events | None: + """Return the attached Events, or ``None``.""" return self.ev @property @@ -2016,6 +2268,7 @@ def covarColl(self, value: CovariateCollection) -> None: self._covarColl = value def getTrialPartition(self) -> np.ndarray: + """Return ``[trainMin, trainMax, valMin, valMax]`` partition times.""" training = [] if self.trainingWindow is None else list(self.trainingWindow) validation = [] if self.validationWindow is None else list(self.validationWindow) p = training + validation @@ -2024,6 +2277,7 @@ def getTrialPartition(self) -> np.ndarray: return np.asarray(p, dtype=float) def setTrialPartition(self, partitionTimes) -> None: + """Set training and validation time windows from a 3- or 4-element array.""" if partitionTimes is None or len(partitionTimes) == 0: partitionTimes = self.getTrialPartition() values = np.asarray(partitionTimes, dtype=float).reshape(-1) @@ -2041,6 +2295,7 @@ def setTrialPartition(self, partitionTimes) -> None: self.setMaxTime(trainingWindow[1]) def setTrialTimesFor(self, partitionName: str = "training") -> None: + """Set trial time bounds to either the ``'training'`` or ``'validation'`` window.""" p = self.getTrialPartition() if partitionName == "training": timeWindow = p[:2] @@ -2052,6 +2307,7 @@ def setTrialTimesFor(self, partitionName: str = "training") -> None: self.setMaxTime(float(timeWindow[1])) def setMinTime(self, minTime: float | None = None) -> None: + """Set minimum time across spikes, covariates, and ensemble covariates.""" if minTime is None: minTime = self.findMinTime() self.nspikeColl.setMinTime(float(minTime)) @@ -2061,6 +2317,7 @@ def setMinTime(self, minTime: float | None = None) -> None: self.minTime = float(minTime) def setMaxTime(self, maxTime: float | None = None) -> None: + """Set maximum time across spikes, covariates, and ensemble covariates.""" if maxTime is None: maxTime = self.findMaxTime() self.nspikeColl.setMaxTime(float(maxTime)) @@ -2070,6 +2327,7 @@ def setMaxTime(self, maxTime: float | None = None) -> None: self.maxTime = float(maxTime) def updateTimePartitions(self) -> None: + """Clamp training/validation windows to current min/max time.""" if not (np.isfinite(self.minTime) and np.isfinite(self.maxTime)): return p = self.getTrialPartition() @@ -2081,7 +2339,84 @@ def updateTimePartitions(self) -> None: newValMax = min(self.maxTime, validation[1]) self.setTrialPartition([newTrainMin, newTrainMax, newValMin, newValMax]) + def plotRaster(self, handle=None): + """Plot only the spike raster for this trial. + + Parameters + ---------- + handle : matplotlib Figure or Axes, optional + If an ``Axes`` is provided the raster is drawn there. + If a ``Figure`` is provided a new axes is added. + If *None* a new figure is created. + + Returns + ------- + matplotlib.figure.Figure + """ + if handle is None: + fig, ax = plt.subplots(figsize=(9.0, 3.0)) + elif isinstance(handle, plt.Axes): + ax = handle + fig = ax.figure + else: + fig = handle + fig.clear() + ax = fig.add_subplot(111) + self.nspikeColl.plot(handle=ax) + ax.set_title("Trial Spike Raster") + fig.tight_layout() + return fig + + def plotCovariates(self, handle=None): + """Plot covariates (and events, if set) for this trial. + + Layout adapts to the number of active covariates, following the + Matlab ``Trial.plotCovariates`` behaviour. + + Parameters + ---------- + handle : matplotlib Figure, optional + Figure to draw on. If *None* a new figure is created. + + Returns + ------- + matplotlib.figure.Figure + """ + numCovars = self.covarColl.nActCovar() + if handle is None: + fig = plt.figure(figsize=(9.0, max(4.0, 2.2 * max(numCovars, 1)))) + else: + fig = handle + fig.clear() + + if numCovars <= 1: + ax = fig.add_subplot(111) + self.covarColl.plot(handle=ax) + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=ax) + elif numCovars == 2: + ax1 = fig.add_subplot(1, 2, 1) + ax2 = fig.add_subplot(1, 2, 2) + self.covarColl.plot(handle=[ax1, ax2]) + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=[ax1, ax2]) + else: + axes = [fig.add_subplot(numCovars, 1, i + 1) + for i in range(numCovars)] + self.covarColl.plot(handle=axes) + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=axes) + + fig.tight_layout() + return fig + def plot(self, *_, handle=None, **__): + """Plot spike raster, covariates, and events in a multi-panel figure. + + Returns + ------- + matplotlib.figure.Figure + """ cov_count = max(self.covarColl.numCov, 1) event_count = 1 if self.ev is not None and self.ev.eventTimes.size else 0 panel_count = 1 + cov_count + event_count @@ -2110,21 +2445,25 @@ def plot(self, *_, handle=None, **__): return fig def setSampleRate(self, sampleRate: float) -> None: + """Resample spikes, covariates, and ensemble covariates to *sampleRate*.""" self.sampleRate = float(sampleRate) self.nspikeColl.resample(sampleRate) self.covarColl.resample(sampleRate) self.resampleEnsColl() def resample(self, sampleRate: float) -> None: + """Alias for :meth:`setSampleRate`.""" self.setSampleRate(sampleRate) def setEnsCovMask(self, mask=None) -> None: + """Set the ensemble-covariate neighbour mask (default: all-to-all minus self).""" if _is_empty_config_value(mask): nSpikes = self.nspikeColl.numSpikeTrains mask = np.ones((nSpikes, nSpikes), dtype=int) - np.eye(nSpikes, dtype=int) self.ensCovMask = np.asarray(mask, dtype=int) def setCovMask(self, mask) -> None: + """Set the covariate mask; ``'all'`` resets to full visibility.""" if isinstance(mask, str) and mask == "all": self.covarColl.resetMask() else: @@ -2132,21 +2471,33 @@ def setCovMask(self, mask) -> None: self.covMask = self.covarColl.covMask def resetCovMask(self) -> None: + """Reset the covariate mask to all-visible.""" self.covarColl.resetMask() self.covMask = self.covarColl.covMask def setNeuronMask(self, mask) -> None: + """Set the neuron (spike-train) mask and sync to ``self.neuronMask``.""" self.nspikeColl.setMask(mask) self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() def resetNeuronMask(self) -> None: + """Reset the neuron mask to all-visible.""" self.nspikeColl.resetMask() self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() def setNeighbors(self, *args) -> None: + """Set the neighbour structure for ensemble-history covariates.""" self.nspikeColl.setNeighbors(*args) def setHistory(self, hist) -> None: + """Set the spike-history configuration. + + Parameters + ---------- + hist : History, array-like, or list[History] + A ``History`` object, an array of window-edge times (seconds), or + a list of ``History`` objects for per-neuron history orders. + """ if _is_empty_config_value(hist): self.history = [] return @@ -2175,9 +2526,18 @@ def setHistory(self, hist) -> None: raise TypeError("Can only set trial history by using History objects or windowTimes") def resetHistory(self) -> None: + """Clear the spike-history configuration.""" self.history = [] def setEnsCovHist(self, hist=None) -> None: + """Set the ensemble-covariate history and rebuild the ensemble collection. + + Parameters + ---------- + hist : History or array-like, optional + A ``History`` object or window-edge array. Passing ``None`` + clears the ensemble history and removes the ``ensCovColl``. + """ if _is_empty_config_value(hist): self.ensCovHist = [] self.ensCovColl = None @@ -2203,15 +2563,19 @@ def setEnsCovHist(self, hist=None) -> None: self.ensCovColl = self.getEnsembleNeuronCovariates(1, [], self.ensCovHist) def isNeuronMaskSet(self) -> bool: + """Return ``True`` if any neuron is currently masked out.""" return self.nspikeColl.isNeuronMaskSet() def isCovMaskSet(self) -> bool: + """Return ``True`` if any covariate dimension is currently masked out.""" return self.covarColl.isCovMaskSet() def isMaskSet(self) -> bool: + """Return ``True`` if either the neuron or covariate mask is active.""" return self.isNeuronMaskSet() or self.isCovMaskSet() def isHistSet(self) -> bool: + """Return ``True`` if a spike-history configuration has been set.""" if self.history in (None, []): return False from .history import History @@ -2221,6 +2585,7 @@ def isHistSet(self) -> bool: return isinstance(self.history, list) and bool(self.history) and all(isinstance(item, History) for item in self.history) def isEnsCovHistSet(self) -> bool: + """Return ``True`` if an ensemble-covariate history has been set.""" from .history import History return isinstance(self.ensCovHist, History) @@ -2254,6 +2619,7 @@ def getNumHist(self) -> int | list[int]: return 0 def addCov(self, cov: Covariate) -> None: + """Add a covariate and enforce consistent sample rate / time bounds.""" self.covarColl.addToColl(cov) self.covMask = self.covarColl.covMask if not self.isSampleRateConsistent(): @@ -2261,6 +2627,7 @@ def addCov(self, cov: Covariate) -> None: self.makeConsistentTime() def removeCov(self, identifier: int | str) -> None: + """Remove a covariate by 1-based index or name.""" self.covarColl.removeCovariate(identifier) self.covMask = self.covarColl.covMask if not self.isSampleRateConsistent(): @@ -2268,6 +2635,17 @@ def removeCov(self, identifier: int | str) -> None: self.makeConsistentTime() def getSpikeVector(self, *args, neuron_index: int = 1) -> np.ndarray: + """Return the spike data as a column matrix. + + Parameters + ---------- + *args + When empty, returns all neurons via ``dataToMatrix()``. An int + selects a single neuron (1-based). A sequence of bin edges + returns binned counts for the neuron given by *neuron_index*. + neuron_index : int, default 1 + Neuron to bin when *args* provides bin edges (1-based). + """ if not args: return self.nspikeColl.dataToMatrix() first = args[0] @@ -2282,9 +2660,16 @@ def getSpikeVector(self, *args, neuron_index: int = 1) -> np.ndarray: return self.nspikeColl.dataToMatrix(*args) def get_covariate_matrix(self, selected_covariates: Sequence[str] | None = None) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Return ``(time, data, names)`` for the covariate collection.""" return self.covarColl.matrixWithTime("standard", selected_covariates) def getDesignMatrix(self, neuronNum: int, dataSelector=None) -> np.ndarray: + """Build the full design matrix for neuron *neuronNum* (1-based). + + Horizontally concatenates covariates, spike-history columns, and + ensemble-history columns — the complete regressor matrix used by + the GLM fitter. + """ X = self.covarColl.dataToMatrix("standard", dataSelector) if self.isHistSet(): H = self.getHistMatrices(neuronNum) @@ -2305,6 +2690,19 @@ def getDesignMatrix(self, neuronNum: int, dataSelector=None) -> np.ndarray: return X def getHistForNeurons(self, neuronIndex) -> CovariateCollection: + """Compute the spike-history covariates for one neuron. + + Parameters + ---------- + neuronIndex : int + 1-based neuron index whose spike train supplies the history. + + Returns + ------- + CovariateCollection + Collection of history-basis covariates aligned to the trial + time grid. + """ if not self.isHistSet(): raise ValueError("Set Trial history and retry") nst = self.nspikeColl.getNST(neuronIndex) @@ -2319,6 +2717,7 @@ def getHistForNeurons(self, neuronIndex) -> CovariateCollection: return self.history.computeHistory(nst, time_grid=target_time) def getHistMatrices(self, neuronIndex: int) -> np.ndarray: + """Return the spike-history columns as a 2-D array for *neuronIndex* (1-based).""" if not self.isHistSet(): time = self.nspikeColl.getNST(neuronIndex).getSigRep().time return np.zeros((time.size, 0), dtype=float) @@ -2326,9 +2725,15 @@ def getHistMatrices(self, neuronIndex: int) -> np.ndarray: return histCovColl.dataToMatrix("standard") def getEnsembleNeuronCovariates(self, *args): + """Delegate to ``SpikeTrainCollection.getEnsembleNeuronCovariates``.""" return self.nspikeColl.getEnsembleNeuronCovariates(*args) def getEnsCovMatrix(self, neuronNum: int, includedNeurons=None) -> np.ndarray: + """Return the ensemble-covariate design-matrix columns for *neuronNum*. + + Uses ``ensCovMask`` to exclude self-history and applies neighbour + filtering when *includedNeurons* is not specified. + """ if not self.isEnsCovHistSet() or self.ensCovColl is None: return np.zeros((self.nspikeColl.getNST(neuronNum).getSigRep().time.size, 0), dtype=float) if includedNeurons is None: @@ -2339,18 +2744,23 @@ def getEnsCovMatrix(self, neuronNum: int, includedNeurons=None) -> np.ndarray: return ensCovCollTemp.dataToMatrix("standard") def getNeuronIndFromMask(self) -> list[int]: + """Return 1-based indices of currently unmasked neurons.""" return self.nspikeColl.getIndFromMask() def getNumUniqueNeurons(self) -> int: + """Return the number of distinct neuron names in the collection.""" return len(self.nspikeColl.uniqueNeuronNames) def getNeuronNames(self) -> list[str]: + """Return all neuron names (may contain duplicates for repeated trials).""" return self.nspikeColl.getNSTnames() def getUniqueNeuronNames(self) -> list[str]: + """Return deduplicated neuron names.""" return self.nspikeColl.getUniqueNSTnames() def getNeuronIndFromName(self, neuronName: str): + """Return 1-based indices matching *neuronName*, filtered by the neuron mask.""" tempInd = self.nspikeColl.getNSTIndicesFromName(neuronName) currMask = set(self.neuronMask_indices()) if isinstance(tempInd, list): @@ -2358,39 +2768,49 @@ def getNeuronIndFromName(self, neuronName: str): return [tempInd] if tempInd in currMask else [] def neuronMask_indices(self) -> list[int]: + """Return 1-based indices of unmasked neurons (alias for ``getNeuronIndFromMask``).""" return self.nspikeColl.getIndFromMask() def getNeuronNeighbors(self, neuronNum=None): + """Return the neighbour list for *neuronNum* (defaults to all unmasked neurons).""" if neuronNum is None: neuronNum = self.getNeuronIndFromMask() return self.nspikeColl.getNeighbors(neuronNum) def getCovSelectorFromMask(self): + """Return the per-covariate selector list derived from the current mask.""" return self.covarColl.getSelectorFromMasks() def getCov(self, identifier): + """Return a ``Covariate`` by 1-based index or name.""" return self.covarColl.getCov(identifier) def getNeuron(self, identifier): + """Return an ``nspikeTrain`` by 1-based index or name.""" return self.nspikeColl.getNST(identifier) def getAllCovLabels(self) -> list[str]: + """Return labels for all covariate dimensions (ignoring mask).""" return self.covarColl.getAllCovLabels() def getCovLabelsFromMask(self) -> list[str]: + """Return labels for only the currently unmasked covariate dimensions.""" return self.covarColl.getCovLabelsFromMask() def getHistLabels(self) -> list[str]: + """Return string labels for all spike-history basis columns.""" if not self.isHistSet(): return [] return self.getHistForNeurons(1).getAllCovLabels() def getEnsCovLabels(self) -> list[str]: + """Return string labels for all ensemble-covariate columns.""" if not self.isEnsCovHistSet() or self.ensCovColl is None: return [] return self.ensCovColl.getAllCovLabels() def getEnsCovLabelsFromMask(self, neuronNum: int) -> list[str]: + """Return ensemble-covariate labels for *neuronNum*, filtered by ``ensCovMask``.""" if not self.isEnsCovHistSet() or self.ensCovColl is None: return [] included = np.flatnonzero(self.ensCovMask[:, neuronNum - 1] == 1) + 1 @@ -2410,15 +2830,18 @@ def getAllLabels(self) -> list[str]: return labels def getLabelsFromMask(self, neuronNum: int) -> list[str]: + """Return all design-matrix labels for *neuronNum*, respecting masks.""" labels = list(self.getCovLabelsFromMask()) labels.extend(self.getHistLabels()) labels.extend(self.getEnsCovLabelsFromMask(neuronNum)) return labels def flattenCovMask(self) -> np.ndarray: + """Flatten the per-covariate mask list into a single 1-D int array.""" return self.covarColl.flattenCovMask() def flattenMask(self) -> np.ndarray: + """Flatten the full mask (covariates + history + ensemble) into 1-D.""" flat = self.flattenCovMask() if self.isHistSet(): flat = np.concatenate([flat, np.ones(len(self.getHistLabels()), dtype=int)]) @@ -2427,19 +2850,23 @@ def flattenMask(self) -> np.ndarray: return flat def shiftCovariates(self, *args) -> None: + """Apply a time shift to covariates and re-synchronize time bounds.""" self.covarColl.setCovShift(*args) self.makeConsistentTime() def resetEnsCovMask(self) -> None: + """Reset the ensemble-covariate mask to the default (all-to-all minus self).""" self.setEnsCovMask() def resampleEnsColl(self) -> None: + """Rebuild the ensemble-covariate collection at the current sample rate.""" if self.ensCovColl is not None and self.ensCovHist not in (None, []): self.ensCovColl = self.getEnsembleNeuronCovariates(1, [], self.ensCovHist) else: self.setEnsCovHist() def restoreToOriginal(self) -> None: + """Reset all collections to their original state and re-synchronize.""" self.nspikeColl.restoreToOriginal() self.covarColl.restoreToOriginal() if not self.isSampleRateConsistent(): @@ -2496,13 +2923,16 @@ def fromStructure(structure: dict[str, Any]) -> "Trial": return trial def makeConsistentSampleRate(self) -> None: + """Resample all collections to the maximum sample rate found.""" self.resample(self.findMaxSampleRate()) def makeConsistentTime(self) -> None: + """Set all collections to the union of min/max time across sub-collections.""" self.setMinTime(self.findMinTime()) self.setMaxTime(self.findMaxTime()) def isSampleRateConsistent(self) -> bool: + """Return ``True`` if spike and covariate collections share the same sample rate.""" if self.nspikeColl.numSpikeTrains == 0 or self.covarColl.numCov == 0: return True target = round(float(self.findMaxSampleRate()), 3) @@ -2510,6 +2940,7 @@ def isSampleRateConsistent(self) -> bool: return all(value == target for value in values) def findMaxSampleRate(self) -> float: + """Return the maximum sample rate across spike and covariate collections.""" values = [value for value in [self.nspikeColl.findMaxSampleRate(), self.covarColl.findMaxSampleRate()] if np.isfinite(value)] return float(max(values)) if values else float("nan") @@ -2536,9 +2967,11 @@ def findMinSampleRate(self) -> float: return float(min(candidates)) if candidates else float("nan") def findMinTime(self) -> float: + """Return the earliest start time across sub-collections.""" return float(min(self.nspikeColl.minTime, self.covarColl.minTime)) def findMaxTime(self) -> float: + """Return the latest end time across sub-collections.""" return float(max(self.nspikeColl.maxTime, self.covarColl.maxTime))