Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def __init__(
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
track_meta: bool = False,
weights_only: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -264,7 +266,17 @@ def __init__(
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.

track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
default to `False`. Cannot be used with `weights_only=True`.
weights_only: keyword argument passed to `torch.load` when reading cached files.
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
other safe objects. Setting this to `False` is required for loading `MetaTensor`
objects saved with `track_meta=True`, however this creates the possibility of remote
code execution through `torch.load` so be aware of the security implications of doing so.

Raises:
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
"""
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
Expand All @@ -280,6 +292,13 @@ def __init__(
if hash_transform is not None:
self.set_transform_hash(hash_transform)
self.reset_ops_id = reset_ops_id
if track_meta and weights_only:
raise ValueError(
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
)
self.track_meta = track_meta
self.weights_only = weights_only

def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
"""Get hashable transforms, and then hash them. Hashable transforms
Expand Down Expand Up @@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=True)
return torch.load(hashfile, weights_only=self.weights_only)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand All @@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand Down
38 changes: 35 additions & 3 deletions tests/data/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import contextlib
import os
import tempfile
import unittest
Expand All @@ -20,7 +21,7 @@
import torch
from parameterized import parameterized

from monai.data import PersistentDataset, json_hashing
from monai.data import MetaTensor, PersistentDataset, json_hashing
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform

TEST_CASE_1 = [
Expand All @@ -43,9 +44,16 @@

TEST_CASE_3 = [None, (128, 128, 128)]

TEST_CASE_4 = [True, False, False, MetaTensor]

TEST_CASE_5 = [True, True, True, None]

TEST_CASE_6 = [False, False, False, torch.Tensor]

TEST_CASE_7 = [False, True, False, torch.Tensor]

class _InplaceXform(Transform):

class _InplaceXform(Transform):
def __call__(self, data):
if data:
data[0] = data[0] + np.pi
Expand All @@ -55,7 +63,6 @@ def __call__(self, data):


class TestDataset(unittest.TestCase):

def test_cache(self):
"""testing no inplace change to the hashed item"""
items = [[list(range(i))] for i in range(5)]
Expand Down Expand Up @@ -168,6 +175,31 @@ def test_different_transforms(self):
l2 = ((im1 - im2) ** 2).sum() ** 0.5
self.assertGreater(l2, 1)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):
"""
Ensure expected behavior for all combinations of `track_meta` and `weights_only`.
"""
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz"))
test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}]
transform = Compose([LoadImaged(keys=["image"])])
cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")

cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext()
with cm:
test_dataset = PersistentDataset(
data=test_data,
transform=transform,
cache_dir=cache_dir,
track_meta=track_meta,
weights_only=weights_only,
)

im = test_dataset[0]["image"]
self.assertIsInstance(im, expected_type)


if __name__ == "__main__":
unittest.main()
Loading