diff --git a/pyproject.toml b/pyproject.toml index b7a65fe49..d014d254c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -285,3 +285,4 @@ scil_volume_reshape = "scilpy.cli.scil_volume_reshape:main" scil_volume_reslice_to_reference = "scilpy.cli.scil_volume_reslice_to_reference:main" scil_volume_stats_in_labels = "scilpy.cli.scil_volume_stats_in_labels:main" scil_volume_stats_in_ROI = "scilpy.cli.scil_volume_stats_in_ROI:main" +scil_volume_validate_correct_strides = "scilpy.cli.scil_volume_validate_correct_strides:main" diff --git a/src/scilpy/cli/scil_gradients_modify_axes.py b/src/scilpy/cli/scil_gradients_modify_axes.py index a34acf715..8ccc93b2d 100755 --- a/src/scilpy/cli/scil_gradients_modify_axes.py +++ b/src/scilpy/cli/scil_gradients_modify_axes.py @@ -12,7 +12,8 @@ import numpy as np -from scilpy.gradients.bvec_bval_tools import (flip_gradient_sampling, +from scilpy.gradients.bvec_bval_tools import (find_flip_swap_from_order, + flip_gradient_axis, swap_gradient_axis) from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_verbose_arg) @@ -70,17 +71,7 @@ def main(): "input's. We do not support conversion in this script.") # Format final order - # Our scripts use axes as 0, 1, 2 rather than 1, 2, 3: adding -1. - axes_to_flip = [] - swapped_order = [] - for next_axis in args.final_order: - if next_axis in [1, 2, 3]: - swapped_order.append(next_axis - 1) - elif next_axis in [-1, -2, -3]: - axes_to_flip.append(abs(next_axis) - 1) - swapped_order.append(abs(next_axis) - 1) - else: - parser.error("Sorry, final order not understood.") + axes_to_flip, swapped_order = find_flip_swap_from_order(args.final_order) # Verifying that user did not ask for, ex, -xxy if len(np.unique(swapped_order)) != 3: @@ -95,7 +86,7 @@ def main(): parser.error("b-vectors format for a .b file should be FSL, " "and contain 3 lines (x, y, z), but got {}" .format(bvecs.shape[0])) - bvecs = flip_gradient_sampling(bvecs, axes_to_flip, 'fsl') + bvecs = flip_gradient_axis(bvecs, axes_to_flip, 'fsl') bvecs = swap_gradient_axis(bvecs, swapped_order, 'fsl') np.savetxt(args.out_gradient_sampling_file, bvecs, "%.8f") else: # ext == '.b': @@ -104,7 +95,7 @@ def main(): parser.error("b-vectors format for a .b file should be mrtrix, " "and contain 4 columns (x, y, z, bval), but got {}" .format(bvecs.shape[1])) - bvecs = flip_gradient_sampling(bvecs, axes_to_flip, 'mrtrix') + bvecs = flip_gradient_axis(bvecs, axes_to_flip, 'mrtrix') bvecs = swap_gradient_axis(bvecs, swapped_order, 'mrtrix') np.savetxt(args.out_gradient_sampling_file, bvecs, "%.8f %.8f %.8f %0.6f") diff --git a/src/scilpy/cli/scil_volume_validate_correct_strides.py b/src/scilpy/cli/scil_volume_validate_correct_strides.py new file mode 100644 index 000000000..5efd2469d --- /dev/null +++ b/src/scilpy/cli/scil_volume_validate_correct_strides.py @@ -0,0 +1,185 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Detect when data strides are different from [1, 2, 3] and correct them. +The script takes as input a nifti file and outputs a nifti file with the +corrected strides if needed. + +Input file can be 3D or 4D. Only the first 3 dimensions are considered +for the stride correction. In the case of DWI data, we recommand to also input +the b-values and b-vectors files to correct the b-vectors accordingly. If the +--validate_bvecs is set, the script first detects sign flips and/or axes swaps +in the b-vectors from a fiber coherence index [1] and corrects the b-vectors. +Then, the b-vectors are permuted and sign flipped to match the new strides. + +A typical pipeline could be: +>>> scil_volume_validate_correct_strides t1.nii.gz t1_restride.nii.gz +>>> scil_volume_validate_correct_strides dwi.nii.gz dwi_restride.nii.gz + --in_bvec dwi.bvec --out_bvec dwi_restride.bvec --validate_bvec + --in_bval dwi.bval + +------------------------------------------------------------------------------ +Reference: +[1] Schilling KG, Yeh FC, Nath V, Hansen C, Williams O, Resnick S, Anderson AW, + Landman BA. A fiber coherence index for quality control of B-table + orientation in diffusion MRI scans. Magn Reson Imaging. 2019 May;58:82-89. + doi: 10.1016/j.mri.2019.01.018. +------------------------------------------------------------------------------ +""" + +import argparse +import logging + +from dipy.core.gradients import gradient_table +from dipy.io.gradients import read_bvals_bvecs +from dipy.reconst.dti import TensorModel, fractional_anisotropy +import numpy as np +import nibabel as nib + +from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, + find_flip_swap_from_order, + flip_gradient_axis, + is_normalized_bvecs, + normalize_bvecs, + swap_gradient_axis) +from scilpy.image.utils import verify_strides, find_strides_transform +from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, + add_skip_b0_check_arg, assert_inputs_exist, + assert_outputs_exist, add_verbose_arg) +from scilpy.reconst.fiber_coherence import compute_coherence_table_for_transforms +from scilpy.version import version_string + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_data', + help='Path to input nifti file.') + p.add_argument('out_data', + help='Path to output nifti file with corrected strides.') + + p.add_argument('--in_bvec', + help='Path to bvec file (FSL format). If provided, the ' + 'bvecs will \nbe permuted and sign flipped to match ' + 'the new strides.') + p.add_argument('--out_bvec', + help='Path to output bvec file (FSL format). Must be ' + 'provided if --in_bvec is used.') + p.add_argument('--validate_bvec', action='store_true', + help='If set, the script first detects sign flips and/or ' + 'axes swaps \nin the b-vectors from a fiber coherence ' + 'index [1] and corrects \nthe b-vectors before ' + 'permuting/sign flipping them to match the new ' + 'strides. \nIf not set, the b-vectors are only ' + 'permuted and sign flipped to match the new strides.') + p.add_argument('--in_bval', + help='Path to bval file. Must be provided if ' + '--validate_bvecs is used.') + + add_b0_thresh_arg(p) + add_skip_b0_check_arg(p, will_overwrite_with_min=True) + add_verbose_arg(p) + add_overwrite_arg(p) + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + assert_inputs_exist(parser, [args.in_data], + optional=[args.in_bvec, args.in_bval]) + assert_outputs_exist(parser, args, args.out_data, optional=args.out_bvec) + + if args.in_bvec and not args.out_bvec: + parser.error('--out_bvec must be provided if --in_bvec is used.') + if args.validate_bvec and (not args.in_bvec or not args.in_bval): + parser.error('--in_bvec and --in_bval must be provided if ' + '--validate_bvecs is set.') + + # Get the current strides + img = nib.load(args.in_data) + strides, is_stride_correct = verify_strides(img) + if not is_stride_correct: + # Compute the required transform to get to [1, 2, 3] + transform = find_strides_transform(strides) + + # Write the transform in a format compatible with the + # flip_gradient_axis and swap_gradient_axis functions (for bvecs) + axes_to_flip, swapped_order = find_flip_swap_from_order(transform) + + # Write the transform in a format compatible with the nibabel + # as_reoriented function (for image) + ornt = np.column_stack((np.array(swapped_order, dtype=np.int8), + np.where(np.isin(range(len(strides)), + axes_to_flip), + -1, 1))) + # Apply the transform to the image and save it + new_img = img.as_reoriented(ornt) + nib.save(new_img, args.out_data) + + if args.validate_bvec: + logging.info('Validating b-vectors from fiber coherence index...') + # Load and validate the data and bvals/bvecs + data = img.get_fdata().astype(np.float32) + if len(data.shape) != 4: + parser.error('Input data must be DWI (4D) when --validate_bvec ' + 'is set.') + bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + if not is_normalized_bvecs(bvecs): + logging.warning('Your b-vectors do not seem normalized...') + bvecs = normalize_bvecs(bvecs) + args.b0_threshold = check_b0_threshold(bvals.min(), + b0_thr=args.b0_threshold, + skip_b0_check=args.skip_b0_check) + gtab = gradient_table(bvals, bvecs=bvecs, + b0_threshold=args.b0_threshold) + tenmodel = TensorModel(gtab, fit_method='WLS', + min_signal=np.min(data[data > 0])) + # Generate a mask to avoid fitting tensor on the whole image + mask = np.zeros(data.shape[:3], dtype=bool) + # Use a small cubic ROI at the center of the volume + interval_i = slice(data.shape[0]//2 - data.shape[0]//4, + data.shape[0]//2 + data.shape[0]//4) + interval_j = slice(data.shape[1]//2 - data.shape[1]//4, + data.shape[1]//2 + data.shape[1]//4) + interval_k = slice(data.shape[2]//2 - data.shape[2]//4, + data.shape[2]//2 + data.shape[2]//4) + mask[interval_i, interval_j, interval_k] = 1 + # Compute the necessary DTI metrics to compute the coherence of bvecs + tenfit = tenmodel.fit(data, mask=mask) + fa = fractional_anisotropy(tenfit.evals) + evecs = tenfit.evecs.astype(np.float32)[..., 0] + evecs[fa < 0.2] = 0 + coherence, transform = compute_coherence_table_for_transforms(evecs, + fa) + # Find the best transform and apply it to the bvecs if needed + best_t = transform[np.argmax(coherence)] + if (best_t == np.eye(3)).all(): + logging.info('The b-vectors are aligned with the original data.') + valid_bvecs = bvecs + else: + logging.warning('Applying correction to b-vectors.') + logging.info('Transform is: \n{0}.'.format(best_t)) + valid_bvecs = np.dot(bvecs, best_t) + # If the data strides were correct, save the bvecs now + if is_stride_correct: + np.savetxt(args.out_bvec, valid_bvecs.T, "%.8f") + + # Apply the permutation and sign flip to the bvecs and save them + if args.in_bvec and not is_stride_correct: + if not args.validate_bvec: + _, bvecs = read_bvals_bvecs(None, args.in_bvec) + else: + bvecs = valid_bvecs + flipped_bvecs = flip_gradient_axis(bvecs.T, axes_to_flip, 'fsl') + swapped_flipped_bvecs = swap_gradient_axis(flipped_bvecs, + swapped_order, 'fsl') + np.savetxt(args.out_bvec, swapped_flipped_bvecs, "%.8f") + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/tests/test_volume_validate_correct_strides.py b/src/scilpy/cli/tests/test_volume_validate_correct_strides.py new file mode 100644 index 000000000..709a95700 --- /dev/null +++ b/src/scilpy/cli/tests/test_volume_validate_correct_strides.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['processing.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_volume_validate_correct_strides', '--help']) + assert ret.success + + +def test_execution_processing_no_restride(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000.nii.gz') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '-f']) + assert ret.success + + +def test_execution_processing_restride(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.nii.gz') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '-f']) + assert ret.success + + +def test_execution_processing_bvecs(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + '1000.bvec') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '--out_bvec', 'dwi_restride.bvec', '-f']) + assert ret.success + + +def test_execution_processing_validate_bvecs_v1(script_runner, monkeypatch): + # Validate with good data strides and bad bvecs + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.bvec') + in_bval = os.path.join(SCILPY_HOME, 'processing', + '1000.bval') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '--out_bvec', 'dwi_restride.bvec', + '--validate_bvec', '--in_bval', in_bval, '-f']) + assert ret.success + + +def test_execution_processing_validate_bvecs_v2(script_runner, monkeypatch): + # Validate with bad data strides + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + '1000.bvec') + in_bval = os.path.join(SCILPY_HOME, 'processing', + '1000.bval') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '--out_bvec', 'dwi_restride.bvec', + '--validate_bvec', '--in_bval', in_bval, '-f']) + assert ret.success + + +def test_execution_processing_validate_bvecs_v3(script_runner, monkeypatch): + # Validate with non-DWI data and bad bvecs + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'nufo.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.bvec') + in_bval = os.path.join(SCILPY_HOME, 'processing', + '1000.bval') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '--out_bvec', 'dwi_restride.bvec', + '--validate_bvec', '--in_bval', in_bval, '-f']) + assert not ret.success + + +def test_execution_processing_no_out_bvec(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + '1000.bvec') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '-f']) + assert not ret.success + + +def test_execution_processing_no_in_bval(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_1000_bad_strides.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + '1000.bvec') + + ret = script_runner.run(['scil_volume_validate_correct_strides', in_dwi, + 'dwi_restride.nii.gz', '--in_bvec', in_bvec, + '--out_bvec', 'dwi_restride.bvec', + '--validate_bvec', '-f']) + assert not ret.success diff --git a/src/scilpy/gradients/bvec_bval_tools.py b/src/scilpy/gradients/bvec_bval_tools.py index ba99c3e3a..ec8014d94 100644 --- a/src/scilpy/gradients/bvec_bval_tools.py +++ b/src/scilpy/gradients/bvec_bval_tools.py @@ -228,7 +228,39 @@ def str_to_axis_index(axis): return None -def flip_gradient_sampling(bvecs, axes, sampling_type): +def find_flip_swap_from_order(order): + """ + Find the flip and swap necessary to get the bvecs to the given order. This + assumes the original order of the bvecs is 1,2,3. + + Parameters + ---------- + order: list of int + List of axes to flip and swap. + Ex: to only flip y: [1, -2, 3] + Ex: to only swap x and y: [2, 1, 3] + Ex: to first flip x, then permute all three axes: [3, -1, 2] + + Returns + ------- + axes_to_flip: list of int + List of axes to flip (e.g. [0, 1]). + swapped_order: list of int + List of axes in the given order (e.g. [1, 0, 2]). + """ + # Our scripts use axes as 0, 1, 2 rather than 1, 2, 3: adding -1. + axes_to_flip = [] + swapped_order = [] + for next_axis in order: + if next_axis in [1, 2, 3]: + swapped_order.append(next_axis - 1) + elif next_axis in [-1, -2, -3]: + axes_to_flip.append(abs(next_axis) - 1) + swapped_order.append(abs(next_axis) - 1) + return(axes_to_flip, swapped_order) + + +def flip_gradient_axis(bvecs, axes, sampling_type): """ Flip bvecs on chosen axis. diff --git a/src/scilpy/gradients/tests/test_bvec_bval_tools.py b/src/scilpy/gradients/tests/test_bvec_bval_tools.py index 4085dd5d3..f119411be 100644 --- a/src/scilpy/gradients/tests/test_bvec_bval_tools.py +++ b/src/scilpy/gradients/tests/test_bvec_bval_tools.py @@ -3,8 +3,8 @@ from scilpy.gradients.bvec_bval_tools import ( check_b0_threshold, identify_shells, is_normalized_bvecs, - flip_gradient_sampling, normalize_bvecs, round_bvals_to_shell, - str_to_axis_index, swap_gradient_axis) + flip_gradient_axis, find_flip_swap_from_order, normalize_bvecs, + round_bvals_to_shell, str_to_axis_index, swap_gradient_axis) bvecs = np.asarray([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0], @@ -79,9 +79,16 @@ def test_str_to_axis_index(): assert str_to_axis_index('v') is None -def test_flip_gradient_sampling(): +def test_find_flip_swap_from_order(): + order = [1, -3, -2] + axes_to_flip, swapped_order = find_flip_swap_from_order(order) + assert np.array_equal(axes_to_flip, [2, 1]) + assert np.array_equal(swapped_order, [0, 2, 1]) + + +def test_flip_gradient_axis(): fsl_bvecs = bvecs.T - b = flip_gradient_sampling(fsl_bvecs, axes=[0], sampling_type='fsl') + b = flip_gradient_axis(fsl_bvecs, axes=[0], sampling_type='fsl') assert np.array_equal(b, np.asarray([[-1.0, 1.0, 1.0], [-1.0, 0.0, 1.0], [-0.0, 1.0, 0.0], diff --git a/src/scilpy/image/utils.py b/src/scilpy/image/utils.py index d05a87280..3400fb57a 100644 --- a/src/scilpy/image/utils.py +++ b/src/scilpy/image/utils.py @@ -74,6 +74,60 @@ def extract_affine(input_files): return vol.affine +def verify_strides(vol_img): + """Verify if the strides of the given volume are [1, 2, 3]. + + Parameters + ---------- + vol_img : nib.Nifti1Image + Volume image. + + Returns + ------- + strides : np.array + Current strides of the volume. + is_stride_correct : bool + True if the strides are [1, 2, 3], false otherwise. + """ + strides = nib.io_orientation(vol_img.affine).astype(np.int8) + strides = (strides[:, 0] + 1) * strides[:, 1] + # Check if the strides are correct ([1, 2, 3]) + if np.array_equal(strides, [1, 2, 3]): + is_stride_correct = True + logging.warning('Input data already has the correct strides [1, 2, 3].' + ' No correction on data needed and outputed.') + else: + is_stride_correct = False + logging.warning('Input data has strides {}. ' + 'Correcting to [1, 2, 3].'.format(strides)) + return strides, is_stride_correct + + +def find_strides_transform(strides): + """Find the transform required to get to [1, 2, 3] from the current + strides. + + Parameters + ---------- + strides : np.array + Current strides of the volume. + + Returns + ------- + transform : list of int + Transform to apply to get to [1, 2, 3]. + """ + n = len(strides) + transform = [0]*n + for i, m in enumerate(strides): + # Get the axis (0, 1, 2) and the sign of the current stride + axis = abs(m) - 1 + sign = 1 if m > 0 else -1 + # Set the transform for this axis + transform[axis] = sign * (i + 1) + return transform + + def check_slice_indices(vol_img, axis_name, slice_ids): """Check that the given volume can be sliced at the given slice indices along the requested axis. diff --git a/src/scilpy/io/fetcher.py b/src/scilpy/io/fetcher.py index 87c62645b..48112872f 100644 --- a/src/scilpy/io/fetcher.py +++ b/src/scilpy/io/fetcher.py @@ -56,7 +56,7 @@ def get_testing_files_dict(): "plot.zip": "a1dc54cad7e1d17e55228c2518a1b34e", "others.zip": "82248b4888a63b0aeffc8070cc206995", "fodf_filtering.zip": "5985c0644321ecf81fd694fb91e2c898", - "processing.zip": "1ba6869c9d8b58a9b911ba71fdd50a07", + "processing.zip": "0417df00d97272f5887c31acb8948604", "surface_vtk_fib.zip": "241f3afd6344c967d7176b43e4a99a41", "tractograms.zip": "964113f307213523d784b3dbf3a5117a", "mrds.zip": "5abe6092400e11e9bb2423e2c387e774",