|
14 | 14 |
|
15 | 15 | from nibabel.loadsave import load as _nbload |
16 | 16 |
|
17 | | -from .base import ( |
| 17 | +from nitransforms.base import ( |
18 | 18 | ImageGrid, |
19 | 19 | TransformBase, |
20 | 20 | SpatialReference, |
21 | 21 | _as_homogeneous, |
22 | 22 | EQUALITY_TOL, |
23 | 23 | ) |
24 | | -from . import io |
| 24 | +from nitransforms.io import get_linear_factory, TransformFileError |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class Affine(TransformBase): |
@@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root): |
183 | 183 | self.reference._to_hdf5(x5_root.create_group("Reference")) |
184 | 184 |
|
185 | 185 | def to_filename(self, filename, fmt="X5", moving=None): |
186 | | - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" |
187 | | - if fmt.lower() in ["itk", "ants", "elastix"]: |
188 | | - itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix) |
189 | | - itkobj.to_filename(filename) |
190 | | - return filename |
191 | | - |
192 | | - # Rest of the formats peek into moving and reference image grids |
193 | | - moving = ImageGrid(moving) if moving is not None else self.reference |
194 | | - |
195 | | - _factory = { |
196 | | - "afni": io.afni.AFNILinearTransform, |
197 | | - "fsl": io.fsl.FSLLinearTransform, |
198 | | - "lta": io.lta.FSLinearTransform, |
199 | | - "fs": io.lta.FSLinearTransform, |
200 | | - } |
201 | | - |
202 | | - if fmt not in _factory: |
203 | | - raise NotImplementedError(f"Unsupported format <{fmt}>") |
204 | | - |
205 | | - _factory[fmt].from_ras( |
206 | | - self.matrix, moving=moving, reference=self.reference |
207 | | - ).to_filename(filename) |
208 | | - return filename |
| 186 | + """Store the transform in the requested output format.""" |
| 187 | + writer = get_linear_factory(fmt, is_array=False) |
209 | 188 |
|
210 | | - @classmethod |
211 | | - def from_filename(cls, filename, fmt="X5", reference=None, moving=None): |
212 | | - """Create an affine from a transform file.""" |
213 | 189 | if fmt.lower() in ("itk", "ants", "elastix"): |
214 | | - _factory = io.itk.ITKLinearTransformArray |
215 | | - elif fmt.lower() in ("lta", "fs"): |
216 | | - _factory = io.lta.FSLinearTransformArray |
217 | | - elif fmt.lower() == "fsl": |
218 | | - _factory = io.fsl.FSLLinearTransformArray |
219 | | - elif fmt.lower() == "afni": |
220 | | - _factory = io.afni.AFNILinearTransformArray |
| 190 | + writer.from_ras(self.matrix).to_filename(filename) |
221 | 191 | else: |
222 | | - raise NotImplementedError |
| 192 | + # Rest of the formats peek into moving and reference image grids |
| 193 | + writer.from_ras( |
| 194 | + self.matrix, |
| 195 | + reference=self.reference, |
| 196 | + moving=ImageGrid(moving) if moving is not None else self.reference, |
| 197 | + ).to_filename(filename) |
| 198 | + return filename |
223 | 199 |
|
224 | | - struct = _factory.from_filename(filename) |
225 | | - matrix = struct.to_ras(reference=reference, moving=moving) |
226 | | - if cls == Affine: |
227 | | - if np.shape(matrix)[0] != 1: |
228 | | - raise TypeError("Cannot load transform array '%s'" % filename) |
229 | | - matrix = matrix[0] |
230 | | - return cls(matrix, reference=reference) |
| 200 | + @classmethod |
| 201 | + def from_filename(cls, filename, fmt=None, reference=None, moving=None): |
| 202 | + """Create an affine from a transform file.""" |
| 203 | + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") |
| 204 | + |
| 205 | + for potential_fmt in fmtlist: |
| 206 | + try: |
| 207 | + struct = get_linear_factory(potential_fmt).from_filename(filename) |
| 208 | + matrix = struct.to_ras(reference=reference, moving=moving) |
| 209 | + if cls == Affine: |
| 210 | + if np.shape(matrix)[0] != 1: |
| 211 | + raise TypeError("Cannot load transform array '%s'" % filename) |
| 212 | + matrix = matrix[0] |
| 213 | + return cls(matrix, reference=reference) |
| 214 | + except (TransformFileError, FileNotFoundError): |
| 215 | + continue |
| 216 | + |
| 217 | + raise TransformFileError( |
| 218 | + f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})." |
| 219 | + ) |
231 | 220 |
|
232 | 221 | def __repr__(self): |
233 | 222 | """ |
@@ -353,31 +342,18 @@ def map(self, x, inverse=False): |
353 | 342 | return np.swapaxes(affine.dot(coords), 1, 2) |
354 | 343 |
|
355 | 344 | def to_filename(self, filename, fmt="X5", moving=None): |
356 | | - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" |
357 | | - if fmt.lower() in ("itk", "ants", "elastix"): |
358 | | - itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix) |
359 | | - itkobj.to_filename(filename) |
360 | | - return filename |
| 345 | + """Store the transform in the requested output format.""" |
| 346 | + writer = get_linear_factory(fmt, is_array=True) |
361 | 347 |
|
362 | | - # Rest of the formats peek into moving and reference image grids |
363 | | - if moving is not None: |
364 | | - moving = ImageGrid(moving) |
| 348 | + if fmt.lower() in ("itk", "ants", "elastix"): |
| 349 | + writer.from_ras(self.matrix).to_filename(filename) |
365 | 350 | else: |
366 | | - moving = self.reference |
367 | | - |
368 | | - _factory = { |
369 | | - "afni": io.afni.AFNILinearTransformArray, |
370 | | - "fsl": io.fsl.FSLLinearTransformArray, |
371 | | - "lta": io.lta.FSLinearTransformArray, |
372 | | - "fs": io.lta.FSLinearTransformArray, |
373 | | - } |
374 | | - |
375 | | - if fmt not in _factory: |
376 | | - raise NotImplementedError(f"Unsupported format <{fmt}>") |
377 | | - |
378 | | - _factory[fmt].from_ras( |
379 | | - self.matrix, moving=moving, reference=self.reference |
380 | | - ).to_filename(filename) |
| 351 | + # Rest of the formats peek into moving and reference image grids |
| 352 | + writer.from_ras( |
| 353 | + self.matrix, |
| 354 | + reference=self.reference, |
| 355 | + moving=ImageGrid(moving) if moving is not None else self.reference, |
| 356 | + ).to_filename(filename) |
381 | 357 | return filename |
382 | 358 |
|
383 | 359 | def apply( |
@@ -486,17 +462,17 @@ def apply( |
486 | 462 | return resampled |
487 | 463 |
|
488 | 464 |
|
489 | | -def load(filename, fmt="X5", reference=None, moving=None): |
| 465 | +def load(filename, fmt=None, reference=None, moving=None): |
490 | 466 | """ |
491 | 467 | Load a linear transform file. |
492 | 468 |
|
493 | 469 | Examples |
494 | 470 | -------- |
495 | | - >>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk") |
| 471 | + >>> xfm = load(regress_dir / "affine-LAS.itk.tfm") |
496 | 472 | >>> isinstance(xfm, Affine) |
497 | 473 | True |
498 | 474 |
|
499 | | - >>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk") |
| 475 | + >>> xfm = load(regress_dir / "itktflist.tfm") |
500 | 476 | >>> isinstance(xfm, LinearTransformsMapping) |
501 | 477 | True |
502 | 478 |
|
|
0 commit comments