Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
]

SUPPORTED_WRITERS: dict = {}
WRITER_DEPENDENCY_HINTS: dict[type, tuple[str, str]] = {}


def register_writer(ext_name, *im_writers):
Expand Down Expand Up @@ -106,17 +107,34 @@ def resolve_writer(ext_name, error_if_not_found=True) -> Sequence:
if fmt.startswith("."):
fmt = fmt[1:]
avail_writers = []
dependency_hints: set[tuple[str, str]] = set()
default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ())
for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers):
try:
_writer() # this triggers `monai.utils.module.require_pkg` to check the system availability
avail_writers.append(_writer)
except OptionalImportError:
hint = WRITER_DEPENDENCY_HINTS.get(_writer)
if hint:
dependency_hints.add(hint)
continue
except Exception: # other writer init errors indicating it exists
avail_writers.append(_writer)
if not avail_writers and error_if_not_found:
raise OptionalImportError(f"No ImageWriter backend found for {fmt}.")
hint_msg = ""
if dependency_hints:
sorted_hints = sorted(dependency_hints, key=lambda item: item[0].lower())
if len(sorted_hints) == 1:
pkg, cmd = sorted_hints[0]
hint_msg = f" Install `{pkg}` (e.g. `{cmd}`) to enable writing {fmt} images."
else:
pkg_names = ", ".join(f"`{pkg}`" for pkg, _ in sorted_hints)
commands = ", ".join(f"`{cmd}`" for _, cmd in sorted_hints)
hint_msg = (
f" Install one of the supported dependencies {pkg_names} "
f"(for example: {commands}) to enable writing {fmt} images."
)
raise OptionalImportError(f"No ImageWriter backend found for {fmt}.{hint_msg}")
writer_tuple = ensure_tuple(avail_writers)
SUPPORTED_WRITERS[fmt] = writer_tuple
return writer_tuple
Expand Down Expand Up @@ -862,6 +880,15 @@ def create_backend_obj(
return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None))


WRITER_DEPENDENCY_HINTS.update(
{
ITKWriter: ("ITK", "pip install itk"),
NibabelWriter: ("Nibabel", "pip install nibabel"),
PILWriter: ("Pillow", "pip install pillow"),
}
)


def init():
"""
Initialize the image writer modules according to the filename extension.
Expand Down
28 changes: 27 additions & 1 deletion tests/data/test_image_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import torch
from parameterized import parameterized

import monai.data.image_writer as image_writer
from monai.data.image_reader import ITKReader, NibabelReader, NrrdReader, PILReader
from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage, SaveImage, moveaxis
from monai.utils import MetaKeys, OptionalImportError, optional_import
from monai.utils import MetaKeys, OptionalImportError, optional_import, require_pkg
from tests.test_utils import TEST_NDARRAYS, assert_allclose

_, has_itk = optional_import("itk", allow_namespace_pkg=True)
Expand Down Expand Up @@ -150,6 +151,31 @@ def test_1_new(self):
register_writer("new2", lambda x: x + 1)
self.assertEqual(resolve_writer("new")[0](0), 1)

def test_missing_dependency_hint(self):
ext = ".needshint"
fmt_key = ext.lstrip(".").lower()
previous = image_writer.SUPPORTED_WRITERS.get(fmt_key)

@require_pkg(pkg_name="__monai_missing_test_pkg__")
class MissingHintWriter(image_writer.ImageWriter):
pass

image_writer.WRITER_DEPENDENCY_HINTS[MissingHintWriter] = ("FakePkg", "pip install fakepkg")

try:
register_writer(ext, MissingHintWriter)
with self.assertRaises(OptionalImportError) as ctx:
resolve_writer(ext)
err_msg = str(ctx.exception)
self.assertIn("FakePkg", err_msg)
self.assertIn("pip install fakepkg", err_msg)
finally:
image_writer.WRITER_DEPENDENCY_HINTS.pop(MissingHintWriter, None)
if previous is None:
image_writer.SUPPORTED_WRITERS.pop(fmt_key, None)
else:
image_writer.SUPPORTED_WRITERS[fmt_key] = previous


@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSaveNrrd(unittest.TestCase):
Expand Down
Loading