Skip to content
6 changes: 6 additions & 0 deletions smriprep/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def get_parser():
action='store_true',
help='treat dataset as longitudinal - may increase runtime',
)
g_conf.add_argument(
'--standardize-with-T2w',
action='store_true',
help='treat dataset as longitudinal - may increase runtime',
)

# ANTs options
g_ants = parser.add_argument_group('Specific options for ANTs registrations')
Expand Down Expand Up @@ -629,6 +634,7 @@ def build_workflow(opts, retval):
fs_no_resume=opts.fs_no_resume,
layout=layout,
longitudinal=opts.longitudinal,
standardize_with_T2w=opts.standardize_with_T2w,
low_mem=opts.low_mem,
msm_sulc=opts.msm_sulc,
omp_nthreads=omp_nthreads,
Expand Down
25 changes: 22 additions & 3 deletions smriprep/interfaces/templateflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,29 @@


class _TemplateFlowSelectInputSpec(BaseInterfaceInputSpec):
template = traits.Str('MNI152NLin2009cAsym', mandatory=True, desc='Template ID')
template = traits.Str(mandatory=True, desc='Template ID')
atlas = InputMultiObject(traits.Str, desc='Specify an atlas')
cohort = InputMultiObject(traits.Either(traits.Str, traits.Int), desc='Specify a cohort')
resolution = InputMultiObject(traits.Int, desc='Specify a template resolution index')
template_spec = traits.DictStrAny(
{'atlas': None, 'cohort': None}, usedefault=True, desc='Template specifications'
)
get_T2w = traits.Bool(False, usedefault=True, desc='Get the T2w if available')


class _TemplateFlowSelectOutputSpec(TraitedSpec):
t1w_file = File(exists=True, desc='T1w template')
t2w_file = File(exists=True, desc='T2w template')
brain_mask = File(exists=True, desc="Template's brain mask")


class TemplateFlowSelect(SimpleInterface):
"""
Select TemplateFlow elements.

>>> select = TemplateFlowSelect(resolution=1)
Examples
--------
>>> select = TemplateFlowSelect(resolution=1, get_T2w=True)
>>> select.inputs.template = 'MNI152NLin2009cAsym'
>>> result = select.run()
>>> result.outputs.t1w_file # doctest: +ELLIPSIS
Expand All @@ -66,6 +70,9 @@ class TemplateFlowSelect(SimpleInterface):
>>> result.outputs.brain_mask # doctest: +ELLIPSIS
'.../tpl-MNI152NLin2009cAsym_res-01_desc-brain_mask.nii.gz'

>>> result.outputs.t2w_file # doctest: +ELLIPSIS
'.../tpl-MNI152NLin2009cAsym_res-01_T2w.nii.gz'

>>> select = TemplateFlowSelect()
>>> select.inputs.template = 'MNIPediatricAsym'
>>> select.inputs.template_spec = {'cohort': 5, 'resolution': 1}
Expand Down Expand Up @@ -94,6 +101,9 @@ class TemplateFlowSelect(SimpleInterface):
>>> result.outputs.t1w_file # doctest: +ELLIPSIS
'.../tpl-MNI305_T1w.nii.gz'

>>> bool(result.outputs.t2w_file)
False

"""

input_spec = _TemplateFlowSelectInputSpec
Expand All @@ -108,8 +118,14 @@ def _run_interface(self, runtime):
if isdefined(self.inputs.cohort):
specs['cohort'] = self.inputs.cohort

files = fetch_template_files(self.inputs.template, specs)
files = fetch_template_files(
self.inputs.template,
specs,
get_T2w=self.inputs.get_T2w,
)
self._results['t1w_file'] = files['t1w']
if self.inputs.get_T2w and 't2w' in files:
self._results['t2w_file'] = files['t2w']
self._results['brain_mask'] = files['mask']
return runtime

Expand Down Expand Up @@ -167,6 +183,7 @@ def fetch_template_files(
template: str,
specs: dict | None = None,
sloppy: bool = False,
get_T2w: bool = False,
) -> dict:
if specs is None:
specs = {}
Expand Down Expand Up @@ -203,6 +220,8 @@ def fetch_template_files(

files = {}
files['t1w'] = tf.get(name[0], desc=None, suffix='T1w', **specs)
if get_T2w and (t2w := tf.get(name[0], desc=None, suffix='T2w', **specs)):
files['t2w'] = t2w
files['mask'] = tf.get(name[0], desc='brain', suffix='mask', **specs) or tf.get(
name[0], label='brain', suffix='mask', **specs
)
Expand Down
43 changes: 42 additions & 1 deletion smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def init_anat_preproc_wf(
name: str = 'anat_preproc_wf',
skull_strip_fixed_seed: bool = False,
fs_no_resume: bool = False,
norm_add_T2w: bool = False,
):
"""
Stage the anatomical preprocessing steps of *sMRIPrep*.
Expand Down Expand Up @@ -190,6 +191,9 @@ def init_anat_preproc_wf(
EXPERT: Import pre-computed FreeSurfer reconstruction without resuming.
The user is responsible for ensuring that all necessary files are present.
(default: ``False``).
norm_add_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -283,6 +287,7 @@ def init_anat_preproc_wf(
omp_nthreads=omp_nthreads,
skull_strip_fixed_seed=skull_strip_fixed_seed,
fs_no_resume=fs_no_resume,
norm_add_T2w=norm_add_T2w,
)
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
ds_std_volumes_wf = init_ds_anat_volumes_wf(
Expand Down Expand Up @@ -461,6 +466,7 @@ def init_anat_fit_wf(
name='anat_fit_wf',
skull_strip_fixed_seed: bool = False,
fs_no_resume: bool = False,
norm_add_T2w: bool = False,
):
"""
Stage the anatomical preprocessing steps of *sMRIPrep*.
Expand Down Expand Up @@ -541,6 +547,9 @@ def init_anat_fit_wf(
Do not use a random seed for skull-stripping - will ensure
run-to-run replicability when used with --omp-nthreads 1
(default: ``False``).
norm_add_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -994,6 +1003,7 @@ def init_anat_fit_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
ds_template_registration_wf = init_ds_template_registration_wf(
output_dir=output_dir, image_type='T1w'
Expand All @@ -1002,7 +1012,6 @@ def init_anat_fit_wf(
# fmt:off
workflow.connect([
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
(t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]),
(refined_buffer, register_template_wf, [('t1w_mask', 'inputnode.moving_mask')]),
(sourcefile_buffer, ds_template_registration_wf, [
('source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1137,6 +1146,12 @@ def init_anat_fit_wf(
image_type='T2w',
name='t2w_template_wf',
)
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
bbreg = pe.Node(
fs.BBRegister(
contrast_type='t2',
Expand Down Expand Up @@ -1166,6 +1181,8 @@ def init_anat_fit_wf(
)
ds_t2w_preproc.inputs.SkullStripped = False

merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True)

workflow.connect([
(inputnode, t2w_template_wf, [('t2w', 'inputnode.anat_files')]),
(t2w_template_wf, bbreg, [('outputnode.anat_ref', 'source_file')]),
Expand All @@ -1182,10 +1199,34 @@ def init_anat_fit_wf(
(inputnode, ds_t2w_preproc, [('t2w', 'source_file')]),
(t2w_resample, ds_t2w_preproc, [('output_image', 'in_file')]),
(ds_t2w_preproc, outputnode, [('out_file', 't2w_preproc')]),
(t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]),
(t2w_resample, merge_t2w, [('output_image', 'in2')]),
(merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]),
]) # fmt:skip
elif not t2w:
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
workflow.connect([
(t1w_buffer, register_template_wf, [('t1w_preproc', 'inputnode.moving_image')]),
])
LOGGER.info('ANAT No T2w images provided - skipping Stage 7')
else:
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
templates=templates,
use_T2w=norm_add_T2w and t2w,
)
merge_t2w = pe.Node(niu.Merge(2), name='merge_t2w', run_without_submitting=True)
workflow.connect([
(t1w_buffer, merge_t2w, [('t1w_preproc', 'in1')]),
(inputnode, merge_t2w, [('t2w', 'in2')]),
(merge_t2w, register_template_wf, [('out', 'inputnode.moving_image')]),
]) # fmt:skip
LOGGER.info('ANAT Found preprocessed T2w - skipping Stage 7')

# Stages 8-10: Surface conversion and registration
Expand Down
10 changes: 10 additions & 0 deletions smriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def init_smriprep_wf(
work_dir,
bids_filters,
cifti_output,
standardize_with_T2w
):
"""
Create the execution graph of *sMRIPrep*, with a sub-workflow for each subject.
Expand Down Expand Up @@ -156,6 +157,9 @@ def init_smriprep_wf(
bids_filters : dict
Provides finer specification of the pipeline input files through pybids entities filters.
A dict with the following structure {<suffix>:{<entity>:<filter>,...},...}
standardize_with_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

"""
smriprep_wf = Workflow(name='smriprep_wf')
Expand Down Expand Up @@ -196,6 +200,7 @@ def init_smriprep_wf(
subject_id=subject_id,
bids_filters=bids_filters,
cifti_output=cifti_output,
standardize_with_T2w=standardize_with_T2w,
)

single_subject_wf.config['execution']['crashdump_dir'] = os.path.join(
Expand Down Expand Up @@ -233,6 +238,7 @@ def init_single_subject_wf(
subject_id,
bids_filters,
cifti_output,
standardize_with_T2w,
):
"""
Create a single subject workflow.
Expand Down Expand Up @@ -324,6 +330,9 @@ def init_single_subject_wf(
bids_filters : dict
Provides finer specification of the pipeline input files through pybids entities filters.
A dict with the following structure {<suffix>:{<entity>:<filter>,...},...}
standardize_with_T2w : :obj:`bool`
Use T2w as a moving image channel in the spatial normalization to template
space(s).

Inputs
------
Expand Down Expand Up @@ -441,6 +450,7 @@ def init_single_subject_wf(
skull_strip_template=skull_strip_template,
spaces=spaces,
cifti_output=cifti_output,
norm_add_T2w=standardize_with_T2w,
)

# fmt:off
Expand Down
26 changes: 23 additions & 3 deletions smriprep/workflows/fit/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def init_register_template_wf(
sloppy,
omp_nthreads,
templates,
use_T2w=False,
name='register_template_wf',
):
"""
Expand Down Expand Up @@ -171,8 +172,9 @@ def init_register_template_wf(
)

# With the improvements from nipreps/niworkflows#342 this truncation is now necessary
trunc_mov = pe.Node(
ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 256'),
trunc_mov = pe.MapNode(
ants.ImageMath(operation='TruncateImageIntensity', op2='0.01 0.999 255'),
iterfield='op1',
name='trunc_mov',
)

Expand All @@ -192,10 +194,19 @@ def init_register_template_wf(
run_without_submitting=True,
)

include_t2w = pe.Node(
niu.Function(function=_include_t2w, output_names=['moving_image', 'get_T2w']),
name='include_t2w',
run_without_submitting=True,
)
include_t2w.inputs.use_T2w = use_T2w

# fmt:off
workflow.connect([
(inputnode, split_desc, [('template', 'template')]),
(inputnode, trunc_mov, [('moving_image', 'op1')]),
(inputnode, include_t2w, [('moving_image', 'moving_image')]),
(include_t2w, tf_select, [('get_T2w', 'get_T2w')]),
(include_t2w, trunc_mov, [('moving_image', 'op1')]),
(inputnode, registration, [
('moving_mask', 'moving_mask'),
('lesion_mask', 'lesion_mask')]),
Expand Down Expand Up @@ -243,3 +254,12 @@ def _fmt_cohort(template, spec):
if cohort is not None:
template = f'{template}:cohort-{cohort}'
return template, spec


def _include_t2w(moving_image, use_T2w=False):
islist = isinstance(moving_image, list)
if not use_T2w:
return moving_image[0] if islist else moving_image, False

return moving_image, islist and len(moving_image) > 1