diff --git a/src/console/interfaces/rx_data.py b/src/console/interfaces/rx_data.py index f86085ea..d280b64c 100644 --- a/src/console/interfaces/rx_data.py +++ b/src/console/interfaces/rx_data.py @@ -18,6 +18,7 @@ class RxData: # Data characteristics, defined by the ADC event in sequence definition num_samples: int num_samples_raw: int + num_samples_discard: int # Number of samples to be discarded before and after ADC, defined by dead time dwell_time: float dwell_time_raw: float @@ -137,10 +138,10 @@ def process_data(self, store_unprocessed: bool = True) -> None: # Creating the processed data output array first and copying the values of the output of the decimation # avoids an apparent memory leak when using the scipy.decimate with the 'iir' ftype - output_shape = list(np.shape(demod_data)) - output_shape[-1] = self.num_samples + # Note that the processed data may contain samples from pre and post sampling + output_shape = (*np.shape(demod_data)[:-1], self.num_samples + int(2*self.num_samples_discard)) self.processed_data = np.zeros(output_shape, dtype=complex) - self.processed_data[:] = self.decimate_data(demod_data)[:] + self.processed_data[:] = self.decimate_data(demod_data) if not store_unprocessed: self.raw_data = None diff --git a/src/console/pulseq_interpreter/sequence_provider.py b/src/console/pulseq_interpreter/sequence_provider.py index de14e90f..9a796864 100644 --- a/src/console/pulseq_interpreter/sequence_provider.py +++ b/src/console/pulseq_interpreter/sequence_provider.py @@ -2,6 +2,8 @@ import logging import operator from collections.abc import Callable +from dataclasses import dataclass +from math import floor from types import SimpleNamespace from typing import Any @@ -34,6 +36,13 @@ def profile(func: Callable[..., Any]) -> Callable[..., Any]: default_fov_offset: Dimensions = Dimensions(0, 0, 0) default_orientation: Dimensions = Dimensions(1, 2, 3) +@dataclass +class ADCGate: + """Define precalculated attributes of an ADC gate.""" + + start: int + num_samples_discard: int + num_samples_raw: int class SequenceProvider(Sequence): """Sequence provider class. @@ -189,7 +198,7 @@ def dict(self) -> dict: "output_limits": self.output_limits, } - def get_adc_events(self) -> list: + def get_adc_events(self) -> list[ADCGate]: """Extract ADC 'waveforms' from the sequence. TODO: Add error checks @@ -199,18 +208,30 @@ def get_adc_events(self) -> list: list: List of with waveform ID, gate signal and reference signal for each unique ADC event. """ - adc_waveforms = self.adc_library adc_list = [] - for adc_waveform in adc_waveforms.data.items(): - num_samples = adc_waveform[1][0] - dwell_time = adc_waveform[1][1] - delay = adc_waveform[1][2] - delay_samples = round(delay * self.spcm_freq) - gate_duration = num_samples * dwell_time - gate_samples = round(gate_duration * self.spcm_freq) - waveform = np.zeros(delay_samples + gate_samples, dtype=np.uint16) - waveform[delay_samples:] = 2**15 - adc_list.append((adc_waveform[0], waveform, gate_samples)) + for adc_props in self.adc_library.data.values(): + # Implementation compatible to version 1.4.X and 1.5.X -> dead time is always appended + num_samples, adc_dwell_time, delay = adc_props[:3] + dead_time = adc_props[-1] + + # Calculate the number of samples to be discarded from the decimated signal + num_samples_discard = floor(dead_time / adc_dwell_time) + # Calculate the total gate duration, given by number of samples + # and two times the number of discarded samples for symmetric adc dead time + # Note that the total gate duration is only increased if the dead time is a multiple of the adc dwell time + total_gate_duration = (num_samples + 2*num_samples_discard) * adc_dwell_time + num_raw_samples = round(total_gate_duration * self.spcm_freq) + + # Remaining delay = dead_time minus pre- and post-sampling fractions + remaining_delay = delay - num_samples_discard * adc_dwell_time + num_delay_samples = round(remaining_delay * self.spcm_freq) + + adc_list.append(ADCGate( + start=num_delay_samples, + num_samples_raw=num_raw_samples, + num_samples_discard=num_samples_discard, + )) + return adc_list def get_rf_events(self) -> list: @@ -281,7 +302,7 @@ def unroll_sequence(self, parameter: AcquisitionParameter) -> UnrolledSequence: # Get list of all events and list of unique RF and ADC events, since they are frequently reused events_list = self.block_events - adc_events = self.get_adc_events() + adc_events: list[ADCGate] = self.get_adc_events() rf_events = self.get_rf_events() # Calculate rf pulse and unblanking waveforms from RF event @@ -410,22 +431,22 @@ def unroll_sequence(self, parameter: AcquisitionParameter) -> UnrolledSequence: labels[label.label] = label.value if block.adc is not None: # ADC event - # Grab the ADC event from the pre-calculated list + # Grab the ADC gate from the pre-calculated list # Pulseq is 1 indexed, shift idx by -1 for correct event - adc_event = adc_events[event[5] - 1] - adc_waveform = adc_event[1] + adc_gate: ADCGate = adc_events[event[5] - 1] # Calculate ADC start and end positions according to block position - adc_start = block_pos[event_idx] * 4 - adc_end = (block_pos[event_idx] + np.size(adc_waveform)) * 4 + adc_start = (block_pos[event_idx] + adc_gate.start) * 4 + adc_end = (block_pos[event_idx] + adc_gate.start + adc_gate.num_samples_raw) * 4 # Add ADC gate to X gradient - _seq[adc_start + 1:adc_end + 1:4] = _seq[adc_start + 1:adc_end + 1:4] | adc_waveform + _seq[adc_start + 1:adc_end + 1:4] = _seq[adc_start + 1:adc_end + 1:4] | np.uint16(2**15) _rx_data.append(RxData( index=adc_count, num_samples=block.adc.num_samples, - num_samples_raw=adc_event[2], + num_samples_raw=adc_gate.num_samples_raw, + num_samples_discard=adc_gate.num_samples_discard, dwell_time=block.adc.dwell, dwell_time_raw=self.spcm_dwell_time, phase_offset=block.adc.phase_offset, diff --git a/src/console/utilities/data/write_acquisition_to_mrd.py b/src/console/utilities/data/write_acquisition_to_mrd.py index 62125aa3..0bf49850 100644 --- a/src/console/utilities/data/write_acquisition_to_mrd.py +++ b/src/console/utilities/data/write_acquisition_to_mrd.py @@ -56,6 +56,11 @@ def write_acquisition_to_mrd( acq.center_sample = rx_data.num_samples // 2 # Readout bandwidth, as time between samples in microseconds acq.sample_time_us = rx_data.dwell_time * 1e6 + # Number of samples to be discarded, defined by adc_dead_time + # Since adc_dead_time is arrange symmetrically around ADC, the + # number of discarded pre and post sample is identical + acq.discard_pre = rx_data.num_samples_discard + acq.discard_post = rx_data.num_samples_discard # Timestamp of readout if rx_data.time_stamp is not None: acq.acquisition_time_stamp = int(rx_data.time_stamp * 1e6) # timestamp in us diff --git a/tests/conftest.py b/tests/conftest.py index fa626b3c..7d0f6156 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ def _factory(num_samples: int, num_acquisitions: int, num_averages: int = 1, num average_index=k_average, num_samples=num_samples, num_samples_raw=num_raw_samples, + num_samples_discard=0, dwell_time=1 / 20e3, dwell_time_raw=1 / 20e6, phase_offset=0, diff --git a/tests/sequence_tests/test_sequence_provider.py b/tests/sequence_tests/test_sequence_provider.py index 7f665a69..1271cae7 100644 --- a/tests/sequence_tests/test_sequence_provider.py +++ b/tests/sequence_tests/test_sequence_provider.py @@ -6,6 +6,7 @@ import pypulseq as pp import pytest +from console.interfaces.acquisition_parameter import AcquisitionParameter from console.interfaces.rx_data import RxData from console.interfaces.unrolled_sequence import UnrolledSequence from console.pulseq_interpreter.sequence_provider import SequenceProvider @@ -260,7 +261,7 @@ def test_get_rf_events(seq_provider, test_sequence): assert rf_block.type == "rf" -def test_sequence_rx_data(seq_provider: SequenceProvider, acquisition_parameter): +def test_sequence_rx_data(seq_provider: SequenceProvider, acquisition_parameter: AcquisitionParameter): """Labels in blocks must be propagated into RxData.labels for each ADC event.""" n_samples = 1000 bw = 20e3 @@ -282,3 +283,26 @@ def test_sequence_rx_data(seq_provider: SequenceProvider, acquisition_parameter) assert rx0.num_samples == n_samples assert rx0.num_samples_raw == n_samples / (bw*seq_provider.spcm_dwell_time) assert rx0.dwell_time == 1/bw + + +@pytest.mark.parametrize("dead_time", [0., 10e-3]) +def test_adc_presampling(seq_provider: SequenceProvider, acquisition_parameter: AcquisitionParameter, dead_time: float): + """Verify that dead_time is used for pre and post samples which are to be discarded after decimation.""" + adc_bw = 20e3 + adc_dwell = 1/adc_bw + num_samples_discard = round(dead_time/adc_dwell) + num_samples = 100 + # Define adc event + adc = pp.make_adc( + delay=dead_time, + num_samples=num_samples, + dwell=adc_dwell, + system=pp.Opts(adc_dead_time=dead_time), + ) + # Unroll sequence + seq_provider.add_block(adc) + seq_unrolled = seq_provider.unroll_sequence(acquisition_parameter) + rx_data = seq_unrolled.rx_data[0] + + assert rx_data.num_samples == num_samples + assert rx_data.num_samples_discard == num_samples_discard