Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion nstat/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def run_analysis_for_neuron(
merged_lambda = merged_lambda.merge(part)

_restore_trial_partition(trial, original_partition)
return FitResult(
fit_result = FitResult(
spike_train,
labels,
numHist,
Expand All @@ -346,6 +346,10 @@ def run_analysis_for_neuron(
distributions,
fits=fits,
)
# MATLAB returns fits with KS diagnostics already populated, and
# downstream summary classes read those cached fields directly.
fit_result.computeKSStats()
return fit_result

@staticmethod
def run_analysis_for_all_neurons(
Expand Down
111 changes: 80 additions & 31 deletions nstat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None
self.plotProps = [plotProps for _ in range(self.dimension)]
else:
props = list(plotProps)
if len(props) == 0:
self.plotProps = [None for _ in range(self.dimension)]
return
if len(props) == 1 and self.dimension > 1:
props = props * self.dimension
if len(props) != self.dimension:
Expand Down Expand Up @@ -945,6 +948,9 @@ def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None:

def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> dict[str, Any]:
data = self.dataToMatrix(selectorArray)
plot_props = list(self.plotProps)
if all(prop is None for prop in plot_props):
plot_props = []
return {
"time": self.time.tolist(),
"data": data.tolist(),
Expand All @@ -953,7 +959,7 @@ def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = Non
"xunits": self.xunits,
"yunits": self.yunits,
"dataLabels": list(self.dataLabels),
"plotProps": list(self.plotProps),
"plotProps": plot_props,
}

def toStructure(self) -> dict[str, Any]:
Expand All @@ -974,6 +980,7 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj":

def plot(self, selectorArray=None, plotPropsIn=None, handle=None):
import matplotlib.pyplot as plt
from .confidence_interval import MATLAB_COLOR_ORDER

ax = plt.gca() if handle is None else handle
signal = self.getSubSignal(selectorArray) if selectorArray is not None else self.getSubSignal(self.findIndFromDataMask() or list(range(1, self.dimension + 1)))
Expand All @@ -989,6 +996,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None):
prop = props[index]
if isinstance(prop, str) and prop:
kwargs["fmt"] = prop
elif prop is None:
kwargs["color"] = MATLAB_COLOR_ORDER[index % MATLAB_COLOR_ORDER.shape[0]]
if "fmt" in kwargs:
fmt = kwargs.pop("fmt")
line = ax.plot(signal.time, signal.data[:, index], fmt, **kwargs)
Expand Down Expand Up @@ -1075,6 +1084,7 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None):
lines = super().plot(selectorArray, plotPropsIn, handle)
if self.isConfIntervalSet():
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

ax = plt.gca() if handle is None else handle
selectors = self.findIndFromDataMask() if selectorArray is None else (
Expand All @@ -1086,6 +1096,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None):
selectors = [item[0] for item in selectors]
for line_index, selector in enumerate(selectors):
color = getattr(lines[line_index], "get_color", lambda: "b")()
if isinstance(color, (str, bytes)):
color = mcolors.to_rgb(color)
self.ci[selector - 1].plot(color, ax=ax)
return lines

Expand Down Expand Up @@ -1176,17 +1188,37 @@ def toStructure(self) -> dict[str, Any]:
if self.isConfIntervalSet():
ci_payload: list[dict[str, Any]] = []
for item in self.ci or []:
if hasattr(item, "time") and hasattr(item, "bounds"):
ci_payload.append(
{
"time": np.asarray(item.time, dtype=float).tolist(),
"bounds": np.asarray(item.bounds, dtype=float).tolist(),
"color": getattr(item, "color", "b"),
}
)
structure["ci"] = ci_payload
if hasattr(item, "dataToStructure"):
ci_payload.append(item.dataToStructure())
if ci_payload:
structure["ci"] = ci_payload[0] if len(ci_payload) == 1 else ci_payload
return structure

@staticmethod
def fromStructure(structure: dict[str, Any]) -> "Covariate":
from .confidence_interval import ConfidenceInterval

cov = Covariate(
structure["time"],
structure["data"],
structure.get("name", ""),
structure.get("xlabelval", "time"),
structure.get("xunits", "s"),
structure.get("yunits", ""),
structure.get("dataLabels"),
structure.get("plotProps"),
)
ci_payload = structure.get("ci")
if ci_payload is None:
return cov
if isinstance(ci_payload, list):
cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload])
elif isinstance(ci_payload, tuple):
cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload])
else:
cov.setConfInterval(ConfidenceInterval.fromStructure(ci_payload))
return cov


class nspikeTrain:
"""Closer MATLAB-style spike-train object with cached signal representation."""
Expand Down Expand Up @@ -1405,11 +1437,15 @@ def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> Sign
return sig

