2323)
2424
2525
26- class DisplacementsFieldTransform (TransformBase ):
27- """Represents a dense field of displacements (one vector per voxel) ."""
26+ class DenseFieldTransform (TransformBase ):
27+ """Represents dense field ( voxel-wise) transforms ."""
2828
29- __slots__ = [ "_field" ]
29+ __slots__ = ( "_field" , "_deltas" )
3030
31- def __init__ (self , field , reference = None ):
31+ def __init__ (self , field = None , is_deltas = True , reference = None ):
3232 """
33- Create a dense deformation field transform.
33+ Create a dense field transform.
34+
35+ Converting to a field of deformations is straightforward by just adding the corresponding
36+ displacement to the :math:`(x, y, z)` coordinates of each voxel.
37+ Numerically, deformation fields are less susceptible to rounding errors
38+ than displacements fields.
39+ SPM generally prefers deformations for that reason.
40+
41+ Parameters
42+ ----------
43+ field : :obj:`numpy.array_like` or :obj:`nibabel.SpatialImage`
44+ The field of deformations or displacements (*deltas*). If given as a data array,
45+ then the reference **must** be given.
46+ is_deltas : :obj:`bool`
47+ Whether this is a displacements (deltas) field (default), or deformations.
48+ reference : :obj:`ImageGrid`
49+ Defines the domain of the transform. If not provided, the domain is defined from
50+ the ``field`` input.
3451
3552 Example
3653 -------
37- >>> DisplacementsFieldTransform (test_dir / "someones_displacement_field.nii.gz")
38- <DisplacementFieldTransform [3D] (57, 67, 56)>
54+ >>> DenseFieldTransform (test_dir / "someones_displacement_field.nii.gz")
55+ <DenseFieldTransform [3D] (57, 67, 56)>
3956
4057 """
58+ if field is None and reference is None :
59+ raise TransformError ("DenseFieldTransforms require a spatial reference" )
60+
4161 super ().__init__ ()
4262
43- field = _ensure_image (field )
44- self ._field = np .squeeze (
45- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
46- )
63+ if field is not None :
64+ field = _ensure_image (field )
65+ self ._field = np .squeeze (
66+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
67+ )
68+ else :
69+ self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
70+ is_deltas = True
4771
4872 try :
4973 self .reference = ImageGrid (
@@ -59,45 +83,61 @@ def __init__(self, field, reference=None):
5983 ndim = self ._field .ndim - 1
6084 if self ._field .shape [- 1 ] != ndim :
6185 raise TransformError (
62- "The number of components of the displacements (%d) does not "
86+ "The number of components of the field (%d) does not match "
6387 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
6488 )
6589
90+ if is_deltas :
91+ self ._deltas = self ._field
92+ # Convert from displacements (deltas) to deformations fields
93+ # (just add its origin to each delta vector)
94+ self ._field += self .reference .ndcoords .T .reshape (self ._field .shape )
95+
6696 def __repr__ (self ):
6797 """Beautify the python representation."""
68- return f"<DisplacementFieldTransform [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
98+ return f"<{ self . __class__ . __name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
6999
70100 def map (self , x , inverse = False ):
71101 r"""
72102 Apply the transformation to a list of physical coordinate points.
73103
74104 .. math::
75- \mathbf{y} = \mathbf{x} + D (\mathbf{x}),
105+ \mathbf{y} = \mathbf{x} + \Delta (\mathbf{x}),
76106 \label{eq:2}\tag{2}
77107
78- where :math:`D (\mathbf{x})` is the value of the discrete field of displacements
79- :math:`D ` interpolated at the location :math:`\mathbf{x}`.
108+ where :math:`\Delta (\mathbf{x})` is the value of the discrete field of displacements
109+ :math:`\Delta ` interpolated at the location :math:`\mathbf{x}`.
80110
81111 Parameters
82112 ----------
83- x : N x D numpy.ndarray
113+ x : N x D :obj:` numpy.array_like`
84114 Input RAS+ coordinates (i.e., physical coordinates).
85- inverse : bool
115+ inverse : :obj:` bool`
86116 If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
87117
88118 Returns
89119 -------
90- y : N x D numpy.ndarray
120+ y : N x D :obj:` numpy.array_like`
91121 Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
92122
93123 Examples
94124 --------
95- >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
125+ >>> xfm = DenseFieldTransform(
126+ ... test_dir / "someones_displacement_field.nii.gz",
127+ ... is_deltas=False,
128+ ... )
96129 >>> xfm.map([-6.5, -36., -19.5]).tolist()
97- [[-6.5 , -36.475167989730835, -19.5 ]]
130+ [[0.0 , -0.47516798973083496, 0.0 ]]
98131
99132 >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
100- [[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
133+ [[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
134+
135+ >>> xfm = DenseFieldTransform(
136+ ... test_dir / "someones_displacement_field.nii.gz",
137+ ... is_deltas=True,
138+ ... )
139+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
140+ [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
101141
102142 """
103143
@@ -106,9 +146,51 @@ def map(self, x, inverse=False):
106146 ijk = self .reference .index (x )
107147 indexes = np .round (ijk ).astype ("int" )
108148 if np .any (np .abs (ijk - indexes ) > 0.05 ):
109- warnings .warn ("Some coordinates are off-grid of the displacements field." )
149+ warnings .warn ("Some coordinates are off-grid of the field." )
110150 indexes = tuple (tuple (i ) for i in indexes .T )
111- return x + self ._field [indexes ]
151+ return self ._field [indexes ]
152+
153+ def __matmul__ (self , b ):
154+ """
155+ Compose with a transform on the right.
156+
157+ Examples
158+ --------
159+ >>> deff = DenseFieldTransform(
160+ ... test_dir / "someones_displacement_field.nii.gz",
161+ ... is_deltas=False,
162+ ... )
163+ >>> deff2 = deff @ TransformBase()
164+ >>> deff == deff2
165+ True
166+
167+ >>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
168+ >>> disp2 = disp @ TransformBase()
169+ >>> disp == disp2
170+ True
171+
172+ """
173+ retval = b .map (
174+ self ._field .reshape ((- 1 , self ._field .shape [- 1 ]))
175+ ).reshape (self ._field .shape )
176+ return DenseFieldTransform (retval , is_deltas = False , reference = self .reference )
177+
178+ def __eq__ (self , other ):
179+ """
180+ Overload equals operator.
181+
182+ Examples
183+ --------
184+ >>> xfm1 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
185+ >>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
186+ >>> xfm1 == xfm2
187+ True
188+
189+ """
190+ _eq = np .array_equal (self ._field , other ._field )
191+ if _eq and self ._reference != other ._reference :
192+ warnings .warn ("Fields are equal, but references do not match." )
193+ return _eq
112194
113195 @classmethod
114196 def from_filename (cls , filename , fmt = "X5" ):
@@ -123,7 +205,7 @@ def from_filename(cls, filename, fmt="X5"):
123205 return cls (_factory [fmt ].from_filename (filename ))
124206
125207
126- load = DisplacementsFieldTransform .from_filename
208+ load = DenseFieldTransform .from_filename
127209
128210
129211class BSplineFieldTransform (TransformBase ):
@@ -169,8 +251,9 @@ def to_field(self, reference=None, dtype="float32"):
169251 # 1 x Nvox : (1 x K) @ (K x Nvox)
170252 field [:, d ] = self ._coeffs [..., d ].reshape (- 1 ) @ self ._weights
171253
172- return DisplacementsFieldTransform (
173- field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref )
254+ return DenseFieldTransform (
255+ field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref
256+ )
174257
175258 def apply (
176259 self ,
0 commit comments