diff --git a/src/scilpy/cli/scil_dti_metrics.py b/src/scilpy/cli/scil_dti_metrics.py index 551cbb373..34acab576 100755 --- a/src/scilpy/cli/scil_dti_metrics.py +++ b/src/scilpy/cli/scil_dti_metrics.py @@ -30,7 +30,6 @@ from dipy.core.gradients import gradient_table import dipy.denoise.noise_estimate as ne -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.dti import (TensorModel, color_fa, fractional_anisotropy, geodesic_anisotropy, mean_diffusivity, axial_diffusivity, norm, @@ -41,6 +40,7 @@ from scilpy.dwi.operations import compute_residuals, \ compute_residuals_statistics from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -184,19 +184,29 @@ def main(): assert_headers_compatible(parser, args.in_dwi, args.mask) # Loading - img = nib.load(args.in_dwi) - data = img.get_fdata(dtype=np.float32) - affine = img.affine - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) - logging.info('Tensor estimation with the {} method...'.format(args.method)) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + # Reorient to RAS for DIPY + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + affine = simg.affine + bvals = simg.bvals + bvecs = simg.bvecs if not is_normalized_bvecs(bvecs): - logging.warning('Your b-vectors do not seem normalized...') + logger.warning('Your b-vectors do not seem normalized...') bvecs = normalize_bvecs(bvecs) + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + logging.info('Tensor estimation with the {} method...'.format(args.method)) + # How the b0_threshold is used: gtab.b0s_mask is used # 1) In TensorModel in Dipy: # - The S0 images used as any other image in the design matrix and in @@ -231,7 +241,8 @@ def main(): fiber_tensors = nib.Nifti1Image( tensor_vals_reordered.astype(np.float32), affine) - nib.save(fiber_tensors, args.tensor) + # Use StatefulImage.create_from to ensure original orientation + StatefulImage.create_from(fiber_tensors, simg).save(args.tensor) del tensor_vals, fiber_tensors, tensor_vals_reordered @@ -240,29 +251,34 @@ def main(): FA[np.isnan(FA)] = 0 FA = np.clip(FA, 0, 1) if args.fa: - nib.save(nib.Nifti1Image(FA.astype(np.float32), affine), args.fa) + fa_img = nib.Nifti1Image(FA.astype(np.float32), affine) + StatefulImage.create_from(fa_img, simg).save(args.fa) if args.rgb: RGB = color_fa(FA, tenfit.evecs) - nib.save(nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine), - args.rgb) + rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine) + StatefulImage.create_from(rgb_img, simg).save(args.rgb) if args.ga: GA = geodesic_anisotropy(tenfit.evals) GA[np.isnan(GA)] = 0 - nib.save(nib.Nifti1Image(GA.astype(np.float32), affine), args.ga) + ga_img = nib.Nifti1Image(GA.astype(np.float32), affine) + StatefulImage.create_from(ga_img, simg).save(args.ga) if args.md: MD = mean_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(MD.astype(np.float32), affine), args.md) + md_img = nib.Nifti1Image(MD.astype(np.float32), affine) + StatefulImage.create_from(md_img, simg).save(args.md) if args.ad: AD = axial_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(AD.astype(np.float32), affine), args.ad) + ad_img = nib.Nifti1Image(AD.astype(np.float32), affine) + StatefulImage.create_from(ad_img, simg).save(args.ad) if args.rd: RD = radial_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(RD.astype(np.float32), affine), args.rd) + rd_img = nib.Nifti1Image(RD.astype(np.float32), affine) + StatefulImage.create_from(rd_img, simg).save(args.rd) if args.mode: # Compute tensor mode @@ -271,31 +287,37 @@ def main(): # Since the mode computation can generate NANs when not masked, # we need to remove them. non_nan_indices = np.isfinite(inter_mode) - mode = np.zeros(inter_mode.shape) - mode[non_nan_indices] = inter_mode[non_nan_indices] - nib.save(nib.Nifti1Image(mode.astype(np.float32), affine), args.mode) + mode_data = np.zeros(inter_mode.shape) + mode_data[non_nan_indices] = inter_mode[non_nan_indices] + mode_img = nib.Nifti1Image(mode_data.astype(np.float32), affine) + StatefulImage.create_from(mode_img, simg).save(args.mode) if args.norm: NORM = norm(tenfit.quadratic_form) - nib.save(nib.Nifti1Image(NORM.astype(np.float32), affine), args.norm) + norm_img = nib.Nifti1Image(NORM.astype(np.float32), affine) + StatefulImage.create_from(norm_img, simg).save(args.norm) if args.evecs: - evecs = tenfit.evecs.astype(np.float32) - nib.save(nib.Nifti1Image(evecs, affine), args.evecs) + evecs_data = tenfit.evecs.astype(np.float32) + evecs_img = nib.Nifti1Image(evecs_data, affine) + StatefulImage.create_from(evecs_img, simg).save(args.evecs) # save individual e-vectors also for i in range(3): - nib.save(nib.Nifti1Image(evecs[..., i], affine), - add_filename_suffix(args.evecs, '_v'+str(i+1))) + ev_img = nib.Nifti1Image(evecs_data[..., i], affine) + StatefulImage.create_from(ev_img, simg).save( + add_filename_suffix(args.evecs, '_v'+str(i+1))) if args.evals: - evals = tenfit.evals.astype(np.float32) - nib.save(nib.Nifti1Image(evals, affine), args.evals) + evals_data = tenfit.evals.astype(np.float32) + evals_img = nib.Nifti1Image(evals_data, affine) + StatefulImage.create_from(evals_img, simg).save(args.evals) # save individual e-values also for i in range(3): - nib.save(nib.Nifti1Image(evals[..., i], affine), - add_filename_suffix(args.evals, '_e' + str(i+1))) + eval_img = nib.Nifti1Image(evals_data[..., i], affine) + StatefulImage.create_from(eval_img, simg).save( + add_filename_suffix(args.evals, '_e' + str(i+1))) if args.p_i_signal: S0 = np.mean(data[..., gtab.b0s_mask], axis=-1, keepdims=True) @@ -305,8 +327,8 @@ def main(): if args.mask is not None: pis_mask *= mask - nib.save(nib.Nifti1Image(pis_mask.astype(np.int16), affine), - args.p_i_signal) + pis_img = nib.Nifti1Image(pis_mask.astype(np.int16), affine) + StatefulImage.create_from(pis_img, simg).save(args.p_i_signal) if args.pulsation: STD = np.std(data[..., ~gtab.b0s_mask], axis=-1) @@ -314,8 +336,9 @@ def main(): if args.mask is not None: STD *= mask - nib.save(nib.Nifti1Image(STD.astype(np.float32), affine), - add_filename_suffix(args.pulsation, '_std_dwi')) + std_img = nib.Nifti1Image(STD.astype(np.float32), affine) + StatefulImage.create_from(std_img, simg).save( + add_filename_suffix(args.pulsation, '_std_dwi')) if np.sum(gtab.b0s_mask) <= 1: logger.info('Not enough b=0 images to output standard ' @@ -330,8 +353,9 @@ def main(): if args.mask is not None: STD *= mask - nib.save(nib.Nifti1Image(STD.astype(np.float32), affine), - add_filename_suffix(args.pulsation, '_std_b0')) + std_b0_img = nib.Nifti1Image(STD.astype(np.float32), affine) + StatefulImage.create_from(std_b0_img, simg).save( + add_filename_suffix(args.pulsation, '_std_b0')) if args.residual: if mask is None: @@ -354,7 +378,8 @@ def main(): R, data_diff = compute_residuals( predicted_data=tenfit2_predict.astype(np.float32), real_data=data, b0s_mask=gtab.b0s_mask, mask=mask) - nib.save(nib.Nifti1Image(R.astype(np.float32), affine), args.residual) + res_img = nib.Nifti1Image(R.astype(np.float32), affine) + StatefulImage.create_from(res_img, simg).save(args.residual) # Each volume's residual statistics R_k, q1, q3, iqr, std = compute_residuals_statistics(data_diff) diff --git a/src/scilpy/cli/scil_fodf_metrics.py b/src/scilpy/cli/scil_fodf_metrics.py index 92df1b578..f7b69d5ee 100755 --- a/src/scilpy/cli/scil_fodf_metrics.py +++ b/src/scilpy/cli/scil_fodf_metrics.py @@ -40,6 +40,7 @@ from dipy.direction.peaks import reshape_peaks_for_visualization from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -138,11 +139,14 @@ def main(): assert_headers_compatible(parser, args.in_fODF, args.mask) # Loading - vol = nib.load(args.in_fODF) - data = vol.get_fdata(dtype=np.float32) - affine = vol.affine - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + simg = StatefulImage.load(args.in_fODF) + data = simg.get_fdata(dtype=np.float32) + affine = simg.affine + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) sphere = get_sphere(name=args.sphere) sh_basis, is_legacy = parse_sh_basis_arg(args) @@ -168,26 +172,26 @@ def main(): # Save result if args.nufo: - nib.save(nib.Nifti1Image(nufo_map.astype(np.float32), affine), - args.nufo) + nufo_img = nib.Nifti1Image(nufo_map.astype(np.float32), affine) + StatefulImage.create_from(nufo_img, simg).save(args.nufo) if args.afd_max: - nib.save(nib.Nifti1Image(afd_max.astype(np.float32), affine), - args.afd_max) + afd_max_img = nib.Nifti1Image(afd_max.astype(np.float32), affine) + StatefulImage.create_from(afd_max_img, simg).save(args.afd_max) if args.afd_total: # this is the analytical afd total afd_tot = data[:, :, :, 0] - nib.save(nib.Nifti1Image(afd_tot.astype(np.float32), affine), - args.afd_total) + afd_tot_img = nib.Nifti1Image(afd_tot.astype(np.float32), affine) + StatefulImage.create_from(afd_tot_img, simg).save(args.afd_total) if args.afd_sum: - nib.save(nib.Nifti1Image(afd_sum.astype(np.float32), affine), - args.afd_sum) + afd_sum_img = nib.Nifti1Image(afd_sum.astype(np.float32), affine) + StatefulImage.create_from(afd_sum_img, simg).save(args.afd_sum) if args.rgb: - nib.save(nib.Nifti1Image(rgb_map.astype('uint8'), affine), - args.rgb) + rgb_img = nib.Nifti1Image(rgb_map.astype('uint8'), affine) + StatefulImage.create_from(rgb_img, simg).save(args.rgb) if args.peaks or args.peak_values: if not args.abs_peaks_and_values: @@ -196,15 +200,19 @@ def main(): where=peak_values[..., 0, None] != 0) peak_dirs[...] *= peak_values[..., :, None] if args.peaks: - nib.save(nib.Nifti1Image( + peaks_img = nib.Nifti1Image( reshape_peaks_for_visualization(peak_dirs), - affine), args.peaks) + affine) + StatefulImage.create_from(peaks_img, simg).save(args.peaks) if args.peak_values: - nib.save(nib.Nifti1Image(peak_values, vol.affine), - args.peak_values) + peak_vals_img = nib.Nifti1Image(peak_values, affine) + StatefulImage.create_from(peak_vals_img, simg).save( + args.peak_values) if args.peak_indices: - nib.save(nib.Nifti1Image(peak_indices, vol.affine), args.peak_indices) + peak_indices_img = nib.Nifti1Image(peak_indices, affine) + StatefulImage.create_from(peak_indices_img, simg).save( + args.peak_indices) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fodf_msmt.py b/src/scilpy/cli/scil_fodf_msmt.py index cd782b331..9181ab239 100755 --- a/src/scilpy/cli/scil_fodf_msmt.py +++ b/src/scilpy/cli/scil_fodf_msmt.py @@ -22,15 +22,14 @@ from dipy.core.gradients import gradient_table, unique_bvals_tolerance from dipy.data import get_sphere -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response -import nibabel as nib import numpy as np from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, assert_inputs_exist, assert_outputs_exist, add_sh_basis_args, add_skip_b0_check_arg, @@ -132,9 +131,20 @@ def main(): wm_frf = np.loadtxt(args.in_wm_frf) gm_frf = np.loadtxt(args.in_gm_frf) csf_frf = np.loadtxt(args.in_csf_frf) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # Orientation standardization? + # Reconstruction logic (dipy/scilpy) often prefers a specific orientation or consistency. + # We reorient secondary inputs to match the primary one. + # If we want to be fully robust, we could force RAS here, but let's see. + # scil_frf_msmt used to_ras(), so let's be consistent. + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.bvecs # Checking data and sh_order wm_frf, gm_frf, csf_frf = verify_frf_files(wm_frf, gm_frf, csf_frf) @@ -142,8 +152,11 @@ def main(): sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking mask - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) # Checking bvals, bvecs values and loading gtab if not is_normalized_bvecs(bvecs): @@ -206,8 +219,8 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32), - vol.affine), args.wm_out_fODF) + res_simg = StatefulImage.from_data(wm_coeff.astype(np.float32), simg) + res_simg.save(args.wm_out_fODF) if args.gm_out_fODF: gm_coeff = shm_coeff[..., 1] @@ -218,8 +231,8 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32), - vol.affine), args.gm_out_fODF) + res_simg = StatefulImage.from_data(gm_coeff.astype(np.float32), simg) + res_simg.save(args.gm_out_fODF) if args.csf_out_fODF: csf_coeff = shm_coeff[..., 0] @@ -230,18 +243,18 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32), - vol.affine), args.csf_out_fODF) + res_simg = StatefulImage.from_data(csf_coeff.astype(np.float32), simg) + res_simg.save(args.csf_out_fODF) if args.vf: - nib.save(nib.Nifti1Image(vf.astype(np.float32), - vol.affine), args.vf) + res_simg = StatefulImage.from_data(vf.astype(np.float32), simg) + res_simg.save(args.vf) if args.vf_rgb: vf_rgb = vf / np.max(vf) * 255 vf_rgb = np.clip(vf_rgb, 0, 255) - nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8), - vol.affine), args.vf_rgb) + res_simg = StatefulImage.from_data(vf_rgb.astype(np.uint8), simg) + res_simg.save(args.vf_rgb) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fodf_ssst.py b/src/scilpy/cli/scil_fodf_ssst.py index 73a420e08..dcc8cb473 100755 --- a/src/scilpy/cli/scil_fodf_ssst.py +++ b/src/scilpy/cli/scil_fodf_ssst.py @@ -12,7 +12,6 @@ from dipy.core.gradients import gradient_table from dipy.data import get_sphere -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel import nibabel as nib import numpy as np @@ -21,6 +20,7 @@ normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, add_skip_b0_check_arg, add_verbose_arg, @@ -77,13 +77,22 @@ def main(): # Loading data full_frf = np.loadtxt(args.frf_file) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # Reorient to RAS for DIPY + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.bvecs # Checking mask - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) sh_order = args.sh_order sh_basis, is_legacy = parse_sh_basis_arg(args) @@ -134,9 +143,11 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32), - affine=vol.affine, - header=vol.header), args.out_fODF) + + fodf_img = nib.Nifti1Image(shm_coeff.astype(np.float32), + affine=simg.affine, + header=simg.header) + StatefulImage.create_from(fodf_img, simg).save(args.out_fODF) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_frf_msmt.py b/src/scilpy/cli/scil_frf_msmt.py index 0b4640cbc..afb518889 100755 --- a/src/scilpy/cli/scil_frf_msmt.py +++ b/src/scilpy/cli/scil_frf_msmt.py @@ -26,13 +26,12 @@ import logging from dipy.core.gradients import unique_bvals_tolerance -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np from scilpy.dwi.utils import extract_dwi_shell from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_precision_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, @@ -157,9 +156,15 @@ def main(): roi_radii = assert_roi_radii_format(parser) # Loading - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # FRF computation often expects RAS (via dipy) + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.bvecs dti_lim = args.dti_bval_limit @@ -172,7 +177,7 @@ def main(): list_bvals = unique_bvals_tolerance(bvals, tol=args.tolerance) if not np.all(list_bvals <= dti_lim): _, data_dti, bvals_dti, bvecs_dti = extract_dwi_shell( - vol, bvals, bvecs, list_bvals[list_bvals <= dti_lim], + simg, bvals, bvecs, list_bvals[list_bvals <= dti_lim], tol=args.tolerance) bvals_dti = np.squeeze(bvals_dti) else: @@ -180,14 +185,29 @@ def main(): bvals_dti = None bvecs_dti = None - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None - mask_wm = get_data_as_mask(nib.load(args.mask_wm), - dtype=bool) if args.mask_wm else None - mask_gm = get_data_as_mask(nib.load(args.mask_gm), - dtype=bool) if args.mask_gm else None - mask_csf = get_data_as_mask(nib.load(args.mask_csf), - dtype=bool) if args.mask_csf else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + mask_wm = None + if args.mask_wm: + mask_wm_simg = StatefulImage.load(args.mask_wm) + mask_wm_simg.reorient(simg.axcodes) + mask_wm = get_data_as_mask(mask_wm_simg, dtype=bool) + + mask_gm = None + if args.mask_gm: + mask_gm_simg = StatefulImage.load(args.mask_gm) + mask_gm_simg.reorient(simg.axcodes) + mask_gm = get_data_as_mask(mask_gm_simg, dtype=bool) + + mask_csf = None + if args.mask_csf: + mask_csf_simg = StatefulImage.load(args.mask_csf) + mask_csf_simg.reorient(simg.axcodes) + mask_csf = get_data_as_mask(mask_csf_simg, dtype=bool) # Processing responses, frf_masks = compute_msmt_frf(data, bvals, bvecs, @@ -208,10 +228,10 @@ def main(): # Saving masks_files = [args.wm_frf_mask, args.gm_frf_mask, args.csf_frf_mask] - for mask, mask_file in zip(frf_masks, masks_files): + for frf_mask, mask_file in zip(frf_masks, masks_files): if mask_file: - nib.save(nib.Nifti1Image(mask.astype(np.uint8), vol.affine), - mask_file) + res_simg = StatefulImage.from_data(frf_mask.astype(np.uint8), simg) + res_simg.save(mask_file) frf_out = [args.out_wm_frf, args.out_gm_frf, args.out_csf_frf] diff --git a/src/scilpy/cli/scil_frf_ssst.py b/src/scilpy/cli/scil_frf_ssst.py index bd027da42..8b707c5af 100755 --- a/src/scilpy/cli/scil_frf_ssst.py +++ b/src/scilpy/cli/scil_frf_ssst.py @@ -16,12 +16,11 @@ import argparse import logging -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_precision_arg, add_skip_b0_check_arg, add_verbose_arg, @@ -103,18 +102,31 @@ def main(): roi_radii = assert_roi_radii_format(parser) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # FRF computation often expects RAS (via dipy) + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.bvecs - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) args.b0_threshold = check_b0_threshold(bvals.min(), b0_thr=args.b0_threshold, skip_b0_check=args.skip_b0_check) - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None - mask_wm = get_data_as_mask(nib.load(args.mask_wm), - dtype=bool) if args.mask_wm else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + mask_wm = None + if args.mask_wm: + mask_wm_simg = StatefulImage.load(args.mask_wm) + mask_wm_simg.reorient(simg.axcodes) + mask_wm = get_data_as_mask(mask_wm_simg, dtype=bool) full_response = compute_ssst_frf( data, bvals, bvecs, args.b0_threshold, mask=mask, diff --git a/src/scilpy/cli/scil_gradients_validate_correct.py b/src/scilpy/cli/scil_gradients_validate_correct.py index a52b1107f..61a1daa16 100755 --- a/src/scilpy/cli/scil_gradients_validate_correct.py +++ b/src/scilpy/cli/scil_gradients_validate_correct.py @@ -2,23 +2,15 @@ # -*- coding: utf-8 -*- """ Detect sign flips and/or axes swaps in the gradients table from a fiber -coherence index [1]. The script takes as input the principal direction(s) -at each voxel, the b-vectors and the fractional anisotropy map and outputs -a corrected b-vectors file. +coherence index [1]. The script takes as input the DWI, b-values and b-vectors +and outputs a corrected b-vectors file. A typical pipeline could be: ->>> scil_dti_metrics dwi.nii.gz bval bvec --not_all --fa fa.nii.gz - --evecs peaks.nii.gz ->>> scil_gradients_validate_correct bvec peaks_v1.nii.gz fa.nii.gz bvec_corr +>>> scil_gradients_validate_correct dwi.nii.gz bval bvec bvec_corr -Note that peaks_v1.nii.gz is the file containing the direction associated -to the highest eigenvalue at each voxel. - -It is also possible to use a file containing multiple principal directions per -voxel, given that they are sorted by decreasing amplitude. In that case, the -first direction (with the highest amplitude) will be chosen for validation. -Only 4D data is supported, so the directions must be stored in a single -dimension. For example, peaks.nii.gz from scil_fodf_metrics could be used. +The script refits the DTI model 24 times (once for each possible axis +permutation and flip) and chooses the one that maximizes the fiber coherence +index. For performance, the fit is only performed on voxels with FA > 0.5. ------------------------------------------------------------------------------ Reference: @@ -30,17 +22,22 @@ """ import argparse +import itertools import logging -from dipy.io.gradients import read_bvals_bvecs +from dipy.core.gradients import gradient_table +from dipy.reconst.dti import TensorModel import numpy as np -import nibabel as nib +from tqdm import tqdm from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_verbose_arg, - assert_headers_compatible) + add_b0_thresh_arg, add_skip_b0_check_arg) from scilpy.io.image import get_data_as_mask -from scilpy.reconst.fiber_coherence import compute_coherence_table_for_transforms +from scilpy.io.stateful_image import StatefulImage +from scilpy.gradients.bvec_bval_tools import check_b0_threshold +from scilpy.reconst.fiber_coherence import (compute_fiber_coherence, + NB_FLIPS) from scilpy.version import version_string @@ -49,25 +46,24 @@ def _build_arg_parser(): formatter_class=argparse.RawTextHelpFormatter, epilog=version_string) + p.add_argument('in_dwi', + help='Path to the input DWI file.') + p.add_argument('in_bval', + help='Path to the b-values file.') p.add_argument('in_bvec', - help='Path to bvec file.') - p.add_argument('in_peaks', - help='Path to peaks file.') - p.add_argument('in_FA', - help='Path to the fractional anisotropy file.') + help='Path to the b-vectors file to validate.') p.add_argument('out_bvec', help='Path to corrected bvec file (FSL format).') p.add_argument('--mask', - help='Path to an optional mask. If set, FA and Peaks will ' - 'only be used inside the mask.') - p.add_argument('--fa_threshold', default=0.2, type=float, + help='Path to an optional mask. If set, DTI fit will ' + 'only be performed inside the mask.') + p.add_argument('--fa_threshold', default=0.5, type=float, help='FA threshold. Only voxels with FA higher ' 'than fa_threshold will be considered. [%(default)s]') - p.add_argument('--column_wise', action='store_true', - help='Specify if input peaks are column-wise (..., 3, N) ' - 'instead of row-wise (..., N, 3).') + add_b0_thresh_arg(p) + add_skip_b0_check_arg(p, will_overwrite_with_min=True) add_verbose_arg(p) add_overwrite_arg(p) return p @@ -78,45 +74,94 @@ def main(): args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - assert_inputs_exist(parser, [args.in_bvec, args.in_peaks, args.in_FA], + assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], optional=args.mask) assert_outputs_exist(parser, args, args.out_bvec) - assert_headers_compatible(parser, [args.in_peaks, args.in_FA], - optional=args.mask) - - _, bvecs = read_bvals_bvecs(None, args.in_bvec) - fa = nib.load(args.in_FA).get_fdata() - peaks = nib.load(args.in_peaks).get_fdata() - - if peaks.shape[-1] > 3: - logging.info('More than one principal direction per voxel was given.') - peaks = peaks[..., 0:3] - logging.info('The first peak is assumed to be the biggest.') - - # convert peaks to a volume of shape (H, W, D, N, 3) - if args.column_wise: - peaks = np.reshape(peaks, peaks.shape[:3] + (3, -1)) - peaks = np.transpose(peaks, axes=(0, 1, 2, 4, 3)) - else: - peaks = np.reshape(peaks, peaks.shape[:3] + (-1, 3)) - peaks = np.squeeze(peaks) - if args.mask: - mask = get_data_as_mask(nib.load(args.mask), ref_shape=peaks.shape) - fa[np.logical_not(mask)] = 0 - peaks[np.logical_not(mask)] = 0 + # Loading data + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + simg.to_ras() - peaks[fa < args.fa_threshold] = 0 - coherence, transform = compute_coherence_table_for_transforms(peaks, fa) + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.bvecs + + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + # Initial DTI fit to get FA and identify high-FA voxels + args.b0_threshold = check_b0_threshold(bvals.min(), + b0_thr=args.b0_threshold, + skip_b0_check=args.skip_b0_check) + gtab = gradient_table(bvals, bvecs=bvecs, b0_threshold=args.b0_threshold) + tenmodel = TensorModel(gtab, fit_method='WLS', + min_signal=np.min(data[data > 0])) + tenfit = tenmodel.fit(data, mask=mask) + fa = tenfit.fa + + # Define high-FA mask for coherence calculation + high_fa_mask = fa > args.fa_threshold + if mask is not None: + high_fa_mask &= mask + + if np.sum(high_fa_mask) == 0: + logging.error('No voxels found with FA > {}. Aborting.' + .format(args.fa_threshold)) + return + + # Generate 24 possible permutation/flips of gradient directions + permutations = list(itertools.permutations([0, 1, 2])) + transforms = np.zeros((len(permutations) * NB_FLIPS, 3, 3)) + for i in range(len(permutations)): + transforms[i * NB_FLIPS, np.arange(3), permutations[i]] = 1 + for ii in range(3): + flip = np.eye(3) + flip[ii, ii] = -1 + transforms[ii + i * NB_FLIPS + + 1] = transforms[i * NB_FLIPS].dot(flip) + + # Iterative refit and coherence calculation + best_coherence = -1 + best_t = None + + logging.info('Refitting DTI 24 times for gradient validation...') + for t in tqdm(transforms): + # Transform bvecs + # Note: Dipy expects bvecs as (N, 3). We apply the transform to axes. + # G' = G @ T + bvecs_candidate = bvecs @ t + + gtab_candidate = gradient_table(bvals, bvecs=bvecs_candidate, + b0_threshold=args.b0_threshold) + tenmodel_candidate = TensorModel(gtab_candidate, fit_method='WLS', + min_signal=np.min(data[data > 0])) + + # Fit ONLY on the high-FA mask to save time + tenfit_candidate = tenmodel_candidate.fit(data, mask=high_fa_mask) + + # Extract the principal direction (v1) + # evecs is (H, W, D, 3, 3), evecs[..., 0] is the first eigenvector (peak) + peaks = tenfit_candidate.evecs[..., 0] + + # Compute coherence + coherence = compute_fiber_coherence(peaks, fa) + + if coherence > best_coherence: + best_coherence = coherence + best_t = t - best_t = transform[np.argmax(coherence)] if (best_t == np.eye(3)).all(): - logging.info('b-vectors are already correct.') + logging.info('b-vectors are already correct. Coherence: {:.2f}' + .format(best_coherence)) correct_bvecs = bvecs else: - logging.info('Applying correction to b-vectors. ' - 'Transform is: \n{0}.'.format(best_t)) - correct_bvecs = np.dot(bvecs, best_t) + logging.info('Applying correction to b-vectors. Coherence: {:.2f} ' + '\nTransform is: \n{}.'.format(best_coherence, best_t)) + correct_bvecs = bvecs @ best_t logging.info('Saving bvecs to file: {0}.'.format(args.out_bvec)) diff --git a/src/scilpy/cli/scil_mti_maps_MT.py b/src/scilpy/cli/scil_mti_maps_MT.py index b39cb80bd..1944d3838 100755 --- a/src/scilpy/cli/scil_mti_maps_MT.py +++ b/src/scilpy/cli/scil_mti_maps_MT.py @@ -93,6 +93,7 @@ import numpy as np from scilpy.io.mti import add_common_args_mti, load_and_verify_mti +from scilpy.io.image import load_img, get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, add_verbose_arg, assert_output_dirs_exist_and_empty) @@ -186,7 +187,8 @@ def main(): optional=args.in_mtoff_t1 or [] + [args.mask]) # Define reference image for saving maps - affine = nib.load(input_maps_lists[0][0]).affine + ref_img, _ = load_img(input_maps_lists[0][0]) + affine = ref_img.affine # Other checks, loading, saving contrast_maps. single_echo, flip_angles, rep_times, B1_map, contrast_maps = \ @@ -251,8 +253,13 @@ def main(): img_data_list.append(MTsat) # Apply thresholds on maps + mask_data = None + if args.mask: + mask_img, _ = load_img(args.mask) + mask_data = get_data_as_mask(mask_img) + for i, map in enumerate(img_data_list): - img_data_list[i] = threshold_map(map, args.mask, 0, 100) + img_data_list[i] = threshold_map(map, mask_data, 0, 100) # Save ihMT and MT images if args.filtering: diff --git a/src/scilpy/cli/scil_mti_maps_ihMT.py b/src/scilpy/cli/scil_mti_maps_ihMT.py index 7ab913273..c669aecfa 100755 --- a/src/scilpy/cli/scil_mti_maps_ihMT.py +++ b/src/scilpy/cli/scil_mti_maps_ihMT.py @@ -108,6 +108,7 @@ import numpy as np from scilpy.io.mti import add_common_args_mti, load_and_verify_mti +from scilpy.io.image import load_img, get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, add_verbose_arg, assert_output_dirs_exist_and_empty) @@ -272,8 +273,14 @@ def main(): # Apply thresholds on maps upper_thresholds = [100, 100, 10, 10] idx_contrast_lists = [[0, 1, 2, 3, 4], [3, 4], [0, 1, 2, 3], [3, 4]] + + mask_data = None + if args.mask: + mask_img, _ = load_img(args.mask) + mask_data = get_data_as_mask(mask_img) + for i, map in enumerate(img_data): - img_data[i] = threshold_map(map, args.mask, 0, upper_thresholds[i], + img_data[i] = threshold_map(map, mask_data, 0, upper_thresholds[i], idx_contrast_list=idx_contrast_lists[i], contrast_maps=contrast_maps) diff --git a/src/scilpy/cli/scil_tracking_local.py b/src/scilpy/cli/scil_tracking_local.py index fed1ab1d9..95dd4d525 100755 --- a/src/scilpy/cli/scil_tracking_local.py +++ b/src/scilpy/cli/scil_tracking_local.py @@ -67,6 +67,7 @@ from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import BinaryStoppingCriterion from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_headers_compatible, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, @@ -187,14 +188,16 @@ def main(): # when providing information to dipy (i.e. working as if in voxel space) # will not yield correct results. Tracking is performed in voxel space # in both the GPU and CPU cases. - odf_sh_img = nib.load(args.in_odf) - if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]), - odf_sh_img.header.get_zooms()[0], atol=1e-03): + odf_sh_simg = StatefulImage.load(args.in_odf) + if not np.allclose(np.mean(odf_sh_simg.header.get_zooms()[:3]), + odf_sh_simg.header.get_zooms()[0], atol=1e-03): parser.error( 'ODF SH file is not isotropic. Tracking cannot be ran robustly.') logging.debug("Loading masks and finding seeds.") - mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool) + mask_simg = StatefulImage.load(args.in_mask) + mask_simg.reorient(odf_sh_simg.axcodes) + mask_data = get_data_as_mask(mask_simg, dtype=bool) if args.npv: nb_seeds = args.npv @@ -206,13 +209,14 @@ def main(): nb_seeds = 1 seed_per_vox = True - voxel_size = odf_sh_img.header.get_zooms()[0] + voxel_size = odf_sh_simg.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size - seed_img = nib.load(args.in_seed) + seed_simg = StatefulImage.load(args.in_seed) + seed_simg.reorient(odf_sh_simg.axcodes) sh_basis, is_legacy = parse_sh_basis_arg(args) - if np.count_nonzero(seed_img.get_fdata(dtype=np.float32)) == 0: + if np.count_nonzero(seed_simg.get_fdata(dtype=np.float32)) == 0: raise IOError('The image {} is empty. ' 'It can\'t be loaded as ' 'seeding mask.'.format(args.in_seed)) @@ -224,13 +228,16 @@ def main(): seeds = np.squeeze(load_matrix_in_any_format(args.in_custom_seeds)) else: seeds = track_utils.random_seeds_from_mask( - seed_img.get_fdata(dtype=np.float32), + seed_simg.get_fdata(dtype=np.float32), np.eye(4), seeds_count=nb_seeds, seed_count_per_voxel=seed_per_vox, random_seed=args.seed) total_nb_seeds = len(seeds) + # ODF data + odf_sh_data = odf_sh_simg.get_fdata(dtype=np.float32) + if not args.use_gpu: # LocalTracking.maxlen is actually the maximum length # per direction, we need to filter post-tracking. @@ -239,7 +246,7 @@ def main(): logging.info("Starting CPU local tracking.") streamlines_generator = LocalTracking( get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_sh_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, voxel_size, args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, @@ -258,15 +265,12 @@ def main(): # to agree with DIPY's implementation max_strl_len = int(2.0 * args.max_length / args.step_size) + 1 - # data volume - odf_sh = odf_sh_img.get_fdata(dtype=np.float32) - # GPU tracking needs the full sphere sphere = get_sphere(name=args.sphere).subdivide(n=args.sub_sphere) logging.info("Starting GPU local tracking.") streamlines_generator = GPUTacker( - odf_sh, mask_data, seeds, + odf_sh_data, mask_data, seeds, vox_step_size, max_strl_len, theta=get_theta(args.theta, args.algo), sf_threshold=args.sf_threshold, @@ -280,7 +284,7 @@ def main(): # save streamlines on-the-fly to file save_tractogram(streamlines_generator, tracts_format, - odf_sh_img, total_nb_seeds, args.out_tractogram, + odf_sh_simg, total_nb_seeds, args.out_tractogram, args.min_length, args.max_length, args.compress_th, args.save_seeds, args.verbose) # Final logging diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 8904efad2..026df0d98 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -50,15 +50,13 @@ import time import dipy.core.geometry as gm +from dipy.io.stateful_tractogram import Space, Origin import nibabel as nib -import numpy as np - -from dipy.io.stateful_tractogram import StatefulTractogram, Space -from dipy.io.stateful_tractogram import Origin -from dipy.io.streamline import save_tractogram from nibabel.streamlines import detect_format, TrkFile +import numpy as np -from scilpy.io.image import assert_same_resolution +from scilpy.io.image import assert_same_resolution, get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_processes_arg, add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -74,7 +72,8 @@ add_tracking_options, get_theta, verify_streamline_length_options, - verify_seed_options) + verify_seed_options, + save_tractogram) from scilpy.version import version_string @@ -213,15 +212,23 @@ def main(): our_space = Space.VOX our_origin = Origin('center') + # ------- INSTANTIATING PROPAGATOR ------- + logging.info("Loading ODF SH data.") + odf_sh_simg = StatefulImage.load(args.in_odf) + odf_sh_data = odf_sh_simg.get_fdata(caching='unchanged', dtype=float) + odf_sh_res = odf_sh_simg.header.get_zooms()[:3] + dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) + logging.info("Loading seeding mask.") - seed_img = nib.load(args.in_seed) - seed_data = seed_img.get_fdata(caching='unchanged', dtype=float) + seed_simg = StatefulImage.load(args.in_seed) + seed_simg.reorient(odf_sh_simg.axcodes) + seed_data = seed_simg.get_fdata(caching='unchanged', dtype=float) if np.count_nonzero(seed_data) == 0: raise IOError('The image {} is empty. ' 'It can\'t be loaded as ' 'seeding mask.'.format(args.in_seed)) - seed_res = seed_img.header.get_zooms()[:3] + seed_res = seed_simg.header.get_zooms()[:3] # ------- INSTANTIATING SEED GENERATOR ------- if args.in_custom_seeds: @@ -248,24 +255,18 @@ def main(): ' value > 0.'.format(args.in_seed)) logging.info("Loading tracking mask.") - mask_img = nib.load(args.in_mask) - mask_data = mask_img.get_fdata(caching='unchanged', dtype=float) - mask_res = mask_img.header.get_zooms()[:3] + mask_simg = StatefulImage.load(args.in_mask) + mask_simg.reorient(odf_sh_simg.axcodes) + mask_data = mask_simg.get_fdata(caching='unchanged', dtype=float) + mask_res = mask_simg.header.get_zooms()[:3] mask = DataVolume(mask_data, mask_res, args.mask_interp) - # ------- INSTANTIATING PROPAGATOR ------- - logging.info("Loading ODF SH data.") - odf_sh_img = nib.load(args.in_odf) - odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float) - odf_sh_res = odf_sh_img.header.get_zooms()[:3] - dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) - logging.info("Instantiating propagator.") # Converting step size to vox space # We only support iso vox for now but allow slightly different vox 1e-3. assert np.allclose(np.mean(odf_sh_res[:3]), odf_sh_res, atol=1e-03) - voxel_size = odf_sh_img.header.get_zooms()[0] + voxel_size = odf_sh_simg.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size # Using space and origin in the propagator: vox and center, like @@ -281,9 +282,10 @@ def main(): # ------- INSTANTIATING RAP OBJECT ------- if args.rap_mask: logging.info("Loading RAP mask.") - rap_img = nib.load(args.rap_mask) - rap_data = rap_img.get_fdata(caching='unchanged', dtype=float) - rap_res = rap_img.header.get_zooms()[:3] + rap_simg = StatefulImage.load(args.rap_mask) + rap_simg.reorient(odf_sh_simg.axcodes) + rap_data = rap_simg.get_fdata(caching='unchanged', dtype=float) + rap_res = rap_simg.header.get_zooms()[:3] rap_mask = DataVolume(rap_data, rap_res, args.mask_interp) else: rap_mask = None @@ -295,11 +297,13 @@ def main(): rap = None logging.info("Instantiating tracker.") + # We must force save_seeds=True so that Tracker returns (streamlines, seeds) + # as expected by scilpy.tracking.utils.save_tractogram tracker = Tracker(propagator, mask, seed_generator, nbr_seeds, min_nbr_pts, max_nbr_pts, args.max_invalid_nb_points, - compression_th=args.compress_th, + compression_th=None, nbr_processes=args.nbr_processes, - save_seeds=args.save_seeds, + save_seeds=True, mmap_mode='r+', rng_seed=args.rng_seed, track_forward_only=args.forward_only, skip=args.skip, @@ -315,24 +319,11 @@ def main(): "Now saving..." .format(len(streamlines), nbr_seeds, str_time)) - # save seeds if args.save_seeds is given - # We seeded (and tracked) in vox, center, which is what is expected for - # seeds. - if args.save_seeds: - data_per_streamline = {'seeds': seeds} - else: - data_per_streamline = {} - - # Compared with scil_tracking_local, using sft rather than - # LazyTractogram to deal with space. - # Contrary to scilpy or dipy, where space after tracking is vox, here - # space after tracking is voxmm. - # Smallest possible streamline coordinate is (0,0,0), equivalent of - # corner origin (TrackVis) - sft = StatefulTractogram(streamlines, mask_img, - space=our_space, origin=our_origin, - data_per_streamline=data_per_streamline) - save_tractogram(sft, args.out_tractogram) + # save streamlines on-the-fly to file + save_tractogram(zip(streamlines, seeds), tracts_format, + odf_sh_simg, nbr_seeds, args.out_tractogram, + args.min_length, args.max_length, args.compress_th, + args.save_seeds, args.verbose) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_tracking_pft.py b/src/scilpy/cli/scil_tracking_pft.py index 1161c2ca8..8eb2d141d 100755 --- a/src/scilpy/cli/scil_tracking_pft.py +++ b/src/scilpy/cli/scil_tracking_pft.py @@ -38,24 +38,23 @@ from dipy.data import get_sphere, HemiSphere from dipy.direction import (ProbabilisticDirectionGetter, DeterministicMaximumDirectionGetter) -from dipy.io.utils import (get_reference_info, - create_tractogram_header) from dipy.tracking.local_tracking import ParticleFilteringTracking from dipy.tracking.stopping_criterion import (ActStoppingCriterion, CmcStoppingCriterion) from dipy.tracking import utils as track_utils -from dipy.tracking.streamlinespeed import length, compress_streamlines import nibabel as nib -from nibabel.streamlines import LazyTractogram +from nibabel.streamlines import detect_format import numpy as np from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, assert_headers_compatible, add_compression_arg, verify_compression_th) -from scilpy.tracking.utils import get_theta +from scilpy.tracking.utils import (add_out_options, get_theta, + save_tractogram) from scilpy.version import version_string @@ -130,19 +129,13 @@ def _build_arg_parser(): help='Length of PFT forward tracking (mm). ' '[%(default)s]') - out_g = p.add_argument_group('Output options') + out_g = add_out_options(p) out_g.add_argument('--all', dest='keep_all', action='store_true', help='If set, keeps "excluded" streamlines.\n' 'NOT RECOMMENDED, except for debugging.') out_g.add_argument('--seed', type=int, help='Random number generator seed.') - add_overwrite_arg(out_g) - out_g.add_argument('--save_seeds', action='store_true', - help='If set, save the seeds used for the tracking \n ' - 'in the data_per_streamline property.') - - add_compression_arg(out_g) add_verbose_arg(p) return p @@ -187,9 +180,9 @@ def main(): if args.nt and args.nt <= 0: parser.error('Total number of seeds must be > 0.') - fodf_sh_img = nib.load(args.in_sh) - if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]), - fodf_sh_img.header.get_zooms()[0], atol=1e-03): + fodf_sh_simg = StatefulImage.load(args.in_sh) + if not np.allclose(np.mean(fodf_sh_simg.header.get_zooms()[:3]), + fodf_sh_simg.header.get_zooms()[0], atol=1e-03): parser.error( 'SH file is not isotropic. Tracking cannot be ran robustly.') @@ -213,7 +206,7 @@ def main(): # relative_peak_threshold is for initial directions filtering # min_separation_angle is the initial separation angle for peak extraction dg = dgklass.from_shcoeff( - fodf_sh_img.get_fdata(dtype=np.float32), + fodf_sh_simg.get_fdata(dtype=np.float32), max_angle=theta, sphere=tracking_sphere, basis_type=sh_basis, @@ -221,20 +214,23 @@ def main(): pmf_threshold=args.sf_threshold, relative_peak_threshold=args.sf_threshold_init) - map_include_img = nib.load(args.in_map_include) - map_exclude_img = nib.load(args.map_exclude_file) - voxel_size = np.average(map_include_img.header['pixdim'][1:4]) + map_include_simg = StatefulImage.load(args.in_map_include) + map_include_simg.reorient(fodf_sh_simg.axcodes) + map_exclude_simg = StatefulImage.load(args.map_exclude_file) + map_exclude_simg.reorient(fodf_sh_simg.axcodes) + + voxel_size = np.average(map_include_simg.header['pixdim'][1:4]) if not args.act: tissue_classifier = CmcStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32), + map_include_simg.get_fdata(dtype=np.float32), + map_exclude_simg.get_fdata(dtype=np.float32), step_size=args.step_size, average_voxel_size=voxel_size) else: tissue_classifier = ActStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32)) + map_include_simg.get_fdata(dtype=np.float32), + map_exclude_simg.get_fdata(dtype=np.float32)) if args.npv: nb_seeds = args.npv @@ -246,20 +242,26 @@ def main(): nb_seeds = 1 seed_per_vox = True - voxel_size = fodf_sh_img.header.get_zooms()[0] + voxel_size = fodf_sh_simg.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size - seed_img = nib.load(args.in_seed) + + seed_simg = StatefulImage.load(args.in_seed) + seed_simg.reorient(fodf_sh_simg.axcodes) + seeds = track_utils.random_seeds_from_mask( - get_data_as_mask(seed_img, dtype=bool), + get_data_as_mask(seed_simg, dtype=bool), np.eye(4), seeds_count=nb_seeds, seed_count_per_voxel=seed_per_vox, random_seed=args.seed) + total_nb_seeds = len(seeds) # Note that max steps is used once for the forward pass, and # once for the backwards. This doesn't, in fact, control the real # max length max_steps = int(args.max_length / args.step_size) + 1 + # We must force save_seeds=True so that the generator yields (strl, seed) + # as expected by scilpy.tracking.utils.save_tractogram pft_streamlines = ParticleFilteringTracking( dg, tissue_classifier, @@ -273,37 +275,15 @@ def main(): particle_count=args.particles, return_all=args.keep_all, random_seed=args.seed, - save_seeds=args.save_seeds) + save_seeds=True) - scaled_min_length = args.min_length / voxel_size - scaled_max_length = args.max_length / voxel_size + tracts_format = detect_format(args.out_tractogram) - if args.save_seeds: - filtered_streamlines, seeds = \ - zip(*((s, p) for s, p in pft_streamlines - if scaled_min_length <= length(s) <= scaled_max_length)) - data_per_streamlines = {'seeds': lambda: seeds} - else: - filtered_streamlines = \ - (s for s in pft_streamlines - if scaled_min_length <= length(s) <= scaled_max_length) - data_per_streamlines = {} - - if args.compress_th: - filtered_streamlines = ( - compress_streamlines(s, args.compress_th) - for s in filtered_streamlines) - - tractogram = LazyTractogram(lambda: filtered_streamlines, - data_per_streamlines, - affine_to_rasmm=seed_img.affine) - - filetype = nib.streamlines.detect_format(args.out_tractogram) - reference = get_reference_info(seed_img) - header = create_tractogram_header(filetype, *reference) - - # Use generator to save the streamlines on-the-fly - nib.streamlines.save(tractogram, args.out_tractogram, header=header) + # save streamlines on-the-fly to file + save_tractogram(pft_streamlines, tracts_format, + fodf_sh_simg, total_nb_seeds, args.out_tractogram, + args.min_length, args.max_length, args.compress_th, + args.save_seeds, args.verbose) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_volume_apply_transform.py b/src/scilpy/cli/scil_volume_apply_transform.py index e07c454d5..fd60404ee 100755 --- a/src/scilpy/cli/scil_volume_apply_transform.py +++ b/src/scilpy/cli/scil_volume_apply_transform.py @@ -10,7 +10,6 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.image.volume_operations import apply_transform diff --git a/src/scilpy/cli/scil_volume_math.py b/src/scilpy/cli/scil_volume_math.py index e9816ae04..e84771ead 100755 --- a/src/scilpy/cli/scil_volume_math.py +++ b/src/scilpy/cli/scil_volume_math.py @@ -21,6 +21,7 @@ from scilpy.image.volume_math import (get_image_ops, get_operations_doc) from scilpy.io.image import load_img +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_outputs_exist) @@ -76,9 +77,10 @@ def main(): # Find at least one mask, but prefer a 4D mask if there is any. mask = None found_ref = False + ref_img = None for input_arg in args.in_args: if not is_float(input_arg): - ref_img = nib.load(input_arg) + ref_img, _ = load_img(input_arg) found_ref = True if mask is None: mask = np.zeros(ref_img.shape) @@ -92,19 +94,20 @@ def main(): # Load all input masks. input_img = [] for input_arg in args.in_args: + img, dtype = load_img(input_arg) if not is_float(input_arg) and \ - not is_header_compatible(ref_img, input_arg): + not is_header_compatible(ref_img, img): parser.error('Inputs do not have a compatible header.') - img, dtype = load_img(input_arg) if not isinstance(img, float): - args.data_type = img.header.get_data_dtype() if args.data_type is None else args.data_type + args.data_type = img.header.get_data_dtype() \ + if args.data_type is None else args.data_type - if isinstance(img, nib.Nifti1Image) and \ + if isinstance(img, StatefulImage) and \ dtype != ref_img.get_data_dtype() and \ not args.data_type: parser.error('Inputs do not have a compatible data type.\n' 'Use --data_type to specify output datatype.') - if args.operation in binary_op and isinstance(img, nib.Nifti1Image): + if args.operation in binary_op and isinstance(img, StatefulImage): data = img.get_fdata(dtype=np.float64) unique = np.unique(data) if not len(unique) <= 2: @@ -116,7 +119,7 @@ def main(): 'binary arrays, will be converted.\n' 'Non-zeros will be set to ones.') - if isinstance(img, nib.Nifti1Image): + if isinstance(img, StatefulImage): data = img.get_fdata(dtype=np.float64) if data.ndim == 4: mask[np.sum(data, axis=3).astype(bool) > 0] = 1 @@ -144,7 +147,10 @@ def main(): new_img = nib.Nifti1Image(output_data, ref_img.affine, header=ref_img.header) - nib.save(new_img, args.out_image) + + # Use StatefulImage.create_from to ensure original orientation + # ref_img is also a StatefulImage (loaded via load_img earlier) + StatefulImage.create_from(new_img, ref_img).save(args.out_image) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_modify_voxel_order.py b/src/scilpy/cli/scil_volume_modify_voxel_order.py index 5575f5dd8..6588442e1 100644 --- a/src/scilpy/cli/scil_volume_modify_voxel_order.py +++ b/src/scilpy/cli/scil_volume_modify_voxel_order.py @@ -32,6 +32,7 @@ import argparse import logging import nibabel as nib +import numpy as np from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, @@ -54,6 +55,11 @@ def _build_arg_parser(): p.add_argument('--new_voxel_order', required=True, help='The new voxel order (e.g., "RAS", "1,2,3").') + p.add_argument('--in_bvec', + help='Path of the b-vectors file.') + p.add_argument('--out_bvec', + help='Path of the modified b-vectors file to write.') + add_verbose_arg(p) add_overwrite_arg(p) @@ -65,18 +71,31 @@ def main(): args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - assert_inputs_exist(parser, args.in_image) - assert_outputs_exist(parser, args, args.out_image) + assert_inputs_exist(parser, args.in_image, args.in_bvec) + assert_outputs_exist(parser, args, args.out_image, args.out_bvec) img = nib.load(args.in_image) simg = StatefulImage.load(args.in_image) + if args.in_bvec: + bvecs = np.loadtxt(args.in_bvec) + if bvecs.shape[0] == 3 and bvecs.shape[1] != 3: + bvecs = bvecs.T + + # Create dummy bvals to satisfy StatefulImage validation + bvals = np.zeros(len(bvecs)) + simg.attach_gradients(bvals, bvecs) + parsed_voxel_order = parse_voxel_order(args.new_voxel_order, dimensions=len(img.shape)) simg.reorient(parsed_voxel_order) - nib.save(simg, args.out_image) + new_simg = StatefulImage.convert_to_simg(simg, simg.bvals, simg.bvecs) + new_simg.save(args.out_image) + + if args.in_bvec and args.out_bvec: + np.savetxt(args.out_bvec, simg.bvecs.T, fmt='%.8f') if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_resample.py b/src/scilpy/cli/scil_volume_resample.py index 761efa0c2..5575ac6f0 100755 --- a/src/scilpy/cli/scil_volume_resample.py +++ b/src/scilpy/cli/scil_volume_resample.py @@ -17,7 +17,6 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.io.utils import (add_verbose_arg, add_overwrite_arg, @@ -86,12 +85,12 @@ def main(): if args.enforce_voxel_size and not args.voxel_size: parser.error("Cannot enforce voxel size without a voxel size.") - if args.volume_size and (not len(args.volume_size) == 1 and - not len(args.volume_size) == 3): + if args.volume_size and (not len(args.volume_size) == 1 + and not len(args.volume_size) == 3): parser.error('Invalid dimensions for --volume_size.') - if args.voxel_size and (not len(args.voxel_size) == 1 and - not len(args.voxel_size) == 3): + if args.voxel_size and (not len(args.voxel_size) == 1 + and not len(args.voxel_size) == 3): parser.error('Invalid dimensions for --voxel_size.') logging.info('Loading raw data from %s', args.in_image) @@ -100,15 +99,15 @@ def main(): ref_img = None if args.ref: - ref_img = nib.load(args.ref) + ref_img = StatefulImage.load(args.ref) # Must not verify that headers are compatible. But can verify that, at # least, the first columns of their affines are compatible. - img_zoom_invert = [1 / zoom for zoom in ref_img.header.get_zooms()[:3]] + img_zoom_invert = [1 / zoom for zoom in simg.header.get_zooms()[:3]] ref_zoom_invert = [1 / zoom for zoom in ref_img.header.get_zooms()[:3]] - img_affine = np.dot(simg.affine[:3, :3], img_zoom_invert) - ref_affine = np.dot(ref_img.affine[:3, :3], ref_zoom_invert) + img_affine = np.dot(simg.affine[:3, :3], np.diag(img_zoom_invert)) + ref_affine = np.dot(ref_img.affine[:3, :3], np.diag(ref_zoom_invert)) if not np.allclose(img_affine, ref_affine): parser.error("The --ref image should have the same affine as the " diff --git a/src/scilpy/cli/scil_volume_reshape.py b/src/scilpy/cli/scil_volume_reshape.py index 597e40f3f..d9bb6da68 100755 --- a/src/scilpy/cli/scil_volume_reshape.py +++ b/src/scilpy/cli/scil_volume_reshape.py @@ -75,8 +75,8 @@ def main(): assert_inputs_exist(parser, args.in_image, args.ref) assert_outputs_exist(parser, args, args.out_image) - if args.volume_size and (not len(args.volume_size) == 1 and - not len(args.volume_size) == 3): + if args.volume_size and (not len(args.volume_size) == 1 + and not len(args.volume_size) == 3): parser.error('--volume_size takes in either 1 or 3 arguments.') logging.info('Loading raw data from %s', args.in_image) diff --git a/src/scilpy/cli/tests/test_gradients_validate_correct.py b/src/scilpy/cli/tests/test_gradients_validate_correct.py index 8c79a6f41..e1653155b 100644 --- a/src/scilpy/cli/tests/test_gradients_validate_correct.py +++ b/src/scilpy/cli/tests/test_gradients_validate_correct.py @@ -26,27 +26,8 @@ def test_execution_processing_dti_peaks(script_runner, monkeypatch): in_bvec = os.path.join(SCILPY_HOME, 'processing', '1000.bvec') - # generate the peaks file and fa map we'll use to test our script - script_runner.run(['scil_dti_metrics', in_dwi, in_bval, in_bvec, - '--not_all', '--fa', 'fa.nii.gz', - '--evecs', 'evecs.nii.gz']) # test the actual script - ret = script_runner.run(['scil_gradients_validate_correct', in_bvec, - 'evecs_v1.nii.gz', 'fa.nii.gz', - 'bvec_corr', '-v']) - assert ret.success - - -def test_execution_processing_fodf_peaks(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bvec = os.path.join(SCILPY_HOME, 'processing', - 'dwi.bvec') - in_peaks = os.path.join(SCILPY_HOME, 'processing', - 'peaks.nii.gz') - in_fa = os.path.join(SCILPY_HOME, 'processing', - 'fa.nii.gz') - - # test the actual script - ret = script_runner.run(['scil_gradients_validate_correct', in_bvec, - in_peaks, in_fa, 'bvec_corr_fodf', '-v']) + ret = script_runner.run(['scil_gradients_validate_correct', + in_dwi, in_bval, in_bvec, 'bvec_corr.bvec', + '--fa_thresh', '0.5', '-v']) assert ret.success diff --git a/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py b/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py index 847f070e6..afffc067c 100644 --- a/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py +++ b/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py @@ -51,3 +51,77 @@ def test_execution(script_runner, monkeypatch): 'output.nii.gz', '--new_voxel_order=invalid', '-f']) assert not ret.success + + +def test_execution_with_gradients(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # 1. Create a 4D dummy NIfTI (RAS) + n_volumes = 2 + in_file = 'input_4d.nii.gz' + data = np.zeros((10, 10, 10, n_volumes)) + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, in_file) + + # 2. Create bvecs + bvecs = np.array([[0, 0, 0], [1, 0, 0]]) # X-direction in RAS + + in_bvec = 'input.bvec' + np.savetxt(in_bvec, bvecs.T, fmt='%.8f') + + # 3. Run script to modify voxel order to LPS + out_file = 'output_lps.nii.gz' + out_bvec = 'output_lps.bvec' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file, '--new_voxel_order=LPS', + '--in_bvec', in_bvec, '--out_bvec', out_bvec, '-f']) + assert ret.success + + # 4. Verify image + lps_img = nib.load(out_file) + assert nib.aff2axcodes(lps_img.affine) == ('L', 'P', 'S') + + # 5. Verify gradients (they should be reoriented to match LPS) + assert os.path.exists(out_bvec) + + saved_bvecs = np.loadtxt(out_bvec).T # loadtxt returns (3, N) for FSL + + # RAS to LPS: flip X and Y. + # Original bvec [1, 0, 0] (X) should become [-1, 0, 0] + expected_bvecs = np.array([[0, 0, 0], [-1, 0, 0]]) + assert np.allclose(saved_bvecs, expected_bvecs) + + +def test_execution_with_gradients_numeric(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # 1. Create a 4D dummy NIfTI (RAS) + n_volumes = 2 + in_file = 'input_4d_num.nii.gz' + data = np.zeros((10, 10, 10, n_volumes)) + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, in_file) + + # 2. Create bvecs + bvecs = np.array([[0, 0, 0], [1, 0, 0]]) # X-direction in RAS + + in_bvec = 'input_num.bvec' + np.savetxt(in_bvec, bvecs.T, fmt='%.8f') + + # 3. Run script to modify voxel order to LPS using numeric: -1,-2,3 + out_file = 'output_lps_num.nii.gz' + out_bvec = 'output_lps_num.bvec' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file, '--new_voxel_order=-1,-2,3', + '--in_bvec', in_bvec, '--out_bvec', out_bvec, '-f']) + assert ret.success + + # 4. Verify image + lps_img = nib.load(out_file) + assert nib.aff2axcodes(lps_img.affine)[:3] == ('L', 'P', 'S') + + # 5. Verify gradients + assert os.path.exists(out_bvec) + saved_bvecs = np.loadtxt(out_bvec).T + expected_bvecs = np.array([[0, 0, 0], [-1, 0, 0]]) + assert np.allclose(saved_bvecs, expected_bvecs) diff --git a/src/scilpy/cli/tests/test_volume_apply_transform.py b/src/scilpy/cli/tests/test_volume_apply_transform.py index d133f2cbf..83e56a6d0 100644 --- a/src/scilpy/cli/tests/test_volume_apply_transform.py +++ b/src/scilpy/cli/tests/test_volume_apply_transform.py @@ -60,3 +60,25 @@ def test_execution_interp_lin(script_runner, monkeypatch): 'template_lin.nii.gz', '--inverse', '--interp', 'linear', '-f']) assert ret.success + + +def test_execution_and_header_compatibility(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_model = os.path.join(SCILPY_HOME, 'bst', 'template', + 'template0.nii.gz') + in_fa = os.path.join(SCILPY_HOME, 'bst', + 'fa.nii.gz') + in_aff = os.path.join(SCILPY_HOME, 'bst', + 'output0GenericAffine.mat') + out_filename = 'template_lin_header_test.nii.gz' + + # Run the transformation + ret = script_runner.run(['scil_volume_apply_transform', + in_model, in_fa, in_aff, + out_filename, '--inverse', '-f']) + assert ret.success + + # Check for header compatibility between the output and the reference + ret = script_runner.run(['scil_header_validate_compatibility', + out_filename, in_fa]) + assert ret.success, "Headers are not compatible!" diff --git a/src/scilpy/image/tests/test_volume_operations.py b/src/scilpy/image/tests/test_volume_operations.py index 54b789aa1..8baa6123a 100644 --- a/src/scilpy/image/tests/test_volume_operations.py +++ b/src/scilpy/image/tests/test_volume_operations.py @@ -180,7 +180,7 @@ def test_resample_volume(): # Ref: 2x2x2, voxel size 3x3x3 ref3d = np.ones((2, 2, 2)) - ref_affine = np.eye(4)*3 + ref_affine = np.eye(4) * 3 ref_affine[-1, -1] = 1 # 1) Option volume_shape: expecting an output of 2x2x2, which means @@ -217,7 +217,7 @@ def test_resample_volume(): def test_reshape_volume_pad(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Reshaping to 4x4x4, padding with 0 @@ -237,7 +237,7 @@ def test_reshape_volume_pad(): # 4D img (2 "stacked" 3D volumes) simg = StatefulImage( - np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), + np.arange(1, ((3**3) * 2) + 1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Reshaping to 5x5x5, padding with 0 @@ -248,7 +248,7 @@ def test_reshape_volume_pad(): def test_reshape_volume_crop(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Cropping to 1x1x1 @@ -265,7 +265,7 @@ def test_reshape_volume_crop(): # 4D img simg = StatefulImage( - np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), + np.arange(1, ((3**3) * 2) + 1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Cropping to 2x2x2 @@ -278,7 +278,7 @@ def test_reshape_volume_crop(): def test_reshape_volume_dtype(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(np.uint16), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(np.uint16), np.eye(4)) # 1) Staying in 3x3x3, same dtype diff --git a/src/scilpy/image/volume_operations.py b/src/scilpy/image/volume_operations.py index 02c52e957..32ec829ff 100644 --- a/src/scilpy/image/volume_operations.py +++ b/src/scilpy/image/volume_operations.py @@ -188,8 +188,11 @@ def apply_transform(transfo, reference, raise ValueError('Does not support this dataset (shape, type, etc)') moved_nib_img = nib.Nifti1Image(resampled.astype(orig_type), grid2world) - return StatefulImage.create_from(moved_nib_img, - StatefulImage.convert_to_simg(reference)) + if isinstance(reference, StatefulImage): + return StatefulImage.create_from(moved_nib_img, reference) + else: + return StatefulImage.create_from( + moved_nib_img, StatefulImage.convert_to_simg(reference)) def transform_dwi(reg_obj, static, dwi, interpolation='linear'): @@ -272,8 +275,8 @@ def register_image(static, static_grid2world, moving, moving_grid2world, level_iters = [250, 100, 50, 25] if fine else [50, 25, 5] # With images too small, dipy fails with no clear warning. - if (np.any(np.asarray(moving.shape) < 8) or - np.any(np.asarray(static.shape) < 8)): + if (np.any(np.asarray(moving.shape) < 8) + or np.any(np.asarray(static.shape) < 8)): raise ValueError("Current implementation of registration was prepared " "with factors up to 8. Requires images with at least " "8 voxels in each direction.") @@ -397,7 +400,7 @@ def compute_snr(dwi, bval, bvec, b0_thr, mask, noise_mask=None, noise_map=None, # Add the upper half in order to delete the neck and shoulder # when inverting the mask - noise_mask[..., :noise_mask.shape[-1]//2] = 1 + noise_mask[..., :noise_mask.shape[-1] // 2] = 1 # Reverse the mask to get only noise noise_mask = (~noise_mask).astype(bool) diff --git a/src/scilpy/io/image.py b/src/scilpy/io/image.py index ad87af5db..9495bbd0a 100644 --- a/src/scilpy/io/image.py +++ b/src/scilpy/io/image.py @@ -2,11 +2,12 @@ from dipy.io.utils import is_header_compatible import logging -import nibabel as nib import numpy as np import os from scilpy.utils import is_float +from scilpy.io.stateful_image import StatefulImage + def load_img(arg): """ @@ -22,7 +23,7 @@ def load_img(arg): else: if not os.path.isfile(arg): raise ValueError('Input file {} does not exist.'.format(arg)) - img = nib.load(arg) + img = StatefulImage.load(arg) shape = img.header.get_data_shape() dtype = img.header.get_data_dtype() logging.info('Loaded {} of shape {} and data_type {}.'.format( @@ -87,14 +88,18 @@ def get_data_as_mask(mask_img, dtype=np.uint8): Data (dtype : np.uint8 or bool). """ # Verify that out data type is ok - if not (issubclass(np.dtype(dtype).type, np.uint8) or - issubclass(np.dtype(dtype).type, np.dtype(bool).type)): + if not (issubclass(np.dtype(dtype).type, np.uint8) + or issubclass(np.dtype(dtype).type, np.dtype(bool).type)): raise IOError('Output data type must be uint8 or bool. ' 'Current data type is {}.'.format(dtype)) # Verify that loaded datatype is ok curr_type = mask_img.get_data_dtype().type - basename = os.path.basename(mask_img.get_filename()) + if hasattr(mask_img, 'get_filename') and mask_img.get_filename(): + basename = os.path.basename(mask_img.get_filename()) + else: + basename = "unnamed" + if np.issubdtype(curr_type, np.signedinteger) or \ np.issubdtype(curr_type, np.unsignedinteger) \ or np.issubdtype(curr_type, np.dtype(bool).type): diff --git a/src/scilpy/io/mti.py b/src/scilpy/io/mti.py index 7d7ec5090..3b998da75 100644 --- a/src/scilpy/io/mti.py +++ b/src/scilpy/io/mti.py @@ -228,7 +228,7 @@ def _prepare_B1_map(args, flip_angles, extended_dir, affine): """ B1_map = None if args.in_B1_map and args.in_mtoff_t1: - B1_img = nib.load(args.in_B1_map) + B1_img, _ = load_img(args.in_B1_map) B1_map = B1_img.get_fdata(dtype=np.float32) B1_map = adjust_B1_map_intensities(B1_map, nominal=args.B1_nominal) B1_map = smooth_B1_map(B1_map, wdims=args.B1_smooth_dims) diff --git a/src/scilpy/io/stateful_image.py b/src/scilpy/io/stateful_image.py index b61e0b834..dd08dfff7 100644 --- a/src/scilpy/io/stateful_image.py +++ b/src/scilpy/io/stateful_image.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import nibabel as nib +import numpy as np + +from dipy.io.gradients import read_bvals_bvecs from dipy.io.utils import get_reference_info from scilpy.utils.orientation import validate_voxel_order @@ -18,7 +21,8 @@ class StatefulImage(nib.Nifti1Image): def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, original_affine=None, original_dimensions=None, original_voxel_sizes=None, - original_axcodes=None): + original_axcodes=None, bvals=None, bvecs=None, + gradients_original_order=True): """ Initialize a StatefulImage object. @@ -32,6 +36,12 @@ def __init__(self, dataobj, affine, header=None, extra=None, self._original_voxel_sizes = original_voxel_sizes self._original_axcodes = original_axcodes + # Store gradient information + self._bvals = None + self._bvecs = None + if bvals is not None and bvecs is not None: + self.attach_gradients(bvals, bvecs, gradients_original_order) + @classmethod def load(cls, filename, to_orientation="RAS"): """ @@ -110,15 +120,65 @@ def create_from(source, reference): A new StatefulImage with the source image's data and the reference image's original orientation information. """ + bvals = None + bvecs = None + if reference.bvals is not None and reference.bvecs is not None: + if source.ndim >= 4 and len(reference.bvals) == source.shape[3]: + bvals = reference.bvals + bvecs = reference.bvecs + + # If reference orientation != source orientation, reorient bvecs + ref_axcodes = reference.axcodes + source_axcodes_3d = nib.orientations.aff2axcodes(source.affine) + + if ref_axcodes[:3] != source_axcodes_3d: + # Strip 'T' etc. for nibabel + ref_axcodes_3d = ref_axcodes[:3] + + # Use a temporary StatefulImage logic to reorient bvecs + start_ornt = nib.orientations.axcodes2ornt(ref_axcodes_3d) + target_ornt = nib.orientations.axcodes2ornt( + source_axcodes_3d) + transform = nib.orientations.ornt_transform( + start_ornt, target_ornt) + axis_permutation = transform[:, 0].astype(int) + axis_flips = transform[:, 1] + bvecs = bvecs[:, axis_permutation] * axis_flips + return StatefulImage(source.dataobj, source.affine, header=source.header, original_affine=reference._original_affine, original_dimensions=reference._original_dimensions, original_voxel_sizes=reference._original_voxel_sizes, - original_axcodes=reference._original_axcodes) + original_axcodes=reference._original_axcodes, + bvals=bvals, bvecs=bvecs, + gradients_original_order=False) + + @staticmethod + def from_data(data, reference): + """ + Create a new StatefulImage from a numpy array, preserving the original + orientation information from a reference StatefulImage. + + Parameters + ---------- + data : numpy.ndarray + The image data to use for the new StatefulImage. + reference : StatefulImage + The reference image from which to copy original orientation + information. + + Returns + ------- + StatefulImage + A new StatefulImage with the data and the reference + image's original orientation information. + """ + new_img = nib.Nifti1Image(data, reference.affine, reference.header) + return StatefulImage.create_from(new_img, reference) @staticmethod - def convert_to_simg(img): + def convert_to_simg(img, bvals=None, bvecs=None): """ Initialize a StatefulImage from an existing Nifti1Image. @@ -129,13 +189,139 @@ def convert_to_simg(img): ---------- img : nib.Nifti1Image The Nifti1Image to initialize from. + bvals : array-like, optional + B-values. + bvecs : array-like, optional + B-vectors. """ + original_axcodes = nib.orientations.aff2axcodes(img.affine) + if len(img.shape) == 4: + original_axcodes += ('T',) + return StatefulImage(img.dataobj, img.affine, header=img.header, original_affine=img.affine.copy(), original_dimensions=img.header.get_data_shape(), original_voxel_sizes=img.header.get_zooms(), - original_axcodes=nib.orientations.aff2axcodes( - img.affine)) + original_axcodes=original_axcodes, + bvals=bvals, bvecs=bvecs) + + @property + def bvals(self): + """Get the current b-values.""" + return self._bvals + + @property + def bvecs(self): + """Get the current (reoriented) b-vectors.""" + return self._bvecs + + def attach_gradients(self, bvals, bvecs, original_order=True): + """ + Attach b-values and b-vectors to the image. + + Parameters + ---------- + bvals : array-like + B-values. + bvecs : array-like + B-vectors. + original_order : bool, optional + If True, assumes b-vectors are in the original voxel order. + If False, assumes b-vectors match the current in-memory orientation. + Default is True. + """ + self._bvals = np.asanyarray(bvals) + self._bvecs = np.asanyarray(bvecs) + + # Validate shapes + if self._bvals.ndim != 1: + raise ValueError("bvals must be a 1D array.") + if self._bvecs.ndim != 2 or self._bvecs.shape[1] != 3: + raise ValueError("bvecs must be an (N, 3) array.") + if len(self._bvals) != len(self._bvecs): + raise ValueError("bvals and bvecs must have the same length.") + + # Validate against image data + if len(self._bvals) != self.shape[3]: + raise ValueError(f"Number of gradients ({len(self._bvals)}) does " + f"not match number of volumes ({self.shape[3]}).") + + # If current orientation is not original, and we assume original, reorient + if original_order and self.axcodes != self._original_axcodes: + self._reorient_gradients(self._original_axcodes, self.axcodes) + + def load_gradients(self, bval_path, bvec_path): + """ + Load b-values and b-vectors from FSL-formatted files. + + Parameters + ---------- + bval_path : str + Path to the bvals file. + bvec_path : str + Path to the bvecs file. + """ + bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path) + self.attach_gradients(bvals, bvecs) + + def save_gradients(self, bval_path, bvec_path): + """ + Save b-values and b-vectors to FSL-formatted files. + Ensures b-vectors match the original voxel order. + + Parameters + ---------- + bval_path : str + Path to save the bvals file. + bvec_path : str + Path to save the bvecs file. + """ + if self._bvals is None or self._bvecs is None: + raise ValueError("No gradients attached to this StatefulImage.") + + # Reorient back to original for saving + bvecs_to_save = self._bvecs + if self.axcodes != self._original_axcodes: + # We don't want to modify self._bvecs in-place here if we just + # want to save. But simg.save() reorients the whole image back! + # So if we follow that pattern, we should probably reorient + # back, save, and then (if needed) reorient back to current. + # However, simg.save() calls reorient_to_original() which DOES + # modify in-place. + self.reorient_to_original() + bvecs_to_save = self._bvecs + + np.savetxt(bvec_path, bvecs_to_save.T, fmt='%.8f') + np.savetxt(bval_path, self._bvals[None, :], fmt='%.3f') + + def _reorient_gradients(self, start_axcodes, target_axcodes): + """ + Internal helper to reorient b-vectors. + + Parameters + ---------- + start_axcodes : tuple + Starting axis codes. + target_axcodes : tuple + Target axis codes. + """ + if self._bvecs is None: + return + + # Strip 'T' if present + start_axcodes_3d = [c for c in start_axcodes if c != 'T'] + target_axcodes_3d = [c for c in target_axcodes if c != 'T'] + + start_ornt = nib.orientations.axcodes2ornt(start_axcodes_3d) + target_ornt = nib.orientations.axcodes2ornt(target_axcodes_3d) + transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + + axis_permutation = transform[:, 0].astype(int) + axis_flips = transform[:, 1] + + # Apply permutation and flips + # bvecs is (N, 3). We permute columns and multiply by flips. + self._bvecs = self._bvecs[:, axis_permutation] * axis_flips def reorient_to_original(self): """ @@ -163,40 +349,56 @@ def reorient(self, target_axcodes): target_axcodes : str or tuple The target orientation axis codes (e.g., "LPS", ("R", "A", "S")). """ - validate_voxel_order(target_axcodes) + if target_axcodes is None: + raise ValueError("Axis codes cannot be None.") + + # Ensure target_axcodes has the same number of dimensions as self.shape + # by padding with unique placeholder codes if necessary. + target_axcodes = list(target_axcodes) + if len(target_axcodes) < len(self.shape): + extra_codes = ['T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + for i in range(len(target_axcodes), len(self.shape)): + target_axcodes.append(extra_codes[i-3]) + elif len(target_axcodes) > len(self.shape): + target_axcodes = target_axcodes[:len(self.shape)] + target_axcodes = tuple(target_axcodes) + + validate_voxel_order(target_axcodes, dimensions=len(self.shape)) - current_axcodes = nib.orientations.aff2axcodes(self.affine) + current_axcodes = self.axcodes if current_axcodes == tuple(target_axcodes): return - # Check unique are only valid axis codes - valid_codes = {'L', 'R', 'A', 'P', 'S', 'I'} - for code in target_axcodes: - if code not in valid_codes: - raise ValueError(f"Invalid axis code '{code}' in target.") - - # Check L/R, A/P, S/I pairs are not both present - pairs = [('L', 'R'), ('A', 'P'), ('S', 'I')] - for pair in pairs: - if pair[0] in target_axcodes and pair[1] in target_axcodes: - raise ValueError(f"Conflicting axis codes '{pair[0]}' and " - f"'{pair[1]}' in target.") - - # Check no repeated axis codes (LL, RR, etc.) - if len(set(target_axcodes)) != 3: - raise ValueError("Target axis codes must be unique.") - - start_ornt = nib.orientations.axcodes2ornt(current_axcodes) - target_ornt = nib.orientations.axcodes2ornt(target_axcodes) + # Nibabel only handles 3D orientations. If 4D, we assume the 4th + # dimension is time/gradients and doesn't need reorientation. + target_axcodes_3d = [c for c in target_axcodes if c != 'T'] + current_axcodes_3d = [c for c in current_axcodes if c != 'T'] + + start_ornt = nib.orientations.axcodes2ornt(current_axcodes_3d) + target_ornt = nib.orientations.axcodes2ornt(target_axcodes_3d) transform = nib.orientations.ornt_transform(start_ornt, target_ornt) reoriented_img = self.as_reoriented(transform) + + # Reorient gradients before re-initializing + if self._bvecs is not None: + self._reorient_gradients(current_axcodes, target_axcodes) + + # Pass current reoriented gradients to __init__ self.__init__(reoriented_img.dataobj, reoriented_img.affine, reoriented_img.header, original_affine=self._original_affine, original_dimensions=self._original_dimensions, original_voxel_sizes=self._original_voxel_sizes, - original_axcodes=self._original_axcodes) + original_axcodes=self._original_axcodes, + bvals=self._bvals, bvecs=self._bvecs, + gradients_original_order=False) + + # Mark that these gradients are already in target orientation + # wait, __init__ will call attach_gradients(bvals, bvecs, original_order=True) + # by default. I need to change how __init__ calls it if it's from here. + + # I'll update __init__ to accept original_order flag. def to_ras(self): """Convenience method to reorient in-memory data to RAS.""" @@ -227,18 +429,41 @@ def to_reference(self, obj): raise TypeError('Reference object must not be a StatefulImage.') _, _, _, voxel_order = get_reference_info(obj) - self.reorient(voxel_order) + self.reorient(voxel_order[:3]) @property def axcodes(self): """Get the axis codes for the current image orientation.""" - return nib.orientations.aff2axcodes(self.affine) + codes = list(nib.orientations.aff2axcodes(self.affine)) + if len(self.shape) > 3: + extra_codes = ['T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + for i in range(3, len(self.shape)): + codes.append(extra_codes[i-3]) + return tuple(codes) @property def original_axcodes(self): """Get the axis codes for the original image orientation.""" return self._original_axcodes + @property + def original_affine(self): + """Get the original image affine.""" + return self._original_affine + + @property + def original_header(self): + """Get a header matching the original image orientation.""" + # Create a copy of the current header but with original info + header = self.header.copy() + header.set_sform(self._original_affine) + header.set_qform(self._original_affine) + if self._original_voxel_sizes is not None: + header.set_zooms(self._original_voxel_sizes) + if self._original_dimensions is not None: + header.set_data_shape(self._original_dimensions) + return header + def __str__(self): """Return a string representation of the image, including orientation.""" base_str = super().__str__() diff --git a/src/scilpy/io/tests/test_stateful_image.py b/src/scilpy/io/tests/test_stateful_image.py index e0fb840b3..033d6e3e4 100644 --- a/src/scilpy/io/tests/test_stateful_image.py +++ b/src/scilpy/io/tests/test_stateful_image.py @@ -193,7 +193,7 @@ def test_direct_instantiation(): @pytest.mark.parametrize("codes, error_msg", [ (None, "Axis codes cannot be None."), - ("INVALID", "Target axis codes must be of length 3."), + ("INVALID", "Invalid axis code 'N' in target."), ("RAR", "Target axis codes must be unique."), ("LRR", "Target axis codes must be unique."), ("LRA", "Conflicting axis codes 'L' and 'R' in target."), diff --git a/src/scilpy/io/tests/test_stateful_image_gradients.py b/src/scilpy/io/tests/test_stateful_image_gradients.py new file mode 100644 index 000000000..d5d1b73e4 --- /dev/null +++ b/src/scilpy/io/tests/test_stateful_image_gradients.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- + +import os +import pytest +import tempfile +from contextlib import contextmanager + +import nibabel as nib +import numpy as np + +from scilpy.io.stateful_image import StatefulImage + + +@contextmanager +def create_dummy_nifti_with_gradients(filename="test.nii.gz", n_volumes=5): + """ + Create a dummy NIfTI file and gradient files for testing. + """ + with tempfile.TemporaryDirectory() as tmpdir: + shape = (10, 10, 10, n_volumes) + affine = np.eye(4) + data = np.random.rand(*shape).astype(np.float32) + img = nib.Nifti1Image(data, affine) + + file_path = os.path.join(tmpdir, filename) + nib.save(img, file_path) + + bvals = np.random.randint(0, 3000, n_volumes) + bvecs = np.random.randn(n_volumes, 3) + bvecs /= (np.linalg.norm(bvecs, axis=1)[:, None] + 1e-8) + + bval_path = os.path.join(tmpdir, "test.bval") + bvec_path = os.path.join(tmpdir, "test.bvec") + + np.savetxt(bval_path, bvals[None, :], fmt='%d') + np.savetxt(bvec_path, bvecs.T, fmt='%.8f') + + yield file_path, bval_path, bvec_path, bvals, bvecs + + +def test_attach_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + assert np.allclose(simg.bvals, bvals) + assert np.allclose(simg.bvecs, bvecs) + + +def test_load_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.load_gradients(bval_p, bvec_p) + + assert np.allclose(simg.bvals, bvals) + assert np.allclose(simg.bvecs, bvecs, atol=1e-5) + + +def test_reorient_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + # LPS reorientation: flip x and y + simg.to_lps() + assert simg.axcodes == ("L", "P", "S", "T") + + expected_bvecs = bvecs.copy() + expected_bvecs[:, 0] *= -1 + expected_bvecs[:, 1] *= -1 + + assert np.allclose(simg.bvecs, expected_bvecs) + + # Reorient back to RAS + simg.to_ras() + assert simg.axcodes == ("R", "A", "S", "T") + assert np.allclose(simg.bvecs, bvecs) + + +def test_save_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + simg.to_lps() + + tmp_dir = os.path.dirname(img_p) + out_bval = os.path.join(tmp_dir, "out.bval") + out_bvec = os.path.join(tmp_dir, "out.bvec") + + simg.save_gradients(out_bval, out_bvec) + + # Saved gradients should be back in RAS (original) + saved_bvals = np.loadtxt(out_bval) + saved_bvecs = np.loadtxt(out_bvec).T + + assert np.allclose(saved_bvals, bvals) + assert np.allclose(saved_bvecs, bvecs) + + # StatefulImage itself should now be in RAS + assert simg.axcodes == ("R", "A", "S", "T") + + +def test_create_from_with_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + simg.to_lps() + + # Create new simg from source (RAS) but with same reference (LPS) + source_nii = nib.load(img_p) + new_simg = StatefulImage.create_from(source_nii, simg) + + # new_simg matches source_nii (RAS) + assert new_simg.axcodes == ("R", "A", "S", "T") + # bvecs should have been reoriented back to RAS to match source_nii + assert np.allclose(new_simg.bvecs, bvecs) + assert np.allclose(new_simg.bvals, bvals) + + +def test_validation_errors(): + with create_dummy_nifti_with_gradients(n_volumes=5) as \ + (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + + # Wrong number of volumes + with pytest.raises(ValueError, + match="Number of gradients.*does not match number of volumes"): + simg.attach_gradients(bvals[:3], bvecs[:3]) + + # Wrong shape + with pytest.raises(ValueError, match="bvals must be a 1D array"): + simg.attach_gradients(bvals[:, None], bvecs) + + +def test_gradient_consistency_across_orientations(): + """ + Comprehensive test: + 1. Create RAS image + gradients. + 2. Reorient to LAS, LPS, LPI. + 3. Save in those orientations. + 4. Load back and verify they all return to the same RAS state. + """ + n_volumes = 4 + with create_dummy_nifti_with_gradients(n_volumes=n_volumes) as \ + (img_p, bval_p, bvec_p, bvals, bvecs): + simg_ras = StatefulImage.load(img_p) + simg_ras.attach_gradients(bvals, bvecs) + + # Original bvecs are in RAS (matching simg_ras.axcodes) + original_bvecs = simg_ras.bvecs.copy() + + for target_ornt in ["LAS", "LPS", "LPI"]: + with tempfile.TemporaryDirectory() as tmpdir: + # 1. Reorient + simg_ras.reorient(target_ornt) + + # 2. Create a "new" original at this orientation so we can save it AS is + # convert_to_simg sets the current state as the "original" + simg_target = StatefulImage.convert_to_simg( + simg_ras, simg_ras.bvals, simg_ras.bvecs) + + # 3. Save + target_img_p = os.path.join(tmpdir, "target.nii.gz") + target_bval_p = os.path.join(tmpdir, "target.bval") + target_bvec_p = os.path.join(tmpdir, "target.bvec") + + simg_target.save(target_img_p) + simg_target.save_gradients(target_bval_p, target_bvec_p) + + # 4. Load back (defaults to RAS) + simg_verify = StatefulImage.load( + target_img_p, to_orientation="RAS") + simg_verify.load_gradients(target_bval_p, target_bvec_p) + + # 5. Verify + assert simg_verify.axcodes == ("R", "A", "S", "T") + # Threshold for float precision after multiple transforms + assert np.allclose(simg_verify.bvecs, + original_bvecs, atol=1e-5) + assert np.allclose(simg_verify.bvals, bvals) + + # Go back to RAS for next iteration + simg_ras.to_ras() diff --git a/src/scilpy/reconst/mti.py b/src/scilpy/reconst/mti.py index 976c53ad7..93b43c4b4 100644 --- a/src/scilpy/reconst/mti.py +++ b/src/scilpy/reconst/mti.py @@ -5,8 +5,6 @@ import scipy.io import scipy.ndimage -from scilpy.io.image import get_data_as_mask - def py_fspecial_gauss(shape, sigma): """ @@ -151,7 +149,7 @@ def compute_ratio_map(mt_on_single, mt_off, mt_on_dual=None): return MTR -def threshold_map(computed_map, in_mask, +def threshold_map(computed_map, mask_data, lower_threshold, upper_threshold, idx_contrast_list=None, contrast_maps=None): """ @@ -167,7 +165,8 @@ def threshold_map(computed_map, in_mask, ---------- computed_map: 3D-Array data. Myelin map (ihMT or non-ihMT maps) - in_mask: Path to binary T1 mask from T1 segmentation. + mask_data: Numpy array. + Binary T1 mask from T1 segmentation. Must be the sum of GM+WM+CSF. lower_threshold: Value for low thresold upper_thresold: Value for up thresold @@ -188,10 +187,8 @@ def threshold_map(computed_map, in_mask, computed_map[computed_map < lower_threshold] = 0 computed_map[computed_map > upper_threshold] = 0 - # Load and apply sum of T1 probability maps on myelin maps - if in_mask is not None: - mask_image = nib.load(in_mask) - mask_data = get_data_as_mask(mask_image) + # Apply T1 mask on myelin maps + if mask_data is not None: computed_map[np.where(mask_data == 0)] = 0 # Apply threshold based on combination of specific contrast maps diff --git a/src/scilpy/tests/test_tracking_io_alignment.py b/src/scilpy/tests/test_tracking_io_alignment.py new file mode 100644 index 000000000..541a57cb6 --- /dev/null +++ b/src/scilpy/tests/test_tracking_io_alignment.py @@ -0,0 +1,129 @@ +import os +import numpy as np +import nibabel as nib +import pytest +from dipy.io.stateful_tractogram import StatefulTractogram, Space +from dipy.io.streamline import load_tractogram, save_tractogram +from scilpy.tracking.utils import save_tractogram as scil_save_tractogram + +def create_fake_header(affine, shape=(10, 10, 10)): + data = np.zeros(shape) + img = nib.Nifti1Image(data, affine) + return img + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_tracking_io_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + # Rotation 30 deg around Z, scaling, translation + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1] + ]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + # (0,0,0) to (5,5,5) + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Convert to RASMM for StatefulTractogram + # StatefulTractogram expects streamlines in RASMM if space is Space.RASMM + sft = StatefulTractogram(vox_streamlines, img, Space.VOX) + + output_path = str(tmp_path / f"tracto{ext}") + + # Method 1: Use DIPY save_tractogram (standard) + save_tractogram(sft, output_path) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + + # Check streamlines in VOX space + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + assert np.allclose(orig, loaded, atol=1e-3) + + # Check streamlines in RASMM space + sft.to_rasmm() + sft_loaded.to_rasmm() + for orig, loaded in zip(sft.streamlines, sft_loaded.streamlines): + assert np.allclose(orig, loaded, atol=1e-3) + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_scil_save_tractogram_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Generator for scil_save_tractogram + # it yields (streamline, seed) + # We make it a list so it's re-iterable if needed + stream_gen_list = [(s.copy(), s[0].copy()) for s in vox_streamlines] + + output_path = str(tmp_path / f"scil_tracto{ext}") + tracts_format = nib.streamlines.detect_format(output_path) + + # scil_save_tractogram(streamlines_generator, tracts_format, ref_img, total_nb_seeds, + # out_tractogram, min_length, max_length, compress, save_seeds, verbose) + scil_save_tractogram(stream_gen_list, tracts_format, img, len(vox_streamlines), + output_path, 0, 1000, None, False, False) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + # Using a slightly larger tolerance because TRK/TCK might have some precision loss or 0.5 offset handling differences + assert np.allclose(orig, loaded, atol=1e-3) diff --git a/src/scilpy/tracking/tests/test_tracking_utils.py b/src/scilpy/tracking/tests/test_tracking_utils.py new file mode 100644 index 000000000..6bedecb02 --- /dev/null +++ b/src/scilpy/tracking/tests/test_tracking_utils.py @@ -0,0 +1,67 @@ +import numpy as np +import nibabel as nib +import pytest +from dipy.io.stateful_tractogram import StatefulTractogram, Space +from dipy.io.streamline import load_tractogram +from scilpy.tracking.utils import save_tractogram as scil_save_tractogram + +def create_fake_header(affine, shape=(10, 10, 10)): + data = np.zeros(shape) + img = nib.Nifti1Image(data, affine) + return img + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_scil_save_tractogram_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + # Rotation 30 deg around Z, scaling, translation + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1] + ]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + vox_streamlines = [np.array([ + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Generator for scil_save_tractogram + # it yields (streamline, seed) + stream_gen_list = [(s.copy(), s[0].copy()) for s in vox_streamlines] + + output_path = str(tmp_path / f"scil_tracto{ext}") + tracts_format = nib.streamlines.detect_format(output_path) + + # scil_save_tractogram(streamlines_generator, tracts_format, ref_img, total_nb_seeds, + # out_tractogram, min_length, max_length, compress, save_seeds, verbose) + scil_save_tractogram(stream_gen_list, tracts_format, img, len(vox_streamlines), + output_path, 0, 1000, None, False, False) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + assert np.allclose(orig, loaded, atol=1e-3) diff --git a/src/scilpy/tracking/utils.py b/src/scilpy/tracking/utils.py index 551d80959..79dee6305 100644 --- a/src/scilpy/tracking/utils.py +++ b/src/scilpy/tracking/utils.py @@ -192,7 +192,7 @@ def save_tractogram( Streamlines generator. tracts_format : TrkFile or TckFile Tractogram format. - ref_img : nibabel.Nifti1Image + ref_img : nibabel.Nifti1Image or scilpy.io.stateful_image.StatefulImage Image used as reference. total_nb_seeds : int Total number of seeds. @@ -211,11 +211,16 @@ def save_tractogram( If True, display progression bar. """ + from scilpy.io.stateful_image import StatefulImage - voxel_size = ref_img.header.get_zooms()[0] + # If ref_img is a StatefulImage, we want to save relative to its + # original on-disk orientation, not the internal (likely RAS) one. + is_stateful = isinstance(ref_img, StatefulImage) + if is_stateful: + original_axcodes = ref_img.axcodes + ref_img.reorient_to_original() - scaled_min_length = min_length / voxel_size - scaled_max_length = max_length / voxel_size + voxel_size = np.array(ref_img.header.get_zooms()[:3]) # Tracking is expected to be returned in voxel space, origin `center`. def tracks_generator_wrapper(): @@ -224,7 +229,11 @@ def tracks_generator_wrapper(): total=total_nb_seeds, miniters=int(total_nb_seeds / 100), leave=False): - if (scaled_min_length <= length(strl) <= scaled_max_length): + # Compute length in mm space for filtering + # length() is euclidean distance, so we must be in mm + strl_mm = strl * voxel_size + strl_len = length(strl_mm) + if (min_length <= strl_len <= max_length): # Seeds are saved with origin `center` by our own convention. # Other scripts (e.g. scil_tractogram_seed_density_map) expect # so. @@ -233,29 +242,28 @@ def tracks_generator_wrapper(): dps['seeds'] = seed if compress: - # compression threshold is given in mm, but we - # are in voxel space - strl = compress_streamlines( - strl, compress / voxel_size) - - # TODO: Use nibabel utilities for dealing with spaces + # compression threshold is given in mm, so we + # must be in mm space to compress + strl_mm = compress_streamlines(strl_mm, compress) + if tracts_format is TrkFile: - # Streamlines are dumped in mm space with - # origin `corner`. This is what is expected by - # LazyTractogram for .trk files (although this is not - # specified anywhere in the doc) - strl += 0.5 - strl *= voxel_size # in mm. + # Streamlines are dumped in mm space with origin `corner`. + # (TrackVis space). + # Note: We use the already computed strl_mm (center origin) + # and shift it by 0.5 * voxel_size to get corner origin. + strl_to_save = strl_mm + 0.5 * voxel_size else: # Streamlines are dumped in true world space with # origin center as expected by .tck files. - strl = np.dot(strl, ref_img.affine[:3, :3]) + \ - ref_img.affine[:3, 3] + strl_to_save = nib.affines.apply_affine(ref_img.affine, strl) - yield TractogramItem(strl, dps, {}) + yield TractogramItem(strl_to_save, dps, {}) tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper) - tractogram.affine_to_rasmm = ref_img.affine + # Since the generator yields coordinates already in their final format-space + # (TrackVis for .trk, RASMM for .tck), we set the affine_to_rasmm to identity + # to prevent nibabel from applying any further transformation. + tractogram.affine_to_rasmm = np.eye(4) filetype = nib.streamlines.detect_format(out_tractogram) reference = get_reference_info(ref_img) @@ -264,8 +272,12 @@ def tracks_generator_wrapper(): # Use generator to save the streamlines on-the-fly nib.streamlines.save(tractogram, out_tractogram, header=header) + # Revert ref_img to its previous orientation + if is_stateful: + ref_img.reorient(original_axcodes) + -def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, +def get_direction_getter(img_data, algo, sphere, sub_sphere, theta, sh_basis, voxel_size, sf_threshold, sh_to_pmf, probe_length, probe_radius, probe_quality, probe_count, support_exponent, is_legacy=True): @@ -273,8 +285,8 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, Parameters ---------- - in_img: str - Path to the input odf file. + img_data: ndarray + The input odf data. algo: str Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'. sphere: str @@ -319,8 +331,6 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, dg: dipy.direction.DirectionGetter The direction getter object. """ - img_data = nib.load(in_img).get_fdata(dtype=np.float32) - sphere = HemiSphere.from_sphere( get_sphere(name=sphere)).subdivide(n=sub_sphere) diff --git a/src/scilpy/utils/orientation.py b/src/scilpy/utils/orientation.py index 1c2b70365..0b4837ea1 100644 --- a/src/scilpy/utils/orientation.py +++ b/src/scilpy/utils/orientation.py @@ -6,16 +6,19 @@ def validate_voxel_order(axcodes, dimensions=3): """ Validate a set of axis codes. + Parameters ---------- axcodes : str or tuple or list The axis codes to validate (e.g., "LPS", ("R", "A", "S")). dimensions : int The number of dimensions of the image. + Returns ------- tuple A tuple of validated axis codes. + Raises ------ ValueError @@ -26,12 +29,13 @@ def validate_voxel_order(axcodes, dimensions=3): axcodes = tuple(axcodes) if len(axcodes) != dimensions: - raise ValueError(f"Target axis codes must be of length {dimensions}.") + raise ValueError(f"Target axis codes must be of length {dimensions}. " + f"Got {len(axcodes)}.") # Check unique are only valid axis codes valid_codes = {"L", "R", "A", "P", "S", "I"} - if dimensions == 4: - valid_codes.add("T") + if dimensions >= 4: + valid_codes.update(["T", "U", "V", "W", "X", "Y", "Z"]) for code in axcodes: if code not in valid_codes: raise ValueError(f"Invalid axis code '{code}' in target.") @@ -53,16 +57,18 @@ def parse_voxel_order(order_str, dimensions=3): """ Parse the voxel order string into a tuple of axis codes. """ - order_str_cleaned = order_str.replace(',', '').replace(' ', '') + order_str_cleaned = order_str.replace(',', '').replace(' ', '').upper() - if dimensions == 4 and order_str_cleaned.isalpha(): - raise ValueError("Alphabetical voxel order is not supported for 4D " - "images. Please use numeric format.") + if dimensions == 4 and order_str_cleaned.isalpha() and \ + len(order_str_cleaned) == 3: + order_str_cleaned += 'T' if order_str_cleaned.isalpha(): - if len(order_str_cleaned) != 3: - raise ValueError("Voxel order string must have 3 characters.") - return validate_voxel_order(tuple(order_str_cleaned.upper())) + if len(order_str_cleaned) != dimensions: + raise ValueError(f"Voxel order string must have {dimensions} " + f"characters.") + return validate_voxel_order(tuple(order_str_cleaned), + dimensions=dimensions) if order_str_cleaned.replace('-', '').isdigit(): numeric_parts = re.findall(r'-?\d', order_str_cleaned) @@ -89,16 +95,19 @@ def parse_voxel_order(order_str, dimensions=3): axis = flip_map[axis] order.append(axis) + if dimensions == 4 and len(order) == 3: + order.append('T') + # Check for duplicate axes - if len(set(order)) != len(numeric_parts): + if len(set(order)) != len(order): # Handle swapped axes from numeric input (e.g., '231') axis_vals = [ras_map[abs(int(p))] for p in numeric_parts] if len(set(axis_vals)) == len(numeric_parts): - return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) + return validate_voxel_order(tuple(order), dimensions=dimensions) else: raise ValueError("Invalid numeric voxel order. " "Axes cannot be repeated.") - return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) - + return validate_voxel_order(tuple(order), dimensions=dimensions) + raise ValueError(f"Invalid voxel order format: {order_str}") diff --git a/src/scilpy/utils/scilpy_bot.py b/src/scilpy/utils/scilpy_bot.py index b3792776e..cce8b04ad 100644 --- a/src/scilpy/utils/scilpy_bot.py +++ b/src/scilpy/utils/scilpy_bot.py @@ -57,7 +57,7 @@ def _make_title(text): Returns a formatted title string with centered text and spacing """ return f'{Fore.LIGHTBLUE_EX}{Style.BRIGHT}{text.center(SPACING_LEN, "=")}' \ - f'{Style.RESET_ALL}' + f'{Style.RESET_ALL}' def _get_docstring_from_script_path(script): @@ -273,7 +273,7 @@ def _highlight_keywords(text, all_expressions): # Function to apply highlighting to the matched word def apply_highlight(match): return f'{Fore.LIGHTYELLOW_EX}{Style.BRIGHT}{match.group(0)}' \ - f'{Style.RESET_ALL}' + f'{Style.RESET_ALL}' # Replace the matched word with its highlighted version text = pattern.sub(apply_highlight, text) diff --git a/src/scilpy/utils/tests/test_orientation.py b/src/scilpy/utils/tests/test_orientation.py index 815794550..b8ec3ae01 100644 --- a/src/scilpy/utils/tests/test_orientation.py +++ b/src/scilpy/utils/tests/test_orientation.py @@ -84,19 +84,17 @@ def test_parse_voxel_order_invalid_format(): match="Voxel order string must have 3 or 4 numbers."): parse_voxel_order("1,2,3,4,5", dimensions=4) + def test_parse_voxel_order_4d_valid_numeric(): """Test parsing of valid 4D numeric voxel order strings.""" assert parse_voxel_order("1,2,3,4", dimensions=4) == ("R", "A", "S", "T") assert parse_voxel_order("-1,2,-3,4", dimensions=4) == ("L", "A", "I", "T") - assert parse_voxel_order("2,3,1", dimensions=4) == ("A", "S", "R") + assert parse_voxel_order("2,3,1", dimensions=4) == ("A", "S", "R", "T") -def test_parse_voxel_order_4d_invalid_alpha(): - """Test that 4D alphabetical voxel order strings raise an error.""" - with pytest.raises(ValueError, - match="Alphabetical voxel order is not supported for 4D " - "images. Please use numeric format."): - parse_voxel_order("RAS", dimensions=4) +def test_parse_voxel_order_4d_alpha(): + """Test that 4D alphabetical voxel order strings are now supported.""" + assert parse_voxel_order("RAS", dimensions=4) == ("R", "A", "S", "T") def test_parse_voxel_order_4d_invalid_numeric():