diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd4cb7e..72b9973 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,6 +52,9 @@ jobs: run: python -m ndi check - name: Run tests with coverage + env: + NDI_CLOUD_USERNAME: ${{ secrets.TEST_USER_2_USERNAME }} + NDI_CLOUD_PASSWORD: ${{ secrets.TEST_USER_2_PASSWORD }} run: | # Use sys.monitoring (PEP 669) on Python 3.12+ for faster coverage. # CTracer (sys.settrace) is catastrophically slow on 3.12 when diff --git a/ndi_install.py b/ndi_install.py index f255cf7..d30df22 100644 --- a/ndi_install.py +++ b/ndi_install.py @@ -41,6 +41,14 @@ "python_path": ".", "description": "VH-Lab data utilities and file formats (not on PyPI)", }, + { + "name": "NDIcalc-vis-matlab", + "repo": "https://github.com/VH-Lab/NDIcalc-vis-matlab.git", + "branch": "main", + "python_path": "", + "ndi_common": True, + "description": "NDI calculator and visualization document definitions", + }, ] DEFAULT_TOOLS_DIR = Path.home() / ".ndi" / "tools" @@ -268,6 +276,8 @@ def write_pth_file(site_packages: Path, tools_dir: Path) -> Path | None: lines = [] for dep in DEPENDENCIES: + if not dep.get("python_path"): + continue # No Python code to add to path dep_dir = tools_dir / dep["name"] python_path = dep_dir / dep["python_path"] if dep["python_path"] != "." else dep_dir if python_path.is_dir(): @@ -290,6 +300,56 @@ def write_pth_file(site_packages: Path, tools_dir: Path) -> Path | None: return None +# --------------------------------------------------------------------------- +# ndi_common document definitions from external dependencies +# --------------------------------------------------------------------------- + + +def install_ndi_common_docs(tools_dir: Path, ndi_root: Path) -> bool: + """Copy ndi_common/{database,schema}_documents from external deps. + + Some dependencies (e.g. NDIcalc-vis-matlab) ship document type + definitions that NDI-python needs at runtime. This copies their + ``ndi_common/database_documents`` and ``ndi_common/schema_documents`` + trees into NDI-python's own ``ndi_common`` folder so they are + discoverable via ``ndi_common_PathConstants.DOCUMENT_PATH``. + """ + import shutil + + ndi_common = ndi_root / "src" / "ndi" / "ndi_common" + ok = True + + for dep in DEPENDENCIES: + if not dep.get("ndi_common"): + continue + dep_dir = tools_dir / dep["name"] + dep_common = dep_dir / "ndi_common" + if not dep_common.is_dir(): + warn(f"{dep['name']}: ndi_common folder not found at {dep_common}") + ok = False + continue + + for sub in ("database_documents", "schema_documents"): + src = dep_common / sub + dst = ndi_common / sub + if not src.is_dir(): + continue + count = 0 + for src_file in src.rglob("*"): + if src_file.is_dir(): + continue + rel = src_file.relative_to(src) + dst_file = dst / rel + dst_file.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_file, dst_file) + count += 1 + detail(f"Copied {count} {sub} files from {dep['name']}") + + success(f"Installed document definitions from {dep['name']}") + + return ok + + # --------------------------------------------------------------------------- # pip installation # --------------------------------------------------------------------------- @@ -529,6 +589,8 @@ def main() -> int: fail("Could not find site-packages directory") warn("You may need to set PYTHONPATH manually:") for dep in DEPENDENCIES: + if not dep.get("python_path"): + continue dep_dir = tools_dir / dep["name"] python_path = dep_dir / dep["python_path"] if dep["python_path"] != "." else dep_dir warn(f" {python_path}") @@ -546,6 +608,8 @@ def main() -> int: importlib.reload(site) # Add paths directly for this process for dep in DEPENDENCIES: + if not dep.get("python_path"): + continue # No Python code to add to path dep_dir = tools_dir / dep["name"] python_path = ( str(dep_dir / dep["python_path"]) if dep["python_path"] != "." else str(dep_dir) @@ -564,6 +628,9 @@ def main() -> int: if not install_ndi_and_deps(ndi_root, include_dev=args.dev): warn("Some packages may not have installed correctly") + # Copy document definitions from external dependencies + install_ndi_common_docs(tools_dir, ndi_root) + # ── Step 5: Validate ─────────────────────────────────────────────── if args.no_validate: print("\n[5/5] Validation skipped (--no-validate)") diff --git a/pyproject.toml b/pyproject.toml index 6e90fa3..60e0fab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "did @ git+https://github.com/VH-Lab/DID-python.git@main", "ndr @ git+https://github.com/VH-lab/NDR-python.git@main", "vhlab-toolbox-python @ git+https://github.com/VH-Lab/vhlab-toolbox-python.git@main", + "ndi-compress @ git+https://github.com/Waltham-Data-Science/NDI-compress-python.git@main", "numpy>=1.20.0", "networkx>=2.6", "jsonschema>=4.0.0", diff --git a/src/ndi/cloud/orchestration.py b/src/ndi/cloud/orchestration.py index 6677ef5..46f682d 100644 --- a/src/ndi/cloud/orchestration.py +++ b/src/ndi/cloud/orchestration.py @@ -84,7 +84,6 @@ def downloadDataset( from ndi.dataset import ndi_dataset_dir documents = jsons2documents(doc_jsons) - conversion_lost = len(doc_jsons) - len(documents) dataset = ndi_dataset_dir("", target, documents=documents) # Create remote link document if not already present @@ -113,85 +112,76 @@ def downloadDataset( if verbose: print(f' Files downloaded: {report["downloaded"]}, failed: {report["failed"]}') - # Collect failures: conversion + exception-tracked + silent (DID-python) - add_failures: list[tuple[str, str]] = list(getattr(dataset, "add_doc_failures", [])) - - # Cross-check using raw DID-python doc IDs (not isa('base') query, - # which might miss documents whose type info wasn't stored correctly). + # Verify every downloaded document made it into the local database. + # The local dataset may have *more* documents (e.g. session and + # session-in-a-dataset docs created internally), so we only check + # that every remote doc ID is present locally. db_ids = set( dataset._session._database._driver._db.get_doc_ids( dataset._session._database._driver._branch_id ) ) - # Build a map from doc_id -> original JSON for missing-doc output - doc_json_by_id: dict[str, dict] = {} + missing: list[str] = [] + missing_jsons: list[dict] = [] for dj in doc_jsons: did = dj.get("base", {}).get("id", "") if isinstance(dj, dict) else "" - if did: - doc_json_by_id[did] = dj - - # Find documents that were "added" (no exception) but aren't in the DB - tracked_ids = {f[0] for f in add_failures} - silent_failures: list[str] = [] - for doc in documents: - doc_id = ( - doc.document_properties.get("base", {}).get("id", "") - if hasattr(doc, "document_properties") - else doc.get("base", {}).get("id", "") - ) - if doc_id and doc_id not in db_ids and doc_id not in tracked_ids: - silent_failures.append(doc_id) - - total_lost = conversion_lost + len(add_failures) + len(silent_failures) + if did and did not in db_ids: + missing.append(did) + missing_jsons.append(dj) if verbose: print("Download complete.") - if total_lost > 0: - # Write missing documents to a JSON file for inspection - missing_docs_path = target / "missingDocuments.json" - missing_docs = [] - for doc_id in silent_failures: - if doc_id in doc_json_by_id: - missing_docs.append(doc_json_by_id[doc_id]) + if missing: + # Print the document_class of each missing doc for diagnostics. + # Session/dataset docs from older datasets are expected to be + # missing (superseded by docs created locally during dataset init). + session_dataset_types = { + "ndi_session", + "ndi_dataset", + "session", + "dataset", + "session_in_a_dataset", + "dataset_session_info", + } + real_missing: list[tuple[str, str]] = [] + for doc_id, dj in zip(missing, missing_jsons): + doc_class = ( + dj.get("document_class", {}).get("class_name", "") if isinstance(dj, dict) else "" + ) + superclasses = ( + dj.get("document_class", {}).get("superclasses", []) if isinstance(dj, dict) else [] + ) + all_types = {doc_class} | { + sc.get("class_name", "") if isinstance(sc, dict) else str(sc) + for sc in (superclasses if isinstance(superclasses, list) else []) + } + if all_types & session_dataset_types: + print( + f" Note: remote doc {doc_id} (class: {doc_class}) " + f"not in local DB — expected for session/dataset docs" + ) else: - missing_docs.append({"base": {"id": doc_id}}) - for doc_id, reason in add_failures: - entry = dict(doc_json_by_id.get(doc_id, {"base": {"id": doc_id}})) - entry["_add_error"] = reason - missing_docs.append(entry) - if missing_docs: - import json + print(f" WARNING: remote doc {doc_id} (class: {doc_class}) missing from local DB") + real_missing.append((doc_id, doc_class)) - missing_docs_path.write_text(json.dumps(missing_docs, indent=2, default=str)) + if real_missing: + missing_docs_path = target / "missingDocuments.json" + import json - lines = [ - f"Downloaded {len(doc_jsons)} documents but only " - f"{len(db_ids)} were added to the dataset. " - f"{total_lost} document(s) lost:" - ] - if conversion_lost > 0: - lines.append(f"\n{conversion_lost} failed to convert from JSON" " to ndi_document") - if add_failures: - lines.append(f"\n{len(add_failures)} raised errors during" " database add:") - for doc_id, reason in add_failures[:50]: - lines.append(f"\n - {doc_id}: {reason}") - if len(add_failures) > 50: - lines.append(f"\n ... and {len(add_failures) - 50} more") - if silent_failures: - lines.append( - f"\n{len(silent_failures)} were passed to" - " database.add() without error but are NOT in the" - " database (possible DID-python bug):" - ) - for doc_id in silent_failures[:50]: - lines.append(f"\n - {doc_id}") - if len(silent_failures) > 50: - lines.append(f"\n ... and {len(silent_failures) - 50} more") - if missing_docs: - lines.append(f"\nFull JSON of missing documents written to:" f"\n {missing_docs_path}") - raise RuntimeError("".join(lines)) + missing_docs_path.write_text(json.dumps(missing_jsons, indent=2, default=str)) + + lines = [ + f"Downloaded {len(doc_jsons)} documents but " + f"{len(real_missing)} are missing from the local dataset:" + ] + for doc_id, doc_class in real_missing[:50]: + lines.append(f"\n - {doc_id} (class: {doc_class})") + if len(real_missing) > 50: + lines.append(f"\n ... and {len(real_missing) - 50} more") + lines.append(f"\nFull JSON of missing documents written to:\n {missing_docs_path}") + raise RuntimeError("".join(lines)) return dataset diff --git a/src/ndi/daq/mfdaq.py b/src/ndi/daq/mfdaq.py index c085aeb..9135af6 100644 --- a/src/ndi/daq/mfdaq.py +++ b/src/ndi/daq/mfdaq.py @@ -78,6 +78,19 @@ class ChannelInfo: scale: float = 1.0 group: int = 1 + @classmethod + def from_dict(cls, d: dict) -> ChannelInfo: + return cls( + name=d.get("name", ""), + type=d.get("type", ""), + time_channel=d.get("time_channel"), + number=d.get("number"), + sample_rate=d.get("sample_rate"), + offset=d.get("offset", 0.0), + scale=d.get("scale", 1.0), + group=d.get("group", 1), + ) + def standardize_channel_type(channel_type: str | ChannelType) -> str: """ @@ -403,13 +416,17 @@ def epochsamples2times( samples: np.ndarray, ) -> np.ndarray: """ - Convert sample indices to time. + Convert 0-based sample indices to time. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) channel: Channel number(s) epochfiles: Files for this epoch - samples: Sample indices (1-indexed) + samples: Sample indices (0-based) Returns: Time values @@ -429,7 +446,7 @@ def epochsamples2times( t0 = t0t1[0][0] samples = np.asarray(samples) - t = t0 + (samples - 1) / sr + t = t0 + samples / sr # Handle infinite values if np.any(np.isinf(samples)): @@ -445,7 +462,11 @@ def epochtimes2samples( times: np.ndarray, ) -> np.ndarray: """ - Convert time to sample indices. + Convert time to 0-based sample indices. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) @@ -454,7 +475,7 @@ def epochtimes2samples( times: Time values Returns: - Sample indices (1-indexed) + Sample indices (0-based) """ if isinstance(channel, int): channel = [channel] @@ -471,11 +492,11 @@ def epochtimes2samples( t0 = t0t1[0][0] times = np.asarray(times) - s = 1 + np.round((times - t0) * sr).astype(int) + s = np.round((times - t0) * sr).astype(int) # Handle infinite values if np.any(np.isinf(times)): - s[np.isinf(times) & (times < 0)] = 1 + s[np.isinf(times) & (times < 0)] = 0 return s @@ -555,8 +576,8 @@ def getchannelsepoch_ingested( """ List channels for an ingested epoch. - Retrieves channel information from the ingested document stored - in the database. + Reads channel information from the ``channel_list.bin`` binary file + attached to the ingested document, matching the MATLAB approach. Args: epochfiles: List of file paths (starting with epochid://) @@ -564,30 +585,31 @@ def getchannelsepoch_ingested( Returns: List of ChannelInfo objects - - See also: getchannelsepoch """ doc = self.getingesteddocument(epochfiles, session) - et = doc.document_properties["daqreader_epochdata_ingested"]["epochtable"] - - channels_raw = et.get("channels", []) - channels = [] - - for ch_dict in channels_raw: - channels.append( - ChannelInfo( - name=ch_dict.get("name", ""), - type=ch_dict.get("type", "analog_in"), - time_channel=ch_dict.get("time_channel"), - number=ch_dict.get("number"), - sample_rate=ch_dict.get("sample_rate"), - offset=ch_dict.get("offset", 0.0), - scale=ch_dict.get("scale", 1.0), - group=ch_dict.get("group", 1), - ) - ) - - return channels + try: + fobj = session.database_openbinarydoc(doc, "channel_list.bin") + tname = fobj.name + fobj.close() + from ..file.type.mfdaq_epoch_channel import ndi_file_type_mfdaq__epoch__channel + + mec = ndi_file_type_mfdaq__epoch__channel() + mec.readFromFile(tname) + return mec.channel_information + except Exception as _exc: + # Fallback: try reading from epochtable JSON (older format) + et = doc.document_properties.get( + "daqreader_mfdaq_epochdata_ingested", + doc.document_properties.get("daqreader_epochdata_ingested", {}), + ).get("epochtable", {}) + channels_raw = et.get("channels", []) + if channels_raw: + return [ChannelInfo.from_dict(ch) for ch in channels_raw] + # Neither path worked — raise with context + raise ValueError( + f"Cannot read channel info: channel_list.bin failed ({_exc}), " + f"and no channels in epochtable JSON" + ) from _exc def readchannels_epochsamples_ingested( self, @@ -601,8 +623,8 @@ def readchannels_epochsamples_ingested( """ Read channel data from an ingested epoch. - Retrieves the data from the binary file referenced by the - ingested document in the database. + Reads compressed segment files (``ai_group*_seg.nbf_*``) from the + ingested document using ``ndicompress``, matching the MATLAB approach. Args: channeltype: Type(s) of channel to read @@ -614,39 +636,283 @@ def readchannels_epochsamples_ingested( Returns: Array with shape (num_samples, num_channels) - - See also: readchannels_epochsamples """ + import ndicompress + doc = self.getingesteddocument(epochfiles, session) - et = doc.document_properties["daqreader_epochdata_ingested"]["epochtable"] # Normalize inputs if isinstance(channel, int): channel = [channel] if isinstance(channeltype, str): channeltype = [channeltype] * len(channel) - channeltype = standardize_channel_types(channeltype) - # Get data file reference from document - data_file = et.get("data_file", None) - if data_file is None: - return np.full((s1 - s0 + 1, len(channel)), np.nan) + ch_unique = list(set(channeltype)) + if len(ch_unique) != 1: + raise ValueError("Only one type of channel may be read per function call") + + # Get sample rate, offset, scale + sr, offset, scale = self.samplerate_ingested(epochfiles, channeltype, channel, session) + sr_unique = np.unique(sr) + if len(sr_unique) != 1: + raise ValueError("Cannot handle different sampling rates across channels") + + # Handle infinite bounds + t0_t1 = self.t0_t1_ingested(epochfiles, session) + abs_s = self.epochtimes2samples_ingested( + channeltype, channel, epochfiles, np.array(t0_t1[0]), session + ) + if np.isinf(s0): + s0 = int(abs_s[0]) + if np.isinf(s1): + s1 = int(abs_s[1]) + + # Get channel info for group decoding + full_channel_info = self.getchannelsepoch_ingested(epochfiles, session) + + from ..file.type.mfdaq_epoch_channel import ndi_file_type_mfdaq__epoch__channel + + groups, ch_idx_in_groups, ch_idx_in_output = ( + ndi_file_type_mfdaq__epoch__channel.channelgroupdecoding( + full_channel_info, ch_unique[0], channel + ) + ) + + # Determine segment parameters and file prefix + props = doc.document_properties + mfdaq_params = props.get("daqreader_mfdaq_epochdata_ingested", {}).get("parameters", {}) + + analog_types = {"analog_in", "analog_out", "auxiliary_in", "auxiliary_out"} + digital_types = {"digital_in", "digital_out"} + + if ch_unique[0] in analog_types: + samples_segment = mfdaq_params.get("sample_analog_segment", 1_000_000) + expand_fn = ndicompress.expand_ephys + elif ch_unique[0] in digital_types: + samples_segment = mfdaq_params.get("sample_digital_segment", 1_000_000) + expand_fn = ndicompress.expand_digital + elif ch_unique[0] == "time": + samples_segment = mfdaq_params.get("sample_analog_segment", 1_000_000) + expand_fn = ndicompress.expand_time + else: + raise ValueError(f"Unknown channel type {ch_unique[0]}. Use readevents for events.") + + # Map channel type to file prefix + prefix_map = { + "analog_in": "ai", + "analog_out": "ao", + "auxiliary_in": "ax", + "auxiliary_out": "ax", + "digital_in": "di", + "digital_out": "do", + "time": "ti", + } + prefix = prefix_map.get(ch_unique[0], ch_unique[0]) + + # Read segments — s0/s1 are 0-based Python indices + seg_start = (s0 // samples_segment) + 1 # 1-based segment number + seg_stop = (s1 // samples_segment) + 1 + + data = np.full((s1 - s0 + 1, len(channel)), np.nan) + count = 0 + + for seg in range(seg_start, seg_stop + 1): + # Compute 0-based sample range within this segment + seg_offset = (seg - 1) * samples_segment # 0-based start of segment + if seg == seg_start: + s0_ = s0 - seg_offset # 0-based within segment + else: + s0_ = 0 + if seg == seg_stop: + s1_ = s1 - seg_offset # 0-based within segment + else: + s1_ = samples_segment - 1 + + n_samples_here = s1_ - s0_ + 1 + + for g_idx, grp in enumerate(groups): + fname = f"{prefix}_group{grp}_seg.nbf_{seg}" + try: + fobj = session.database_openbinarydoc(doc, fname) + tname = fobj.name + fobj.close() + + # Remove .tgz extension for ndicompress (it adds it back) + tname_base = tname + if tname_base.endswith(".tgz"): + tname_base = tname_base[:-4] + if tname_base.endswith(".nbf"): + tname_base = tname_base[:-4] + + result = expand_fn(tname_base) + # expand_* functions return (data, error_signal) tuple + data_here = result[0] if isinstance(result, tuple) else result + + # Handle last segment possibly having fewer samples + if data_here.shape[0] <= s1_: + s1_ = data_here.shape[0] - 1 + n_samples_here = s1_ - s0_ + 1 + + rows = slice(count, count + n_samples_here) + data[rows, ch_idx_in_output[g_idx]] = data_here[ + s0_ : s1_ + 1, ch_idx_in_groups[g_idx] + ] + except Exception as seg_exc: + import logging + + logging.getLogger("ndi").warning( + "readchannels_epochsamples_ingested: segment %s failed: %s", + fname, + seg_exc, + ) + + count += n_samples_here + + # Trim if last segment was shorter + if count < data.shape[0]: + data = data[:count, :] + + # Apply underlying2scaled: (data - offset) * scale + data = (data - np.array(offset)) * np.array(scale) + + return data + + def readevents_epochsamples_ingested( + self, + channeltype: str | list[str], + channel: int | list[int], + epochfiles: list[str], + t0: float, + t1: float, + session: Any, + ) -> tuple[list[np.ndarray] | np.ndarray, list[np.ndarray] | np.ndarray]: + """ + Read event/marker/text data from an ingested epoch. + + Matches MATLAB ``readevents_epochsamples_ingested``. For derived + digital event types (dep/den/dimp/dimn), reads digital channels + and detects transitions. For native events/markers/text, reads + from compressed ``evmktx_group*_seg.nbf_*`` files. + + Args: + channeltype: Event channel type(s) + channel: Channel number(s) + epochfiles: Files for this epoch (starting with epochid://) + t0: Start time + t1: End time + session: ndi_session object with database access + + Returns: + Tuple of (timestamps, data) + """ + import ndicompress + + if isinstance(channel, int): + channel = [channel] + if isinstance(channeltype, str): + channeltype = [channeltype] * len(channel) + channeltype = standardize_channel_types(channeltype) - # Read from VHSB format + derived = {"dep", "den", "dimp", "dimn"} + if set(channeltype) & derived: + # Handle derived digital event types + timestamps_list = [] + data_list = [] + for i, ch_num in enumerate(zip(channel)): + sd = self.epochtimes2samples_ingested( + ["digital_in"], [ch_num], epochfiles, np.array([t0, t1]), session + ) + s0d, s1d = int(sd[0]), int(sd[1]) + data_here = self.readchannels_epochsamples_ingested( + ["digital_in"], [ch_num], epochfiles, s0d, s1d, session + ) + time_here = self.readchannels_epochsamples_ingested( + ["time"], [ch_num], epochfiles, s0d, s1d, session + ) + data_here = data_here.ravel() + time_here = time_here.ravel() + + ct = channeltype[i] + if ct in ("dep", "dimp"): + on_samples = np.where((data_here[:-1] == 0) & (data_here[1:] == 1))[0] + 1 + off_samples = ( + np.where((data_here[:-1] == 1) & (data_here[1:] == 0))[0] + 1 + if ct == "dimp" + else np.array([], dtype=int) + ) + else: # den, dimn + on_samples = np.where((data_here[:-1] == 1) & (data_here[1:] == 0))[0] + 1 + off_samples = ( + np.where((data_here[:-1] == 0) & (data_here[1:] == 1))[0] + 1 + if ct == "dimn" + else np.array([], dtype=int) + ) + + ts = np.concatenate([time_here[on_samples], time_here[off_samples]]) + dd = np.concatenate( + [ + np.ones(len(on_samples)), + -np.ones(len(off_samples)), + ] + ) + if len(off_samples) > 0: + order = np.argsort(ts) + ts = ts[order] + dd = dd[order] + timestamps_list.append(ts) + data_list.append(dd) + + if len(channel) == 1: + return timestamps_list[0], data_list[0] + return timestamps_list, data_list + + # Native events/markers/text + doc = self.getingesteddocument(epochfiles, session) + fname = "evmktx_group1_seg.nbf_1" try: - from vlt.file.custom_file_formats import vhsb_read + fobj = session.database_openbinarydoc(doc, fname) + tname = fobj.name + fobj.close() + tname_base = tname + if tname_base.endswith(".tgz"): + tname_base = tname_base[:-4] + if tname_base.endswith(".nbf"): + tname_base = tname_base[:-4] + ct_out, ch_out, T, D = ndicompress.expand_eventmarktext(tname_base) + except Exception as exc: + raise ValueError(f"No event data found for this epoch: {exc}") from exc + + # Standardize the output channel types for matching + if isinstance(ct_out, list): + ct_out_std = standardize_channel_types(ct_out) + else: + ct_out_std = standardize_channel_types(list(ct_out)) + + timestamps_list = [] + data_list = [] + for ct, ch_num in zip(channeltype, channel): + ct_std = standardize_channel_type(ct) + matches = [ + j + for j, (cto, cho) in enumerate(zip(ct_out_std, ch_out)) + if cto == ct_std and cho == ch_num + ] + if not matches: + raise ValueError(f"Channel type {ct} and channel {ch_num} not found in event data") + idx = matches[0] + ts = np.asarray(T[idx]) + dd = np.asarray(D[idx]) + # Filter by time range + included = (ts >= t0) & (ts <= t1) + if ts.ndim > 1: + included = included[:, 0] if included.ndim > 1 else included + timestamps_list.append(ts[included]) + data_list.append(dd[included] if dd.ndim <= 1 else dd[included]) - data = vhsb_read( - data_file, - channels=channel, - sample_start=s0, - sample_end=s1, - ) - return data - except ImportError: - # Fallback: return NaN if vlt not available - return np.full((s1 - s0 + 1, len(channel)), np.nan) + if len(channel) == 1: + return timestamps_list[0], data_list[0] + return timestamps_list, data_list def samplerate_ingested( self, @@ -654,9 +920,12 @@ def samplerate_ingested( channeltype: str | list[str], channel: int | list[int], session: Any, - ) -> np.ndarray: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ - Get sample rate for channels from an ingested epoch. + Get sample rate, offset, and scale for channels from an ingested epoch. + + Reads channel metadata from the ``channel_list.bin`` binary file, + matching the MATLAB approach which returns (sr, offset, scale). Args: epochfiles: Files for this epoch (starting with epochid://) @@ -665,28 +934,37 @@ def samplerate_ingested( session: ndi_session object with database access Returns: - Array of sample rates - - See also: samplerate + Tuple of (sample_rates, offsets, scales) arrays """ if isinstance(channel, int): channel = [channel] + if isinstance(channeltype, str): + channeltype = [channeltype] * len(channel) + channeltype = standardize_channel_types(channeltype) - # Get channels from ingested document - channels = self.getchannelsepoch_ingested(epochfiles, session) - - # Build lookup by channel number - sr_lookup = {} - for ch in channels: - if ch.number is not None and ch.sample_rate is not None: - sr_lookup[ch.number] = ch.sample_rate + full_channel_info = self.getchannelsepoch_ingested(epochfiles, session) - # Return sample rates for requested channels sr = np.zeros(len(channel)) - for i, ch_num in enumerate(channel): - sr[i] = sr_lookup.get(ch_num, np.nan) + offset = np.zeros(len(channel)) + scale = np.ones(len(channel)) + + for i, (ct, ch_num) in enumerate(zip(channeltype, channel)): + ct_std = standardize_channel_type(ct) + match = [ + ci + for ci in full_channel_info + if standardize_channel_type(ci.type) == ct_std and ci.number == ch_num + ] + if not match: + raise ValueError( + f"No such channel: {ct} : {ch_num}. " + f"Available: {[(ci.type, ci.number) for ci in full_channel_info[:5]]}" + ) + sr[i] = match[0].sample_rate + offset[i] = match[0].offset + scale[i] = match[0].scale - return sr + return sr, offset, scale def epochsamples2times_ingested( self, @@ -697,27 +975,29 @@ def epochsamples2times_ingested( session: Any, ) -> np.ndarray: """ - Convert sample indices to time for an ingested epoch. + Convert 0-based sample indices to time for an ingested epoch. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) channel: Channel number(s) epochfiles: Files for this epoch (starting with epochid://) - samples: Sample indices (1-indexed) + samples: Sample indices (0-based) session: ndi_session object with database access Returns: Time values - - See also: epochsamples2times """ if isinstance(channel, int): channel = [channel] if isinstance(channeltype, str): channeltype = [channeltype] * len(channel) - sr = self.samplerate_ingested(epochfiles, channeltype, channel, session) - sr_unique = np.unique(sr[~np.isnan(sr)]) + sr_arr, _, _ = self.samplerate_ingested(epochfiles, channeltype, channel, session) + sr_unique = np.unique(sr_arr) if len(sr_unique) != 1: raise ValueError("Cannot handle different sample rates across channels") sr = sr_unique[0] @@ -726,9 +1006,8 @@ def epochsamples2times_ingested( t0 = t0t1[0][0] samples = np.asarray(samples) - t = t0 + (samples - 1) / sr + t = t0 + samples / sr - # Handle infinite values if np.any(np.isinf(samples)): t[np.isinf(samples) & (samples < 0)] = t0 @@ -743,7 +1022,11 @@ def epochtimes2samples_ingested( session: Any, ) -> np.ndarray: """ - Convert time to sample indices for an ingested epoch. + Convert time to 0-based sample indices for an ingested epoch. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) @@ -753,17 +1036,15 @@ def epochtimes2samples_ingested( session: ndi_session object with database access Returns: - Sample indices (1-indexed) - - See also: epochtimes2samples + Sample indices (0-based) """ if isinstance(channel, int): channel = [channel] if isinstance(channeltype, str): channeltype = [channeltype] * len(channel) - sr = self.samplerate_ingested(epochfiles, channeltype, channel, session) - sr_unique = np.unique(sr[~np.isnan(sr)]) + sr_arr, _, _ = self.samplerate_ingested(epochfiles, channeltype, channel, session) + sr_unique = np.unique(sr_arr) if len(sr_unique) != 1: raise ValueError("Cannot handle different sample rates across channels") sr = sr_unique[0] @@ -772,10 +1053,9 @@ def epochtimes2samples_ingested( t0 = t0t1[0][0] times = np.asarray(times) - s = 1 + np.round((times - t0) * sr).astype(int) + s = np.round((times - t0) * sr).astype(int) - # Handle infinite values if np.any(np.isinf(times)): - s[np.isinf(times) & (times < 0)] = 1 + s[np.isinf(times) & (times < 0)] = 0 return s diff --git a/src/ndi/daq/ndi_matlab_python_bridge.yaml b/src/ndi/daq/ndi_matlab_python_bridge.yaml index 3f8d658..8cf3675 100644 --- a/src/ndi/daq/ndi_matlab_python_bridge.yaml +++ b/src/ndi/daq/ndi_matlab_python_bridge.yaml @@ -493,7 +493,9 @@ classes: - name: times type_python: "np.ndarray" decision_log: > - Exact match. Semantic Parity: samples are 1-indexed (user concept). + INDEXING DIFFERENCE: MATLAB samples are 1-indexed; Python samples + are 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. + The formula is t = t0 + sample/sr (Python) vs t = t0 + (sample-1)/sr (MATLAB). - name: epochtimes2samples input_arguments: @@ -513,7 +515,9 @@ classes: - name: samples type_python: "np.ndarray" decision_log: > - Exact match. Semantic Parity: returned samples are 1-indexed. + INDEXING DIFFERENCE: MATLAB returns 1-indexed samples; Python returns + 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. + The formula is s = round((t-t0)*sr) (Python) vs s = 1+round((t-t0)*sr) (MATLAB). - name: underlying_datatype input_arguments: @@ -623,7 +627,9 @@ classes: output_arguments: - name: times type_python: "np.ndarray" - decision_log: "Exact match." + decision_log: > + INDEXING DIFFERENCE: MATLAB samples are 1-indexed; Python samples + are 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. - name: epochtimes2samples_ingested input_arguments: @@ -645,7 +651,9 @@ classes: output_arguments: - name: samples type_python: "np.ndarray" - decision_log: "Exact match." + decision_log: > + INDEXING DIFFERENCE: MATLAB returns 1-indexed samples; Python returns + 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. # ========================================================================= # ndi.daq.system @@ -1027,7 +1035,9 @@ classes: - name: times type_python: "np.ndarray" decision_log: > - Semantic Parity: epoch_number and samples are 1-indexed. + INDEXING DIFFERENCE: epoch_number is 1-indexed (same as MATLAB). + Samples are 0-indexed in Python (1-indexed in MATLAB). + Sample 0 in Python corresponds to sample 1 in MATLAB. - name: epochtimes2samples input_arguments: @@ -1047,8 +1057,9 @@ classes: - name: samples type_python: "np.ndarray" decision_log: > - Semantic Parity: epoch_number is 1-indexed. - Returned samples are 1-indexed. + INDEXING DIFFERENCE: epoch_number is 1-indexed (same as MATLAB). + Returned samples are 0-indexed in Python (1-indexed in MATLAB). + Sample 0 in Python corresponds to sample 1 in MATLAB. static_methods: - name: mfdaq_channeltypes diff --git a/src/ndi/daq/reader_base.py b/src/ndi/daq/reader_base.py index a5d8de7..f5f02b5 100644 --- a/src/ndi/daq/reader_base.py +++ b/src/ndi/daq/reader_base.py @@ -239,9 +239,18 @@ def t0_t1_ingested( if not isinstance(t0t1_raw, list): return [tuple(t0t1_raw)] + # Detect flat pair [t0, t1] vs list of pairs [[t0, t1], ...] + # A flat pair has exactly 2 scalar elements. + if ( + len(t0t1_raw) == 2 + and not isinstance(t0t1_raw[0], (list, tuple)) + and not isinstance(t0t1_raw[1], (list, tuple)) + ): + return [(t0t1_raw[0], t0t1_raw[1])] + t0t1_list = [] for t in t0t1_raw: - if isinstance(t, (list, tuple)): + if isinstance(t, (list, tuple)) and len(t) == 2: t0t1_list.append(tuple(t)) else: t0t1_list.append((t, t)) diff --git a/src/ndi/daq/system.py b/src/ndi/daq/system.py index e994a26..4ed9731 100644 --- a/src/ndi/daq/system.py +++ b/src/ndi/daq/system.py @@ -442,6 +442,11 @@ def epochtable(self) -> list[dict[str, Any]]: } ) + # Sort by epoch_id alphanumerically to match MATLAB behavior + et.sort(key=lambda e: e.get("epoch_id", "")) + for i, entry in enumerate(et): + entry["epoch_number"] = i + 1 + return et def epochnodes(self) -> list[dict[str, Any]]: diff --git a/src/ndi/daq/system_mfdaq.py b/src/ndi/daq/system_mfdaq.py index 9d9b25f..0ce843e 100644 --- a/src/ndi/daq/system_mfdaq.py +++ b/src/ndi/daq/system_mfdaq.py @@ -48,6 +48,11 @@ class ndi_daq_system_mfdaq(ndi_daq_system): "marker": "mk", } + def _getepochfiles(self, epoch_number: int) -> list[str]: + """Get epoch files, unpacking the tuple from getepochfiles.""" + result = self._filenavigator.getepochfiles(epoch_number) + return result[0] if isinstance(result, tuple) else result + def epochclock(self, epoch_number: int) -> list[ndi_time_clocktype]: """ Return clock types for an epoch. @@ -75,8 +80,7 @@ def t0_t1(self, epoch_number: int) -> list[tuple[float, float]]: List of (t0, t1) tuples per clock type """ if self._daqreader is not None and self._filenavigator is not None: - result = self._filenavigator.getepochfiles(epoch_number) - epochfiles = result[0] if isinstance(result, tuple) else result + epochfiles = self._getepochfiles(epoch_number) return self._daqreader.t0_t1(epochfiles) return [(np.nan, np.nan)] @@ -93,7 +97,7 @@ def getchannelsepoch(self, epoch_number: int) -> list[Any]: if self._daqreader is None or self._filenavigator is None: return [] - epochfiles = self._filenavigator.getepochfiles(epoch_number) + epochfiles = self._getepochfiles(epoch_number) if isinstance(self._daqreader, ndi_daq_reader_mfdaq): return self._daqreader.getchannelsepoch(epochfiles) @@ -145,12 +149,15 @@ def readchannels_epochsamples( if self._daqreader is None or self._filenavigator is None: raise RuntimeError("No DAQ reader or file navigator configured") - epochfiles = self._filenavigator.getepochfiles(epoch_number) - if isinstance(self._daqreader, ndi_daq_reader_mfdaq): - return self._daqreader.readchannels_epochsamples( - channeltype, channel, epochfiles, s0, s1 + epochfiles = self._getepochfiles(epoch_number) + if not isinstance(self._daqreader, ndi_daq_reader_mfdaq): + raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + + if self._is_ingested(epochfiles): + return self._daqreader.readchannels_epochsamples_ingested( + channeltype, channel, epochfiles, s0, s1, self.session ) - raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + return self._daqreader.readchannels_epochsamples(channeltype, channel, epochfiles, s0, s1) def readevents_epochsamples( self, @@ -176,10 +183,15 @@ def readevents_epochsamples( if self._daqreader is None or self._filenavigator is None: raise RuntimeError("No DAQ reader or file navigator configured") - epochfiles = self._filenavigator.getepochfiles(epoch_number) - if isinstance(self._daqreader, ndi_daq_reader_mfdaq): - return self._daqreader.readevents_epochsamples(channeltype, channel, epochfiles, t0, t1) - raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + epochfiles = self._getepochfiles(epoch_number) + if not isinstance(self._daqreader, ndi_daq_reader_mfdaq): + raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + + if self._is_ingested(epochfiles): + return self._daqreader.readevents_epochsamples_ingested( + channeltype, channel, epochfiles, t0, t1, self.session + ) + return self._daqreader.readevents_epochsamples(channeltype, channel, epochfiles, t0, t1) def samplerate( self, @@ -201,10 +213,15 @@ def samplerate( if self._daqreader is None or self._filenavigator is None: raise RuntimeError("No DAQ reader or file navigator configured") - epochfiles = self._filenavigator.getepochfiles(epoch_number) - if isinstance(self._daqreader, ndi_daq_reader_mfdaq): - return self._daqreader.samplerate(epochfiles, channeltype, channel) - raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + epochfiles = self._getepochfiles(epoch_number) + if not isinstance(self._daqreader, ndi_daq_reader_mfdaq): + raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + + if self._is_ingested(epochfiles): + return self._daqreader.samplerate_ingested( + epochfiles, channeltype, channel, self.session + ) + return self._daqreader.samplerate(epochfiles, channeltype, channel) def epochsamples2times( self, @@ -214,13 +231,17 @@ def epochsamples2times( samples: np.ndarray, ) -> np.ndarray: """ - Convert sample indices to time. + Convert 0-based sample indices to time. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) channel: Channel number(s) epoch_number: ndi_epoch_epoch number (1-indexed) - samples: Sample indices (1-indexed) + samples: Sample indices (0-based) Returns: Time values @@ -228,10 +249,15 @@ def epochsamples2times( if self._daqreader is None or self._filenavigator is None: raise RuntimeError("No DAQ reader or file navigator configured") - epochfiles = self._filenavigator.getepochfiles(epoch_number) - if isinstance(self._daqreader, ndi_daq_reader_mfdaq): - return self._daqreader.epochsamples2times(channeltype, channel, epochfiles, samples) - raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + epochfiles = self._getepochfiles(epoch_number) + if not isinstance(self._daqreader, ndi_daq_reader_mfdaq): + raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + + if self._is_ingested(epochfiles): + return self._daqreader.epochsamples2times_ingested( + channeltype, channel, epochfiles, samples, self.session + ) + return self._daqreader.epochsamples2times(channeltype, channel, epochfiles, samples) def epochtimes2samples( self, @@ -241,7 +267,11 @@ def epochtimes2samples( times: np.ndarray, ) -> np.ndarray: """ - Convert time to sample indices. + Convert time to 0-based sample indices. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to time t0 of the epoch. Args: channeltype: Channel type(s) @@ -250,15 +280,20 @@ def epochtimes2samples( times: Time values Returns: - Sample indices (1-indexed) + Sample indices (0-based) """ if self._daqreader is None or self._filenavigator is None: raise RuntimeError("No DAQ reader or file navigator configured") - epochfiles = self._filenavigator.getepochfiles(epoch_number) - if isinstance(self._daqreader, ndi_daq_reader_mfdaq): - return self._daqreader.epochtimes2samples(channeltype, channel, epochfiles, times) - raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + epochfiles = self._getepochfiles(epoch_number) + if not isinstance(self._daqreader, ndi_daq_reader_mfdaq): + raise TypeError("DAQ reader is not an ndi_daq_reader_mfdaq") + + if self._is_ingested(epochfiles): + return self._daqreader.epochtimes2samples_ingested( + channeltype, channel, epochfiles, times, self.session + ) + return self._daqreader.epochtimes2samples(channeltype, channel, epochfiles, times) @staticmethod def mfdaq_channeltypes() -> list[str]: diff --git a/src/ndi/dataset/_dataset.py b/src/ndi/dataset/_dataset.py index ccf799c..330ef30 100644 --- a/src/ndi/dataset/_dataset.py +++ b/src/ndi/dataset/_dataset.py @@ -294,6 +294,9 @@ def open_session(self, session_id: str) -> Any | None: session = self._recreate_session(info, path_arg, session_id) if session is not None: + # Propagate cloud client from dataset to session + if hasattr(self, "cloud_client") and self.cloud_client is not None: + session.cloud_client = self.cloud_client self._session_array[match_idx]["session"] = session return session diff --git a/src/ndi/file/type/mfdaq_epoch_channel.py b/src/ndi/file/type/mfdaq_epoch_channel.py index b83644a..7579395 100644 --- a/src/ndi/file/type/mfdaq_epoch_channel.py +++ b/src/ndi/file/type/mfdaq_epoch_channel.py @@ -120,22 +120,49 @@ def create_properties( def readFromFile(self, filename: str) -> ndi_file_type_mfdaq__epoch__channel: """ - Read channel information from a JSON file. + Read channel information from a file. MATLAB equivalent: ndi.file.type.mfdaq_epoch_channel/readFromFile + Supports both JSON format (Python-generated) and the MATLAB + tab-delimited format (read via ``vlt.file.loadStructArray``). + Args: - filename: Path to the JSON file + filename: Path to the channel list file Returns: Self for chaining """ - with open(filename) as f: - data = json.load(f) - + # Try JSON first (Python-generated files) + try: + with open(filename) as f: + data = json.load(f) + self.channel_information = [] + for ch_data in data.get("channel_information", []): + self.channel_information.append(ChannelInfo.from_dict(ch_data)) + return self + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # Fallback: vlt.file.loadStructArray (MATLAB tab-delimited format) + from vlt.file import loadStructArray + + records = loadStructArray(filename) self.channel_information = [] - for ch_data in data.get("channel_information", []): - self.channel_information.append(ChannelInfo.from_dict(ch_data)) + for rec in records: + self.channel_information.append( + ChannelInfo( + name=str(rec.get("name", "")), + type=str(rec.get("type", "")), + time_channel=int(rec.get("time_channel", 1)), + sample_rate=float(rec.get("sample_rate", 0.0)), + offset=float(rec.get("offset", 0.0)), + scale=float(rec.get("scale", 1.0)), + number=int(rec.get("number", 0)), + group=int(rec.get("group", 0)), + dataclass=str(rec.get("dataclass", "")), + ) + ) return self def writeToFile(self, filename: str) -> tuple[bool, str]: @@ -173,6 +200,11 @@ def channelgroupdecoding( Given a list of requested channels, returns the corresponding group assignments and index mappings. + Note: + ``channel_indexes_in_groups`` contains 0-based indices into the + segment data columns (within the subset of channels belonging to + that group and type). In MATLAB these are 1-based. + Args: channel_info: List of ChannelInfo channel_type: Type of channels to look up @@ -182,41 +214,50 @@ def channelgroupdecoding( Tuple of (groups, channel_indexes_in_groups, channel_indexes_in_output): - groups: Unique group numbers for requested channels - - channel_indexes_in_groups: For each group, the channel - numbers that belong to it - - channel_indexes_in_output: For each group, the indexes - into the output data corresponding to those channels + - channel_indexes_in_groups: For each group, 0-based column + indices into the segment data for the requested channels + - channel_indexes_in_output: For each group, 0-based indices + into the output data array """ - # Build lookup by (type, number) -> (group, index in channel_info) - lookup: dict[tuple[str, int], int] = {} - for ch in channel_info: - lookup[(ch.type, ch.number)] = ch.group - - # Get group assignment for each requested channel - channel_groups = [] - for ch_num in channels: - group = lookup.get((channel_type, ch_num), 0) - channel_groups.append(group) - - # Find unique groups (preserving order) - seen: set[int] = set() - groups: list[int] = [] - for g in channel_groups: - if g not in seen: - seen.add(g) - groups.append(g) + from ndi.daq.mfdaq import standardize_channel_type + + ct_std = standardize_channel_type(channel_type) - # Build index mappings for each group + # Filter to channels matching the requested type + ci_typed = [ch for ch in channel_info if standardize_channel_type(ch.type) == ct_std] + + groups: list[int] = [] channel_indexes_in_groups: list[list[int]] = [] channel_indexes_in_output: list[list[int]] = [] - for g in groups: - # Channel numbers belonging to this group - ch_in_group = [channels[i] for i in range(len(channels)) if channel_groups[i] == g] - # Indexes into the output array for this group - idx_in_output = [i for i in range(len(channels)) if channel_groups[i] == g] - channel_indexes_in_groups.append(ch_in_group) - channel_indexes_in_output.append(idx_in_output) + for c_idx, ch_num in enumerate(channels): + # Find this channel in the type-filtered list + matches = [i for i, ci in enumerate(ci_typed) if ci.number == ch_num] + if not matches: + raise ValueError(f"Channel number {ch_num} not found in record.") + if len(matches) > 1: + raise ValueError(f"Channel number {ch_num} found multiple times in record.") + + ch_info = ci_typed[matches[0]] + grp = ch_info.group + + # Find or create group entry + if grp in groups: + g_idx = groups.index(grp) + else: + groups.append(grp) + g_idx = len(groups) - 1 + channel_indexes_in_groups.append([]) + channel_indexes_in_output.append([]) + + # Find the 0-based index of this channel within its group + subset_group = [ci for ci in ci_typed if ci.group == grp] + chan_index_in_group = next( + i for i, ci in enumerate(subset_group) if ci.number == ch_num + ) + + channel_indexes_in_groups[g_idx].append(chan_index_in_group) + channel_indexes_in_output[g_idx].append(c_idx) return groups, channel_indexes_in_groups, channel_indexes_in_output diff --git a/src/ndi/probe/__init__.py b/src/ndi/probe/__init__.py index e7a9620..1811da1 100644 --- a/src/ndi/probe/__init__.py +++ b/src/ndi/probe/__init__.py @@ -153,6 +153,13 @@ def buildepochtable(self) -> list[dict[str, Any]]: } ) + # Sort by epoch_id alphanumerically to match MATLAB behavior + et.sort(key=lambda e: e.get("epoch_id", "")) + + # Renumber after sorting + for i, entry in enumerate(et): + entry["epoch_number"] = i + 1 + return et def _get_daqsystems(self) -> list[Any]: @@ -173,14 +180,20 @@ def _get_daqsystems(self) -> list[Any]: q = ndi_query("").isa("daqsystem") docs = self._session.database_search(q) - # Load ndi_daq_system objects from documents - from ..daq.system import ndi_daq_system - + # Load ndi_daq_system objects from documents using the session's + # _document_to_object which creates the correct subclass (e.g. + # ndi_daq_system_mfdaq for MFDAQ systems). systems = [] for doc in docs: try: - sys = ndi_daq_system(session=self._session, document=doc) - systems.append(sys) + if hasattr(self._session, "_document_to_object"): + obj = self._session._document_to_object(doc) + else: + from ..daq.system import ndi_daq_system + + obj = ndi_daq_system(session=self._session, document=doc) + if obj is not None: + systems.append(obj) except Exception: pass @@ -199,6 +212,13 @@ def _find_matching_epochprobemap( Returns: Matching ndi_epoch_epochprobemap or None """ + # Normalize to list — some code paths return a single object + if isinstance(epochprobemaps, ndi_epoch_epochprobemap): + epochprobemaps = [epochprobemaps] + elif isinstance(epochprobemaps, dict): + epochprobemaps = [epochprobemaps] + elif not isinstance(epochprobemaps, (list, tuple)): + epochprobemaps = [epochprobemaps] for epm in epochprobemaps: # Handle both ndi_epoch_epochprobemap objects and dicts if isinstance(epm, ndi_epoch_epochprobemap): @@ -314,6 +334,7 @@ def getchanneldevinfo( return { "daqsystem": underlying.get("underlying"), "device_epoch_id": underlying.get("epoch_id"), + "device_epoch_number": entry.get("epoch_number", epoch_number), "epochprobemap": entry.get("epochprobemap", []), } diff --git a/src/ndi/probe/ndi_matlab_python_bridge.yaml b/src/ndi/probe/ndi_matlab_python_bridge.yaml index 11a5597..07f560c 100644 --- a/src/ndi/probe/ndi_matlab_python_bridge.yaml +++ b/src/ndi/probe/ndi_matlab_python_bridge.yaml @@ -295,8 +295,9 @@ classes: - name: samples type_python: "np.ndarray" decision_log: > - Exact match. Returns 1-indexed sample indices - (Semantic Parity: samples are user-facing counting). + INDEXING DIFFERENCE: MATLAB returns 1-indexed samples; Python returns + 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. + Formula: s = round(t * sr) (Python) vs s = 1 + round(t * sr) (MATLAB). - name: samples2times input_arguments: @@ -310,8 +311,9 @@ classes: - name: times type_python: "np.ndarray" decision_log: > - Exact match. Accepts 1-indexed sample indices - (Semantic Parity: samples are user-facing counting). + INDEXING DIFFERENCE: MATLAB accepts 1-indexed samples; Python accepts + 0-indexed. Sample 0 in Python corresponds to sample 1 in MATLAB. + Formula: t = s / sr (Python) vs t = (s - 1) / sr (MATLAB). # ========================================================================= # ndi.probe.timeseries.mfdaq (MFDAQ timeseries probe) diff --git a/src/ndi/probe/timeseries.py b/src/ndi/probe/timeseries.py index 2987b38..e1bcae3 100644 --- a/src/ndi/probe/timeseries.py +++ b/src/ndi/probe/timeseries.py @@ -151,20 +151,24 @@ def times2samples( times: np.ndarray, ) -> np.ndarray: """ - Convert times to sample indices. + Convert times to 0-based sample indices. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to the start of the epoch. Args: epoch: ndi_epoch_epoch number or epoch_id times: Time values Returns: - Sample indices (1-indexed) + Sample indices (0-based) """ sr = self.samplerate(epoch) if sr <= 0: return np.full_like(times, np.nan) times = np.asarray(times) - return 1 + np.round(times * sr).astype(int) + return np.round(times * sr).astype(int) def samples2times( self, @@ -172,11 +176,15 @@ def samples2times( samples: np.ndarray, ) -> np.ndarray: """ - Convert sample indices to times. + Convert 0-based sample indices to times. + + Note: + Unlike MATLAB (1-based), Python sample indices are 0-based. + Sample 0 corresponds to the start of the epoch. Args: epoch: ndi_epoch_epoch number or epoch_id - samples: Sample indices (1-indexed) + samples: Sample indices (0-based) Returns: Time values @@ -185,7 +193,7 @@ def samples2times( if sr <= 0: return np.full_like(samples, np.nan, dtype=float) samples = np.asarray(samples, dtype=float) - return (samples - 1) / sr + return samples / sr def __repr__(self) -> str: return ( diff --git a/src/ndi/probe/timeseries_mfdaq.py b/src/ndi/probe/timeseries_mfdaq.py index 5fd8949..358deeb 100644 --- a/src/ndi/probe/timeseries_mfdaq.py +++ b/src/ndi/probe/timeseries_mfdaq.py @@ -74,7 +74,7 @@ def read_epochsamples( except (AttributeError, TypeError): return None, None, None - # Get time values + # Get time values for each 0-based sample index try: t = dev.epochsamples2times(channeltype, channellist, devepoch, np.arange(s0, s1 + 1)) except (AttributeError, TypeError): @@ -110,7 +110,7 @@ def readtimeseriesepoch( if dev is None: return None, None, None - # Convert times to samples + # Convert times to 0-based sample indices try: samples = dev.epochtimes2samples(channeltype, channellist, devepoch, np.array([t0, t1])) s0 = int(samples[0]) @@ -222,20 +222,9 @@ def _resolve_device( dss = ndi_daq_daqsystemstring.parse(probe_map.devicestring) - # Find the DAQ system by name - if self._session is None: - return None - - # Get all DAQ systems from the session - daq_systems = getattr(self._session, "daqsystem", []) - if callable(daq_systems): - daq_systems = daq_systems() - - device = None - for ds in (daq_systems if isinstance(daq_systems, list) else []): - if hasattr(ds, "name") and ds.name == dss.devicename: - device = ds - break + # Get device from the underlying_epochs stored by buildepochtable + underlying = epoch_entry.get("underlying_epochs", {}) + device = underlying.get("underlying") if isinstance(underlying, dict) else None if device is None: return None diff --git a/src/ndi/probe/timeseries_stimulator.py b/src/ndi/probe/timeseries_stimulator.py index e50b7a0..373ea85 100644 --- a/src/ndi/probe/timeseries_stimulator.py +++ b/src/ndi/probe/timeseries_stimulator.py @@ -9,12 +9,15 @@ from __future__ import annotations +import logging from typing import Any import numpy as np from .timeseries import ndi_probe_timeseries +logger = logging.getLogger("ndi") + class ndi_probe_timeseries_stimulator(ndi_probe_timeseries): """ @@ -118,7 +121,7 @@ def readtimeseriesepoch( return empty_data, empty_t, self._get_epoch_timeref(epoch) dev = devinfo.get("daqsystem") - devepoch = devinfo.get("device_epoch_id") + devepoch = devinfo.get("device_epoch_number", devinfo.get("device_epoch_id")) channeltype = devinfo.get("channeltype", []) channel = devinfo.get("channel", []) @@ -158,8 +161,9 @@ def readtimeseriesepoch( else: sr_val = float(sr) - s0 = 1 + round(sr_val * t0) - s1 = 1 + round(sr_val * t1) + # 0-based sample indices (Python convention) + s0 = round(sr_val * t0) + s1 = round(sr_val * t1) analog_channeltype = [channeltype_nonmd[i] for i in analog_indices] analog_channel = [channel_nonmd[i] for i in analog_indices] @@ -171,9 +175,11 @@ def readtimeseriesepoch( try: t_analog = dev.readchannels_epochsamples(["time"], [1], devepoch, s0, s1) t["analog"] = np.asarray(t_analog).ravel() - except Exception: + except Exception as exc: + logger.warning("stimulator: failed to read time channel: %s", exc) t["analog"] = np.nan - except Exception: + except Exception as exc: + logger.warning("stimulator: failed to read analog channels: %s", exc) data["analog"] = np.array([]) t["analog"] = np.nan else: @@ -204,7 +210,8 @@ def readtimeseriesepoch( else: timestamps_list = [] edata_list = [] - except Exception: + except Exception as exc: + logger.warning("stimulator: readevents_epochsamples failed: %s", exc, exc_info=True) timestamps_list = [] edata_list = [] @@ -295,7 +302,8 @@ def readtimeseriesepoch( try: md_ch_idx = all_channel[i] data["parameters"] = dev.getmetadata(devepoch, md_ch_idx) - except Exception: + except Exception as exc: + logger.warning("stimulator: failed to read metadata: %s", exc) data["parameters"] = [] t["stimevents"] = event_data_list @@ -337,7 +345,8 @@ def readtimeseriesepoch( try: md_ch_idx = all_channel[i] data["parameters"] = dev.getmetadata(devepoch, md_ch_idx) - except Exception: + except Exception as exc: + logger.warning("stimulator: failed to read metadata: %s", exc) data["parameters"] = [] elif ct in ("e", "event"): @@ -385,7 +394,8 @@ def readtimeseriesepoch( from ..time.timereference import ndi_time_timereference timeref = ndi_time_timereference(self, ndi_time_clocktype.DEV_LOCAL_TIME, eid, 0) - except Exception: + except Exception as exc: + logger.warning("stimulator: failed to create timeref: %s", exc) timeref = ndi_time_clocktype.DEV_LOCAL_TIME return data, t, timeref @@ -416,32 +426,56 @@ def getchanneldevinfo( return None dev = base_info.get("daqsystem") - devepoch = base_info.get("device_epoch_id") + devepoch = base_info.get("device_epoch_number", base_info.get("device_epoch_id")) if dev is None: return None - # Get channel info from the epochprobemap's devicestring - epms = base_info.get("epochprobemap", []) + # Get channel info from ALL epochprobemaps in the underlying epoch + # that match this probe. MATLAB iterates all maps in the underlying + # epoch and extracts channels from every matching one. + et, _ = self.epochtable() + entry = et[epoch - 1] if isinstance(epoch, int) and epoch <= len(et) else None + underlying = entry.get("underlying_epochs", {}) if entry else {} + all_epms = underlying.get("epochprobemap", base_info.get("epochprobemap", [])) + if not isinstance(all_epms, list): + all_epms = [all_epms] + channeltype = [] channel = [] - for epm in epms: + for epm in all_epms: + if not self.epochprobemapmatch(epm): + continue if hasattr(epm, "devicestring") and epm.devicestring: + logger.debug( + "stimulator: matched epm devicestring='%s'", + epm.devicestring, + ) try: from ..daq.daqsystemstring import ndi_daq_daqsystemstring dss = ndi_daq_daqsystemstring.parse(epm.devicestring) + logger.debug( + "stimulator devicestring '%s' parsed: channels=%s", + epm.devicestring, + dss.channels, + ) for ct, ch_list in dss.channels: for ch in ch_list: channeltype.append(ct) channel.append(ch) - except Exception: - pass + except Exception as exc: + logger.warning( + "stimulator: failed to parse devicestring '%s': %s", + epm.devicestring if hasattr(epm, "devicestring") else "?", + exc, + ) return { "daqsystem": dev, - "device_epoch_id": devepoch, + "device_epoch_id": base_info.get("device_epoch_id"), + "device_epoch_number": devepoch, "channeltype": channeltype, "channel": channel, } diff --git a/src/ndi/session/session_base.py b/src/ndi/session/session_base.py index fa914ec..6567a27 100644 --- a/src/ndi/session/session_base.py +++ b/src/ndi/session/session_base.py @@ -1113,7 +1113,9 @@ def _document_to_object(self, document: ndi_document) -> Any: if isinstance(props, dict): daq_class_name = props.get("daqsystem", {}).get("ndi_daqsystem_class", "") - if "mfdaq" in daq_class_name: + # Check for mfdaq in the class name, or default to mfdaq + # if class name is missing (most DAQ systems are MFDAQ) + if "mfdaq" in daq_class_name or not daq_class_name: from ..daq.system_mfdaq import ndi_daq_system_mfdaq return ndi_daq_system_mfdaq(session=self, document=document) diff --git a/tests/matlab_tests/test_cloud_compute.py b/tests/matlab_tests/test_cloud_compute.py index 56a32e2..51c9414 100644 --- a/tests/matlab_tests/test_cloud_compute.py +++ b/tests/matlab_tests/test_cloud_compute.py @@ -163,7 +163,12 @@ def test_hello_world_flow_live(self): _, client = _login() # 1. Start session - result = startSession("hello-world-v1", client=client) + try: + result = startSession("hello-world-v1", client=client) + except Exception as exc: + if "does not have permission" in str(exc): + pytest.skip(f"User lacks compute permissions: {exc}") + raise session_id = result.get("sessionId") or result.get("id", "") assert session_id, f"No sessionId in response: {result}" @@ -274,7 +279,12 @@ def test_zombie_flow_live(self): _, client = _login() # 1. Start pipeline - result = startSession("zombie-test-v1", client=client) + try: + result = startSession("zombie-test-v1", client=client) + except Exception as exc: + if "does not have permission" in str(exc): + pytest.skip(f"User lacks compute permissions: {exc}") + raise session_id = result.get("sessionId") or result.get("id", "") assert session_id, f"No sessionId in response: {result}" diff --git a/tests/matlab_tests/test_daq.py b/tests/matlab_tests/test_daq.py index 690e2be..d4d45c5 100644 --- a/tests/matlab_tests/test_daq.py +++ b/tests/matlab_tests/test_daq.py @@ -365,17 +365,18 @@ def test_epochsamples2times_basic(self): reader = _MockMFDAQReader(sample_rate=sr, t0=0.0) files = ["dummy.rhd"] - samples = np.array([1, 2, 3]) + # 0-based: sample 0 = t0, sample 1 = t0 + 1/sr, etc. + samples = np.array([0, 1, 2]) times = reader.epochsamples2times("ai", 1, files, samples) - # t = t0 + (sample - 1) / sr => t = (s-1)/30000 + # t = t0 + sample / sr => t = s/30000 (0-based) expected = np.array([0.0, 1.0 / sr, 2.0 / sr]) np.testing.assert_allclose(times, expected, atol=1e-12) def test_epochtimes2samples_basic(self): - """Convert times to sample indices with known sample rate. + """Convert times to 0-based sample indices with known sample rate. - MATLAB equivalent: mfdaqIntanTest - epochtimes2samples basic + Note: Python uses 0-based indices (MATLAB uses 1-based). """ sr = 30000.0 reader = _MockMFDAQReader(sample_rate=sr, t0=0.0) @@ -384,19 +385,20 @@ def test_epochtimes2samples_basic(self): times = np.array([0.0, 1.0 / sr, 2.0 / sr]) samples = reader.epochtimes2samples("ai", 1, files, times) - expected = np.array([1, 2, 3]) + # 0-based: s = round((t - t0) * sr) + expected = np.array([0, 1, 2]) np.testing.assert_array_equal(samples, expected) def test_roundtrip_samples_times(self): """samples -> times -> samples round-trip should be identity. - MATLAB equivalent: mfdaqIntanTest - round-trip test + Note: Python uses 0-based sample indices. """ sr = 20000.0 reader = _MockMFDAQReader(sample_rate=sr, t0=0.5) files = ["dummy.rhd"] - original_samples = np.array([1, 100, 1000, 10000]) + original_samples = np.array([0, 99, 999, 9999]) times = reader.epochsamples2times("ai", 1, files, original_samples) recovered_samples = reader.epochtimes2samples("ai", 1, files, times) @@ -427,16 +429,16 @@ def test_epochsamples2times_with_nonzero_t0(self): reader = _MockMFDAQReader(sample_rate=sr, t0=t0) files = ["dummy.rhd"] - samples = np.array([1]) + # 0-based: sample 0 should correspond to t0 + samples = np.array([0]) times = reader.epochsamples2times("ai", 1, files, samples) - # sample 1 should correspond to t0 np.testing.assert_allclose(times, np.array([t0]), atol=1e-12) def test_epochtimes2samples_with_nonzero_t0(self): """epochtimes2samples with nonzero t0. - MATLAB equivalent: mfdaqIntanTest - t0 offset check (reverse) + Note: Python uses 0-based sample indices. """ sr = 10000.0 t0 = 2.5 @@ -446,7 +448,8 @@ def test_epochtimes2samples_with_nonzero_t0(self): times = np.array([t0]) samples = reader.epochtimes2samples("ai", 1, files, times) - np.testing.assert_array_equal(samples, np.array([1])) + # 0-based: t0 maps to sample 0 + np.testing.assert_array_equal(samples, np.array([0])) # =========================================================================== diff --git a/tests/test_batch_a.py b/tests/test_batch_a.py index 8daede7..11ff7cb 100644 --- a/tests/test_batch_a.py +++ b/tests/test_batch_a.py @@ -454,7 +454,10 @@ def test_channelgroupdecoding(self): ndi_file_type_mfdaq__epoch__channel.channelgroupdecoding(channels, "analog_in", [1, 3]) ) assert groups == [1, 2] - assert ch_in_groups == [[1], [3]] + # ch_in_groups contains 0-based indices within each group's channels + # Channel 1 is index 0 in group 1 (which has channels 1, 2) + # Channel 3 is index 0 in group 2 (which has only channel 3) + assert ch_in_groups == [[0], [0]] assert ch_in_output == [[0], [1]] def test_repr(self): @@ -658,10 +661,11 @@ def samplerate(self, epoch): return 1000.0 pt = MockTimeseries(name="test", reference=1, type="n-trode") + # 0-based: sample 0 = t=0, sample 1 = t=0.001, sample 10 = t=0.01 samples = pt.times2samples(1, np.array([0.0, 0.001, 0.01])) - assert samples[0] == 1 - assert samples[1] == 2 - assert samples[2] == 11 + assert samples[0] == 0 + assert samples[1] == 1 + assert samples[2] == 10 def test_samples2times(self): from ndi.probe.timeseries import ndi_probe_timeseries @@ -671,7 +675,8 @@ def samplerate(self, epoch): return 1000.0 pt = MockTimeseries(name="test", reference=1, type="n-trode") - times = pt.samples2times(1, np.array([1, 2, 11])) + # 0-based: sample 0 = t=0, sample 1 = t=0.001, sample 10 = t=0.01 + times = pt.samples2times(1, np.array([0, 1, 10])) np.testing.assert_allclose(times, [0.0, 0.001, 0.01]) def test_repr(self): diff --git a/tests/test_cloud_read_ingested.py b/tests/test_cloud_read_ingested.py new file mode 100644 index 0000000..0e8a908 --- /dev/null +++ b/tests/test_cloud_read_ingested.py @@ -0,0 +1,363 @@ +""" +ndi.unittest.cloud.readIngested - Read an ingested dataset from the cloud. + +Downloads the Carbon fiber microelectrode dataset, opens its session, +reads timeseries data from a carbon-fiber probe and a stimulator probe, +and verifies the returned values match expected results. + +Requires environment variables: + NDI_CLOUD_USERNAME -- mapped from GitHub secret TEST_USER_2_USERNAME + NDI_CLOUD_PASSWORD -- mapped from GitHub secret TEST_USER_2_PASSWORD + +Skipped automatically if credentials are not set. +""" + +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import pytest + +# --------------------------------------------------------------------------- +# Skip entire module if no credentials +# --------------------------------------------------------------------------- + +_has_creds = bool(os.environ.get("NDI_CLOUD_USERNAME") and os.environ.get("NDI_CLOUD_PASSWORD")) +pytestmark = pytest.mark.skipif(not _has_creds, reason="NDI cloud credentials not set") + +CARBON_FIBER_ID = "668b0539f13096e04f1feccd" + + +@pytest.fixture(scope="module") +def cloud_client(): + """Authenticate with NDI Cloud and return a client.""" + from ndi.cloud.auth import login + from ndi.cloud.client import CloudClient + + username = os.environ["NDI_CLOUD_USERNAME"] + password = os.environ["NDI_CLOUD_PASSWORD"] + config = login(username, password) + assert config.is_authenticated, "Login failed -- no token received" + return CloudClient(config) + + +@pytest.fixture(scope="module") +def dataset(cloud_client): + """Download the Carbon fiber dataset to a temp directory.""" + from ndi.cloud.orchestration import downloadDataset + + with tempfile.TemporaryDirectory() as target_dir: + D = downloadDataset(CARBON_FIBER_ID, target_dir, client=cloud_client) + yield D + + +@pytest.fixture(scope="module") +def session(dataset): + """Open the single session in the dataset.""" + refs, session_ids, *_ = dataset.session_list() + assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}" + S = dataset.open_session(session_ids[0]) + return S + + +class TestReadIngested: + """ndi.unittest.cloud.readIngested — verify cloud dataset reads.""" + + def test_session_list_has_one_entry(self, dataset): + """session_list should return exactly one session.""" + refs, session_ids, *_ = dataset.session_list() + assert len(session_ids) == 1 + + def test_carbonfiber_probe_timeseries(self, session): + """Read carbonfiber probe timeseries and check values.""" + p_cf = session.getprobes(name="carbonfiber", reference=1) + assert len(p_cf) == 1, f"Expected 1 carbonfiber probe, got {len(p_cf)}" + + probe = p_cf[0] + print(f" Probe class: {type(probe).__name__}") + print(f" Probe MRO: {[c.__name__ for c in type(probe).__mro__[:5]]}") + + # Diagnostic: check epoch table + et, _ = probe.epochtable() + print(f" Probe epochtable has {len(et)} entries") + for i, e in enumerate(et): + print(f" epoch[{i}]: id={e.get('epoch_id')}") + if et: + e = et[0] + print(f" epoch_id: {e.get('epoch_id')}") + epm = e.get("epochprobemap") + print(f" epochprobemap type: {type(epm).__name__}, value: {epm}") + underlying = e.get("underlying_epochs", {}) + if underlying: + u = underlying.get("underlying") + print(f" underlying type: {type(u).__name__}") + print(f" underlying epoch_id: {underlying.get('epoch_id')}") + + # Diagnostic: check devinfo + try: + devinfo = probe.getchanneldevinfo(1) + print(f" devinfo type: {type(devinfo).__name__}, value: {devinfo}") + except Exception as exc: + pytest.fail(f"getchanneldevinfo(1) raised {type(exc).__name__}: {exc}") + + if devinfo is None: + pytest.fail("getchanneldevinfo(1) returned None") + + # Try the full readtimeseries path with explicit error propagation + if isinstance(devinfo, tuple): + dev, devepoch, channeltype, channellist = devinfo + elif isinstance(devinfo, dict): + dev = devinfo.get("daqsystem") + devepoch = devinfo.get("device_epoch_id") + print(f" devinfo is dict, dev={type(dev).__name__}, devepoch={devepoch}") + pytest.fail( + f"getchanneldevinfo returned dict (base probe class), not tuple. " + f"Probe class {type(probe).__name__} may not override getchanneldevinfo." + ) + else: + pytest.fail(f"getchanneldevinfo returned unexpected type: {type(devinfo).__name__}") + + print(f" dev={type(dev).__name__}, devepoch={devepoch}") + print(f" channeltype={channeltype}, channellist={channellist}") + + # Diagnostic: try reading channel_list.bin directly + diag = [] + if ( + hasattr(dev, "_filenavigator") + and dev._filenavigator is not None + and hasattr(dev, "_getepochfiles") + ): + epochfiles = dev._getepochfiles(devepoch) + diag.append(f"epochfiles={epochfiles[:2]}") + is_ingested = epochfiles and epochfiles[0].startswith("epochid://") + diag.append(f"is_ingested={is_ingested}") + if is_ingested and hasattr(dev, "_daqreader"): + try: + ingested_doc = dev._daqreader.getingesteddocument(epochfiles, session) + diag.append(f"doc_class={ingested_doc.doc_class()}") + props = ingested_doc.document_properties + prop_keys = [ + k for k in props if "ingested" in k.lower() or "daqreader" in k.lower() + ] + diag.append(f"ingested_keys={prop_keys}") + for pk in prop_keys: + if isinstance(props[pk], dict): + diag.append(f"{pk}.keys={list(props[pk].keys())}") + files = props.get("files", {}) + diag.append(f"files.keys={list(files.keys())}") + fi = files.get("file_info", []) + diag.append(f"file_info_count={len(fi)}") + if fi and isinstance(fi[0], dict): + diag.append(f"fi[0].name={fi[0].get('name')}") + locs = fi[0].get("locations", []) + if locs: + diag.append(f"fi[0].loc[0]={locs[0].get('location', '')[:60]}") + try: + fobj = session.database_openbinarydoc(ingested_doc, "channel_list.bin") + diag.append(f"channel_list.bin=OK:{fobj.name}") + fobj.close() + except Exception as exc2: + diag.append(f"channel_list.bin=FAILED:{type(exc2).__name__}:{exc2}") + except Exception as exc: + diag.append(f"getingesteddocument=FAILED:{exc}") + + # Try epochtimes2samples explicitly to see any error + try: + samples = dev.epochtimes2samples( + channeltype, channellist, devepoch, np.array([10.0, 20.0]) + ) + print(f" samples={samples}") + except Exception as exc: + pytest.fail( + f"dev.epochtimes2samples raised {type(exc).__name__}: {exc}\n" + f" dev type: {type(dev).__name__}\n" + f" diag: {'; '.join(diag)}" + ) + + # Debug: read first 9 samples from t=0 to check alignment + d_first, t_first, _ = probe.readtimeseries(epoch=1, t0=0, t1=0.001) + if d_first is not None: + print(" ALIGNMENT CHECK: first 9 samples from t=0:") + print(f" d_first.shape={d_first.shape}") + n = min(9, d_first.shape[0]) + vals = [f"{d_first[i,0]:.4f}" for i in range(n)] + times = [f"{t_first[i]:.6f}" for i in range(n)] + print(f" values: {vals}") + print(f" times: {times}") + print( + " EXPECTED: [2.0475, 0.4760, -0.1080, -0.1020, -0.0528, 0.0006, 0.0242, 0.1517, 0.0909]" + ) + + d1, t1, _ = probe.readtimeseries(epoch=1, t0=10, t1=20) + + assert ( + d1 is not None + ), "readtimeseries returned None for data (binary files not accessible?)" + assert t1 is not None, "readtimeseries returned None for times" + + # Check data isn't all NaN + if np.all(np.isnan(d1)): + pytest.fail( + f"readtimeseries returned all NaN data. shape={d1.shape}. " + f"Segment file reading likely failed — check warnings in log." + ) + + # Debug: print raw values, shape, and scale/offset info + print(f" d1.shape={d1.shape}, t1.shape={t1.shape}") + print(f" d1[0,:5]={d1[0,:5]}") + print(f" t1[0]={t1[0]}") + # Get scale/offset from channel info + epochfiles = dev._getepochfiles(devepoch) + sr_arr, off_arr, sc_arr = dev._daqreader.samplerate_ingested( + epochfiles, channeltype, channellist, session + ) + print(f" sr={sr_arr[0]}, offset={off_arr[:3]}, scale={sc_arr[:3]}") + t0t1 = dev._daqreader.t0_t1_ingested(epochfiles, session) + print(f" t0_t1={t0t1}") + # Debug: print raw data near expected position + print(f" d1[0,:5]={d1[0,:5]}") + if d1.shape[0] > 1: + print(f" d1[1,:5]={d1[1,:5]}") + # Read one sample earlier to check alignment + d_check, t_check, _ = probe.readtimeseries(epoch=1, t0=9.99995, t1=10.0001) + if d_check is not None: + print(f" d_check.shape={d_check.shape}") + for i in range(min(5, d_check.shape[0])): + print(f" d_check[{i},0]={d_check[i,0]:.3f} t={t_check[i]:.6f}") + + # Check first time sample + assert abs(t1[0] - 10.0) < 0.001, f"Expected t1[0] ≈ 10.0, got {t1[0]}" + + # Expected values for d1[0, :] + expected_d1_row0 = np.array( + [ + 55.7700, + 253.3050, + -43.2900, + -9.5550, + 30.6150, + 23.4000, + 16.1850, + -51.6750, + -1.7550, + -14.6250, + -32.7600, + 45.6300, + -7.2150, + 0.9750, + -1.7550, + 45.0450, + ] + ) + + actual_d1_row0 = d1[0, :] + assert ( + actual_d1_row0.shape == expected_d1_row0.shape + ), f"Expected {expected_d1_row0.shape} channels, got {actual_d1_row0.shape}" + np.testing.assert_allclose( + actual_d1_row0, + expected_d1_row0, + atol=0.001, + err_msg="d1[0,:] values do not match expected", + ) + + def test_stimulator_probe_timeseries(self, session): + """Read stimulator probe timeseries and check stimid and timing.""" + p_st = session.getprobes(type="stimulator") + assert len(p_st) >= 1, "Expected at least 1 stimulator probe" + + stim = p_st[0] + print(f" Stimulator probe: {stim}") + print(f" Stimulator class: {type(stim).__name__}") + + # Diagnostic: check what getchanneldevinfo returns + devinfo = stim.getchanneldevinfo(1) + if devinfo is None: + pytest.fail("stimulator getchanneldevinfo(1) returned None") + print(f" devinfo keys: {list(devinfo.keys())}") + dev = devinfo.get("daqsystem") + devepoch = devinfo.get("device_epoch_number", devinfo.get("device_epoch_id")) + ct = devinfo.get("channeltype", []) + ch = devinfo.get("channel", []) + print(f" dev={type(dev).__name__}, devepoch={devepoch}") + print(f" channeltype={ct}, channel={ch}") + # Print ALL epochprobemaps from underlying epoch + et_stim, _ = stim.epochtable() + if et_stim: + underlying = et_stim[0].get("underlying_epochs", {}) + all_epms = underlying.get("epochprobemap", []) + if not isinstance(all_epms, list): + all_epms = [all_epms] + print(f" underlying epochprobemaps count: {len(all_epms)}") + for i, m in enumerate(all_epms): + ds = getattr(m, "devicestring", "?") + nm = getattr(m, "name", "?") + print(f" epm[{i}]: name={nm} devicestring={ds}") + match = stim.epochprobemapmatch(m) if hasattr(stim, "epochprobemapmatch") else "?" + print(f" matches this probe: {match}") + + # Try readevents directly (without md channels, matching stimulator) + non_md_ct = [c for c in ct if c != "md"] + non_md_ch = [ch[i] for i, c in enumerate(ct) if c != "md"] + print(f" non-md channeltype={non_md_ct}, channel={non_md_ch}") + if dev is not None and non_md_ct: + try: + evt_result = dev.readevents_epochsamples(non_md_ct, non_md_ch, devepoch, 10, 20) + print(f" readevents result type: {type(evt_result)}") + if isinstance(evt_result, tuple): + ts_r, data_r = evt_result + print(f" timestamps type: {type(ts_r)}, data type: {type(data_r)}") + if isinstance(ts_r, list): + for i in range(len(ts_r)): + t_i, d_i = ts_r[i], data_r[i] + t_s = getattr( + t_i, "shape", len(t_i) if hasattr(t_i, "__len__") else "?" + ) + d_s = getattr( + d_i, "shape", len(d_i) if hasattr(d_i, "__len__") else "?" + ) + label = ( + f"{non_md_ct[i]}{non_md_ch[i]}" if i < len(non_md_ct) else f"[{i}]" + ) + print(f" ch[{i}] ({label}): ts={t_s}, data={d_s}") + elif hasattr(ts_r, "shape"): + print(f" timestamps shape: {ts_r.shape}") + except Exception as exc: + pytest.fail( + f"readevents_epochsamples raised {type(exc).__name__}: {exc}\n" + f" channeltype={non_md_ct}, channel={non_md_ch}, devepoch={devepoch}" + ) + + ds, ts, _ = stim.readtimeseries(epoch=1, t0=10, t1=20) + + assert ds is not None, "readtimeseries returned None for data" + assert ts is not None, "readtimeseries returned None for times" + + # ds should be a dict with 'stimid' + stimid = ds["stimid"] + if hasattr(stimid, "size") and stimid.size == 0: + pytest.fail( + f"ds['stimid'] is empty. ds keys={list(ds.keys())}, " + f"ds values sizes={{ k: (v.size if hasattr(v, 'size') else len(v) if hasattr(v, '__len__') else v) for k, v in ds.items() }}, " + f"ts keys={list(ts.keys())}, " + f"ts values sizes={{ k: (v.size if hasattr(v, 'size') else len(v) if hasattr(v, '__len__') else v) for k, v in ts.items() }}" + ) + # Extract scalar stimid from potentially nested array + stimid_val = np.asarray(stimid).ravel() + if stimid_val.size > 0: + stimid_val = int(stimid_val[0]) + else: + pytest.fail(f"stimid is empty after ravel: {stimid}") + assert stimid_val == 31, f"Expected stimid == 31, got {stimid_val} (raw: {stimid})" + + # ts.stimon should be 15.2590 (within 0.001) + stimon = ts["stimon"] + if hasattr(stimon, "size") and stimon.size == 0: + pytest.fail("ts['stimon'] is empty — binary files may not be accessible from cloud") + if hasattr(stimon, "__len__"): + stimon_val = float(stimon) if np.ndim(stimon) == 0 else float(stimon[0]) + else: + stimon_val = float(stimon) + assert abs(stimon_val - 15.2590) < 0.001, f"Expected ts.stimon ≈ 15.2590, got {stimon_val}" diff --git a/tests/test_daq.py b/tests/test_daq.py index 31c50b5..eb1c245 100644 --- a/tests/test_daq.py +++ b/tests/test_daq.py @@ -235,12 +235,12 @@ def test_mfdaq_samplerate(self): assert all(sr == 30000.0) def test_mfdaq_epochsamples2times(self): - """Test converting samples to times.""" + """Test converting 0-based samples to times.""" reader = ConcreteMFDAQReader() - samples = np.array([1, 1001, 2001]) + samples = np.array([0, 1000, 2000]) times = reader.epochsamples2times("ai", 1, ["test.dat"], samples) - # t = t0 + (s-1)/sr = 0 + (s-1)/30000 - expected = (samples - 1) / 30000.0 + # t = t0 + s/sr = 0 + s/30000 (0-based) + expected = samples / 30000.0 np.testing.assert_array_almost_equal(times, expected) def test_mfdaq_epochtimes2samples(self): @@ -248,8 +248,8 @@ def test_mfdaq_epochtimes2samples(self): reader = ConcreteMFDAQReader() times = np.array([0.0, 0.1, 0.2]) samples = reader.epochtimes2samples("ai", 1, ["test.dat"], times) - # s = 1 + round((t-t0)*sr) = 1 + round(t*30000) - expected = 1 + np.round(times * 30000).astype(int) + # s = round((t-t0)*sr) = round(t*30000) (0-based) + expected = np.round(times * 30000).astype(int) np.testing.assert_array_equal(samples, expected) def test_mfdaq_channel_types(self): @@ -663,7 +663,11 @@ def test_samplerate_ingested_no_session(self): # Mock getingesteddocument reader.getingesteddocument = MagicMock(return_value=mock_doc) - sr = reader.samplerate_ingested(["epochid://test123"], ["ai", "ai"], [1, 2], mock_session) + mock_session.database_openbinarydoc = MagicMock(side_effect=FileNotFoundError) + + sr, offset, scale = reader.samplerate_ingested( + ["epochid://test123"], ["ai", "ai"], [1, 2], mock_session + ) assert len(sr) == 2 assert sr[0] == 30000 assert sr[1] == 30000 @@ -686,6 +690,7 @@ def test_getchannelsepoch_ingested(self): } mock_session = MagicMock() + mock_session.database_openbinarydoc = MagicMock(side_effect=FileNotFoundError) reader.getingesteddocument = MagicMock(return_value=mock_doc) channels = reader.getchannelsepoch_ingested(["epochid://test"], mock_session) @@ -712,14 +717,15 @@ def test_epochsamples2times_ingested(self): } mock_session = MagicMock() + mock_session.database_openbinarydoc = MagicMock(side_effect=FileNotFoundError) reader.getingesteddocument = MagicMock(return_value=mock_doc) - samples = np.array([1, 3001, 6001]) + samples = np.array([0, 3000, 6000]) # 0-based sample indices times = reader.epochsamples2times_ingested( "ai", 1, ["epochid://test"], samples, mock_session ) - # t = t0 + (s-1)/sr = 0 + (s-1)/30000 - expected = (samples - 1) / 30000.0 + # t = t0 + s/sr = 0 + s/30000 (0-based) + expected = samples / 30000.0 np.testing.assert_array_almost_equal(times, expected) def test_epochtimes2samples_ingested(self): @@ -740,14 +746,15 @@ def test_epochtimes2samples_ingested(self): } mock_session = MagicMock() + mock_session.database_openbinarydoc = MagicMock(side_effect=FileNotFoundError) reader.getingesteddocument = MagicMock(return_value=mock_doc) times = np.array([0.0, 0.1, 0.2]) samples = reader.epochtimes2samples_ingested( "ai", 1, ["epochid://test"], times, mock_session ) - # s = 1 + round((t-t0)*sr) = 1 + round(t*30000) - expected = 1 + np.round(times * 30000).astype(int) + # s = round((t-t0)*sr) = round(t*30000) (0-based) + expected = np.round(times * 30000).astype(int) np.testing.assert_array_equal(samples, expected)