-
Notifications
You must be signed in to change notification settings - Fork 1
Improve DlupDataModule #90
Description
I have come across some things that could improve the code quality/readability of the DlupDataModule code
self._already_called is never set. While constructing the dataloaders, we already check if the self._{stage}_data_iterator exists (if it is not None). This is sufficient, so I don't see the added value of keeping self._already_called
Suggestion: remove self._already_called
- (bug)
def test_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]:
if not self._test_data_iterator:
self.setup("test")
batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size
assert self._validate_data_iterator
return self._construct_concatenated_dataloader(
self._validate_data_iterator, batch_size=batch_size, stage="test"
)
It should use the self._test_data_iterator not self._validate_data_iterator
Suggestion: replace self._validate_data_iterator with self._test_data_iterator
The data iterator type in the _construct_concatenated_dataloader method is data_iterator: Iterator[_DlupDataset]
However, in _construct_concatenated_dataloader the data_iterator is allowed to be None:
if not data_iterator:
return None
Suggestions: a) Change data_iterator: Iterator[_DlupDataset] --> data_iterator: Iterator[_DlupDataset] | None. b) Or if not data_iterator, raise ValueError()
This is logged in _construct_concatenated_dataloader:
lengths = np.asarray([len(ds) for ds in dataset.datasets])
self._logger.info(
f"Dataset for stage {stage} has {len(dataset)} samples and the following statistics:\n"
f" - Mean: {lengths.mean():.2f}\n"
f" - Std: {lengths.std():.2f}\n"
f" - Min: {lengths.min():.2f}\n"
f" - Max: {lengths.max():.2f}"
)
Suggestion: add a method log_stats to the ConcatDataset class
e.g.
def log_stats(self) -> None:
lengths = np.asarray([len(ds) for ds in self.datasets])
logger.info(
f"Total number of samples: {len(self})\n"
f" - Mean: {lengths.mean():.2f}\n"
f" - Std: {lengths.std():.2f}\n"
f" - Min: {lengths.min():.2f}\n"
f" - Max: {lengths.max():.2f}"
)
and in _construct_concatenated_dataloader:
self._logger.info(f"Dataset for stage {stage} has the following statistics:\n"
dataset.log_stats()
The variable stage is currently a string throughout the code.
Suggestion: use Enum to keep track of the stage. Especially whithin datasets_from_data_description (manifest.py), it would be nice to use the CategoryEnum from database_models.py
self._limit_{stage}_samples is protected, but it is never set within DlupDataModule as far as I could see
Suggestions:
a) add it as argument to the init function e.g. limit_samples: dict[str, int] then e.g. self._limit_validate_samples = limit_samples.get("limit_validate_samples", None) b) use kwargs in the init function eg. self._limit_validate_samples = kwargs.get("limit_validate_samples", None)
self._limit_{stage}_samples is limiting the number of datasets that are concatenated (so the number of WSIs), not the actual number of samples (which are the number of tiles).
data_iterator: Iterator[_DlupDataset]
def construct_dataset() -> ConcatDataset:
datasets = []
for idx, ds in enumerate(data_iterator):
datasets.append(ds)
if limit_samples and idx >= limit_samples:
break
return ConcatDataset(datasets=datasets)
Suggestion: rename self._limit_{stage}_samples to self._limit_{stage}_slides
self._limit_{stage}_samples can lead to problems in loading datasets from the cache: we check whether a cached dataset exists based on its UUID. However, this UUID is generated based on only the self.data_description. So if you run the datamodule first with e.g. self._limit_fit_samples=10 and then you run it a second time with self._limit_fit_samples=None , I expect it will load the incomplete dataset (with 10 slides) from the cache instead of constructing a new dataset
Suggestions: a) instead of using only the UUID as pkl filename, add e.g. a suffix _{limit} if the limit is not None b) add the limit somewhere in the path in _load_from_cache