def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj:
self.sigRep = self.getSigRep(binwidth, minTime, maxTime)
self.isSigRepBin = self.isSigRepBinary()
self.sampleRate = float(self.sigRep.sampleRate)
self.minTime = float(self.sigRep.minTime)
self.maxTime = float(self.sigRep.maxTime)
sig = self.getSigRep(binwidth, minTime, maxTime)
self.sigRep = sig.copySignal()
self.sampleRate = float(sig.sampleRate)
self.isSigRepBin = bool(np.max(np.asarray(sig.data, dtype=float)) <= 1.0)
# Keep the freshly-built cached representation alive instead of
# clearing it through the public min/max setters.
self.minTime = float(sig.minTime)
self.maxTime = float(sig.maxTime)
self.computeStatistics(-1)
return self.sigRep

def clearSigRep(self) -> None:
Expand All @@ -1420,14 +1456,12 @@ def clearSigRep(self) -> None:
def setMinTime(self, minTime: float) -> None:
self.minTime = float(minTime)
self.clearSigRep()
if self.avgFiringRate is not None:
self.computeStatistics(-1)
self.computeStatistics(-1)

def setMaxTime(self, maxTime: float) -> None:
self.maxTime = float(maxTime)
self.clearSigRep()
if self.avgFiringRate is not None:
self.computeStatistics(-1)
self.computeStatistics(-1)

def resample(self, sampleRate: float) -> "nspikeTrain":
self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime)
Expand Down Expand Up @@ -1476,8 +1510,9 @@ def getMaxBinSizeBinary(self) -> float:
return float(np.min(isi))

def isSigRepBinary(self) -> bool:
if self.isSigRepBin is None:
self.getSigRep()
default_key = self._cache_key(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime))
if self._sigrep_cache_key != default_key or self.isSigRepBin is None:
self.getSigRep(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime))
return bool(self.isSigRepBin)

def computeRate(self) -> SignalObj:
Expand Down Expand Up @@ -1553,14 +1588,19 @@ def plotJointISIHistogram(self):
ax = plt.subplots(1, 1, figsize=(4.5, 4.0))[1]
isi = self.getISIs()
if isi.size >= 2:
ax.loglog(isi[:-1], isi[1:], ".")
xvals = np.asarray(isi[:-1], dtype=float).reshape(-1)
yvals = np.asarray(isi[1:], dtype=float).reshape(-1)
ax.loglog(xvals, yvals, ".")
mean_isi = float(np.mean(isi))
ln = isi[isi < mean_isi]
ml = float(np.mean(ln)) if ln.size else np.nan
if np.isfinite(ml) and ml > 0:
v = ax.axis()
ax.loglog([ml, ml], [v[2], v[3]], "k--")
ax.loglog([v[0], v[1]], [ml, ml], "k--")
ymin = float(np.min(yvals))
ymax = float(np.max(yvals))
xmin = float(np.min(xvals))
xmax = float(np.max(xvals))
ax.loglog([ml, ml], [ymin, ymax], "k--")
ax.loglog([xmin, xmax], [ml, ml], "k--")
ax.set_xlabel("ISI(t) [s]")
ax.set_ylabel("ISI(t+1) [s]")
return ax
Expand All @@ -1582,14 +1622,21 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None =
bins = np.arange(0.0, float(np.max(isi)) + bin_width, bin_width, dtype=float)
if bins.size < 2:
bins = np.array([0.0, bin_width], dtype=float)
counts, edges = np.histogram(isi, bins=bins)
centers = edges[:-1]
idx = np.searchsorted(bins, isi, side="right") - 1
idx = np.where(
np.isclose(isi, bins[-1], rtol=0.0, atol=max(1e-12, bin_width * 1e-9)),
bins.size - 1,
idx,
)
idx = np.clip(idx, 0, bins.size - 1)
counts = np.bincount(idx, minlength=bins.size).astype(float)
centers = bins
ax.bar(
centers,
counts,
width=bin_width,
align="edge",
edgecolor=(0.0, 0.0, 0.0),
edgecolor="none",
linewidth=2.0,
color=(0.831372559070587, 0.815686285495758, 0.7843137383461),
)
Expand All @@ -1600,7 +1647,6 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None =

def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = None, handle=None):
import matplotlib.pyplot as plt
from scipy import stats

ax = plt.gca() if handle is None else handle
if maxTime is None:
Expand All @@ -1610,8 +1656,11 @@ def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = Non
isi = self.getISIs(minTime, maxTime)
ax.clear()
if isi.size:
stats.probplot(isi, dist=stats.expon, plot=ax)
ax.set_title(ax.get_title() or "Probability Plot")
sorted_isi = np.sort(np.asarray(isi, dtype=float).reshape(-1))
n = sorted_isi.size
p = (np.arange(1, n + 1, dtype=float) - 0.5) / float(n)
exp_quantiles = -np.log(1.0 - p)
ax.plot(sorted_isi, exp_quantiles, linestyle="none", marker=".")
return ax

def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None):
Expand Down
Loading