diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 1ec210541e..9000069805 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -91,7 +91,9 @@ def install_maxwell_plugin(self, force_download=False): auto_install_maxwell_hdf5_compression_plugin(force_download=False) -_maxwell_event_dtype = np.dtype([("frame", "int64"), ("state", "int8"), ("time", "float64")]) +_maxwell_event_dtype = np.dtype( + [("id", "int8"), ("frame", "uint32"), ("time", "float64"), ("state", "uint32"), ("message", "object")] +) class MaxwellEventExtractor(BaseEvent): @@ -107,9 +109,41 @@ def __init__(self, file_path): version = int(h5_file["version"][0].decode()) fs = 20000 + if version < 20190530: + raise NotImplementedError(f"Version {self.version} not supported") + + # get ttl events bits = h5_file["bits"] - bit_states = bits["bits"] - channel_ids = np.unique(bit_states[bit_states != 0]) + + channel_ids = np.zeros((0), dtype=np.int8) + if len(bits) > 0: + bit_state = bits["bits"] + channel_ids = np.int8(np.unique(bit_state[bit_state != 0])) + if -1 in channel_ids or 1 in channel_ids: + raise ValueError("TTL bits cannot be -1 or 1.") + + # access data_store from h5_file + data_store_keys = [x for x in h5_file["data_store"].keys()] + data_store_keys_id = [ + ("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys()) + for x in data_store_keys + ] + data_store = data_store_keys[data_store_keys_id.index(True)] + + # get stim events + event_raw = h5_file["data_store"][data_store]["events"] + channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw])) + if -1 in channel_ids_stim or 0 in channel_ids_stim: + raise ValueError("Stimulation bits cannot be -1 or 0.") + if len(channel_ids) > 0: + if set(channel_ids) & set(channel_ids_stim): + raise ValueError("TTL and stimulation bits overlap.") + channel_ids = np.concatenate((channel_ids, channel_ids_stim), dtype=np.int8) + + # set spike events channel == -1 + spike_raw = h5_file["data_store"][data_store]["spikes"] + if len(spike_raw) > 0: + channel_ids = np.concatenate((channel_ids, [-1]), dtype=np.int8) BaseEvent.__init__(self, channel_ids, structured_dtype=_maxwell_event_dtype) event_segment = MaxwellEventSegment(h5_file, version, fs) @@ -125,22 +159,73 @@ def __init__(self, h5_file, version, fs): self.fs = fs def get_events(self, channel_id, start_time, end_time): - if self.version != 20160704: - raise NotImplementedError(f"Version {self.version} not supported") + bits = self.bits + + # get ttl events + channel_ids = np.zeros((0), dtype=np.int8) + bit_channel = np.zeros((0), dtype=np.int8) + bit_frameno = np.zeros((0), dtype=np.uint32) + bit_state = np.zeros((0), dtype=np.uint32) + bit_message = np.zeros((0), dtype=object) + if len(bits) > 0: + good_idx = np.where(bits["bits"] != 0)[0] + channel_ids = np.concatenate((channel_ids, np.int8(np.unique(bits["bits"][good_idx])))) + if 1 in channel_ids: + raise ValueError("TTL bits cannot be 1.") + bit_channel = np.concatenate((bit_channel, np.uint8(bits["bits"][good_idx]))) + bit_frameno = np.concatenate((bit_frameno, np.uint32(bits["frameno"][good_idx]))) + bit_state = np.concatenate((bit_state, np.uint32(bits["bits"][good_idx]))) + bit_message = np.concatenate((bit_message, [b"{}\n"] * len(bit_state)), dtype=object) + + # access data_store from h5_file + h5_file = self.h5_file + data_store_keys = [x for x in h5_file["data_store"].keys()] + data_store_keys_id = [ + ("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys()) + for x in data_store_keys + ] + data_store = data_store_keys[data_store_keys_id.index(True)] + + # get stim events + event_raw = h5_file["data_store"][data_store]["events"] + channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw])) + stim_arr = np.array(event_raw) + bit_channel_stim = stim_arr["eventtype"] + bit_frameno_stim = stim_arr["frameno"] + bit_state_stim = stim_arr["eventid"] + bit_message_stim = stim_arr["eventmessage"] + + # get spike events + spike_raw = h5_file["data_store"][data_store]["spikes"] + if len(spike_raw) > 0: + channel_ids_spike = np.int8([-1]) + spike_arr = np.array(spike_raw) + bit_channel_spike = -np.ones(len(spike_arr), dtype=np.int8) + bit_frameno_spike = spike_arr["frameno"] + bit_state_spike = spike_arr["channel"] + bit_message_spike = spike_arr["amplitude"] + + # final array in order: spikes, stims, ttl + bit_channel = np.concatenate((bit_channel_spike, bit_channel_stim, bit_channel)) + bit_frameno = np.concatenate((bit_frameno_spike, bit_frameno_stim, bit_frameno)) + bit_state = np.concatenate((bit_state_spike, bit_state_stim, bit_state)) + bit_message = np.concatenate((bit_message_spike, bit_message_stim, bit_message)) + + first_frame = h5_file["data_store"][data_store]["groups/routed/frame_nos"][0] + bit_frameno = bit_frameno - first_frame - framevals = self.h5_file["sig"][-2:, 0] - first_frame = framevals[1] << 16 | framevals[0] - ttl_frames = self.bits["frameno"] - first_frame - ttl_states = self.bits["bits"] if channel_id is not None: - bits_channel_idx = np.where((ttl_states == channel_id) | (ttl_states == 0))[0] - ttl_frames = ttl_frames[bits_channel_idx] - ttl_states = ttl_states[bits_channel_idx] - ttl_states[ttl_states == 0] = -1 - event = np.zeros(len(ttl_frames), dtype=_maxwell_event_dtype) - event["frame"] = ttl_frames - event["time"] = ttl_frames / self.fs - event["state"] = ttl_states + good_idx = np.where(bit_channel == channel_id)[0] + bit_channel = bit_channel[good_idx] + bit_frameno = bit_frameno[good_idx] + bit_state = bit_state[good_idx] + bit_message = bit_message[good_idx] + event = np.zeros(len(bit_channel), dtype=_maxwell_event_dtype) + event["id"] = bit_channel + event["frame"] = bit_frameno + event["time"] = np.float64(bit_frameno) / self.fs + event["state"] = bit_state + event["message"] = bit_message if start_time is not None: event = event[event["time"] >= start_time] diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index d66ce79aa3..f0156942de 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -228,6 +228,29 @@ def __init__( # need neo 0.10.0 signal_channels = self.neo_reader.header["signal_channels"] mask = signal_channels["stream_id"] == stream_id + + # remove all duplicate channel-to-electrode assignments + mask_id = np.argwhere(mask).flatten() + [u, u_i, u_v, u_c] = np.unique( + signal_channels[mask]["name"], return_index=True, return_inverse=True, return_counts=True + ) + for i in u_v[u_i[u_c > 1]]: + mask[mask_id[np.argwhere(signal_channels[mask]["name"] == u[i])[1:].flatten()]] = False + + # remove all duplicate channel assigments corresponding to different electrodes (channel is a mix of mulitple electrode signals) + mask_id = np.argwhere(mask).flatten() + signal_channels_chan, _ = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"]))) + [u, u_i, u_v, u_c] = np.unique(signal_channels_chan, return_index=True, return_inverse=True, return_counts=True) + for i in u_v[u_i[u_c > 1]]: + mask[mask_id[np.argwhere(signal_channels_chan == u[i])[:].flatten()]] = False + + # remove subsequent duplicated electrodes (single electrode saved to multiple channels) + mask_id = np.argwhere(mask).flatten() + _, signal_channels_elec = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"]))) + [u, u_i, u_v, u_c] = np.unique(signal_channels_elec, return_index=True, return_inverse=True, return_counts=True) + for i in u_v[u_i[u_c > 1]]: + mask[mask_id[np.argwhere(signal_channels_elec == u[i])[1:].flatten()]] = False + signal_channels = signal_channels[mask] if use_names_as_ids: