Skip to content
119 changes: 102 additions & 17 deletions src/spikeinterface/extractors/neoextractors/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def install_maxwell_plugin(self, force_download=False):
auto_install_maxwell_hdf5_compression_plugin(force_download=False)


_maxwell_event_dtype = np.dtype([("frame", "int64"), ("state", "int8"), ("time", "float64")])
_maxwell_event_dtype = np.dtype(
[("id", "int8"), ("frame", "uint32"), ("time", "float64"), ("state", "uint32"), ("message", "object")]
)


class MaxwellEventExtractor(BaseEvent):
Expand All @@ -107,9 +109,41 @@ def __init__(self, file_path):
version = int(h5_file["version"][0].decode())
fs = 20000

if version < 20190530:
raise NotImplementedError(f"Version {self.version} not supported")

# get ttl events
bits = h5_file["bits"]
bit_states = bits["bits"]
channel_ids = np.unique(bit_states[bit_states != 0])

channel_ids = np.zeros((0), dtype=np.int8)
if len(bits) > 0:
bit_state = bits["bits"]
channel_ids = np.int8(np.unique(bit_state[bit_state != 0]))
if -1 in channel_ids or 1 in channel_ids:
raise ValueError("TTL bits cannot be -1 or 1.")

# access data_store from h5_file
data_store_keys = [x for x in h5_file["data_store"].keys()]
data_store_keys_id = [
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
for x in data_store_keys
]
data_store = data_store_keys[data_store_keys_id.index(True)]

# get stim events
event_raw = h5_file["data_store"][data_store]["events"]
channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw]))
if -1 in channel_ids_stim or 0 in channel_ids_stim:
raise ValueError("Stimulation bits cannot be -1 or 0.")
if len(channel_ids) > 0:
if set(channel_ids) & set(channel_ids_stim):
raise ValueError("TTL and stimulation bits overlap.")
channel_ids = np.concatenate((channel_ids, channel_ids_stim), dtype=np.int8)

# set spike events channel == -1
spike_raw = h5_file["data_store"][data_store]["spikes"]
if len(spike_raw) > 0:
channel_ids = np.concatenate((channel_ids, [-1]), dtype=np.int8)

BaseEvent.__init__(self, channel_ids, structured_dtype=_maxwell_event_dtype)
event_segment = MaxwellEventSegment(h5_file, version, fs)
Expand All @@ -125,22 +159,73 @@ def __init__(self, h5_file, version, fs):
self.fs = fs

def get_events(self, channel_id, start_time, end_time):
if self.version != 20160704:
raise NotImplementedError(f"Version {self.version} not supported")
bits = self.bits

# get ttl events
channel_ids = np.zeros((0), dtype=np.int8)
bit_channel = np.zeros((0), dtype=np.int8)
bit_frameno = np.zeros((0), dtype=np.uint32)
bit_state = np.zeros((0), dtype=np.uint32)
bit_message = np.zeros((0), dtype=object)
if len(bits) > 0:
good_idx = np.where(bits["bits"] != 0)[0]
channel_ids = np.concatenate((channel_ids, np.int8(np.unique(bits["bits"][good_idx]))))
if 1 in channel_ids:
raise ValueError("TTL bits cannot be 1.")
bit_channel = np.concatenate((bit_channel, np.uint8(bits["bits"][good_idx])))
bit_frameno = np.concatenate((bit_frameno, np.uint32(bits["frameno"][good_idx])))
bit_state = np.concatenate((bit_state, np.uint32(bits["bits"][good_idx])))
bit_message = np.concatenate((bit_message, [b"{}\n"] * len(bit_state)), dtype=object)

# access data_store from h5_file
h5_file = self.h5_file
data_store_keys = [x for x in h5_file["data_store"].keys()]
data_store_keys_id = [
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
for x in data_store_keys
]
data_store = data_store_keys[data_store_keys_id.index(True)]

# get stim events
event_raw = h5_file["data_store"][data_store]["events"]
channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw]))
stim_arr = np.array(event_raw)
bit_channel_stim = stim_arr["eventtype"]
bit_frameno_stim = stim_arr["frameno"]
bit_state_stim = stim_arr["eventid"]
bit_message_stim = stim_arr["eventmessage"]

# get spike events
spike_raw = h5_file["data_store"][data_store]["spikes"]
if len(spike_raw) > 0:
channel_ids_spike = np.int8([-1])
spike_arr = np.array(spike_raw)
bit_channel_spike = -np.ones(len(spike_arr), dtype=np.int8)
bit_frameno_spike = spike_arr["frameno"]
bit_state_spike = spike_arr["channel"]
bit_message_spike = spike_arr["amplitude"]

# final array in order: spikes, stims, ttl
bit_channel = np.concatenate((bit_channel_spike, bit_channel_stim, bit_channel))
bit_frameno = np.concatenate((bit_frameno_spike, bit_frameno_stim, bit_frameno))
bit_state = np.concatenate((bit_state_spike, bit_state_stim, bit_state))
bit_message = np.concatenate((bit_message_spike, bit_message_stim, bit_message))

first_frame = h5_file["data_store"][data_store]["groups/routed/frame_nos"][0]
bit_frameno = bit_frameno - first_frame

framevals = self.h5_file["sig"][-2:, 0]
first_frame = framevals[1] << 16 | framevals[0]
ttl_frames = self.bits["frameno"] - first_frame
ttl_states = self.bits["bits"]
if channel_id is not None:
bits_channel_idx = np.where((ttl_states == channel_id) | (ttl_states == 0))[0]
ttl_frames = ttl_frames[bits_channel_idx]
ttl_states = ttl_states[bits_channel_idx]
ttl_states[ttl_states == 0] = -1
event = np.zeros(len(ttl_frames), dtype=_maxwell_event_dtype)
event["frame"] = ttl_frames
event["time"] = ttl_frames / self.fs
event["state"] = ttl_states
good_idx = np.where(bit_channel == channel_id)[0]
bit_channel = bit_channel[good_idx]
bit_frameno = bit_frameno[good_idx]
bit_state = bit_state[good_idx]
bit_message = bit_message[good_idx]
event = np.zeros(len(bit_channel), dtype=_maxwell_event_dtype)
event["id"] = bit_channel
event["frame"] = bit_frameno
event["time"] = np.float64(bit_frameno) / self.fs
event["state"] = bit_state
event["message"] = bit_message

if start_time is not None:
event = event[event["time"] >= start_time]
Expand Down
23 changes: 23 additions & 0 deletions src/spikeinterface/extractors/neoextractors/neobaseextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,29 @@ def __init__(
# need neo 0.10.0
signal_channels = self.neo_reader.header["signal_channels"]
mask = signal_channels["stream_id"] == stream_id

# remove all duplicate channel-to-electrode assignments
mask_id = np.argwhere(mask).flatten()
[u, u_i, u_v, u_c] = np.unique(
signal_channels[mask]["name"], return_index=True, return_inverse=True, return_counts=True
)
for i in u_v[u_i[u_c > 1]]:
mask[mask_id[np.argwhere(signal_channels[mask]["name"] == u[i])[1:].flatten()]] = False

# remove all duplicate channel assigments corresponding to different electrodes (channel is a mix of mulitple electrode signals)
mask_id = np.argwhere(mask).flatten()
signal_channels_chan, _ = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
[u, u_i, u_v, u_c] = np.unique(signal_channels_chan, return_index=True, return_inverse=True, return_counts=True)
for i in u_v[u_i[u_c > 1]]:
mask[mask_id[np.argwhere(signal_channels_chan == u[i])[:].flatten()]] = False

# remove subsequent duplicated electrodes (single electrode saved to multiple channels)
mask_id = np.argwhere(mask).flatten()
_, signal_channels_elec = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
[u, u_i, u_v, u_c] = np.unique(signal_channels_elec, return_index=True, return_inverse=True, return_counts=True)
for i in u_v[u_i[u_c > 1]]:
mask[mask_id[np.argwhere(signal_channels_elec == u[i])[1:].flatten()]] = False

signal_channels = signal_channels[mask]

if use_names_as_ids:
Expand Down
Loading