diff --git a/xfuse/data/utility/misc.py b/xfuse/data/utility/misc.py index 342984f4..dd5c2bc7 100644 --- a/xfuse/data/utility/misc.py +++ b/xfuse/data/utility/misc.py @@ -1,5 +1,6 @@ import itertools as it from typing import Any, Dict +import platform import numpy as np import torch @@ -91,68 +92,68 @@ def _compute_size(x): ) } +def _worker_init(n): + np.random.seed(np.random.get_state()[1][0] + get("training_data").step) + np.random.seed(np.random.randint(np.iinfo(np.int32).max) + n) + +def _collate(xs): + def _remove_key(v): + v.pop("data_type") + return v + + def _sort_key(x): + return x["data_type"] + + def _collate(ys): + collated_data = {} + + # we can't collate the count data as a tensor since its dimension + # will differ between samples. therefore, we return it as a list + # instead. + try: + collated_data.update({"data": [y.pop("data") for y in ys]}) + except KeyError: + pass + + # Collate any other non-tensor as list + collated_data.update( + { + k: [y.pop(k) for y in ys] + for k in set( + k + for y in ys + for k, v in y.items() + if not torch.is_tensor(v) + ) + } + ) -def make_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader: - r"""Creates a :class:`~torch.utils.data.DataLoader` for `dataset`""" - - def _collate(xs): - def _remove_key(v): - v.pop("data_type") - return v - - def _sort_key(x): - return x["data_type"] - - def _collate(ys): - collated_data = {} - - # we can't collate the count data as a tensor since its dimension - # will differ between samples. therefore, we return it as a list - # instead. - try: - collated_data.update({"data": [y.pop("data") for y in ys]}) - except KeyError: - pass - - # Collate any other non-tensor as list - collated_data.update( - { - k: [y.pop(k) for y in ys] - for k in set( - k - for y in ys - for k, v in y.items() - if not torch.is_tensor(v) + # Crop image sizes to the minimum size over the batch + min_size = {} + for y in ys: + for k, v in y.items(): + if k in min_size: + min_size[k] = torch.min( + min_size[k], torch.as_tensor(v.shape) ) - } - ) + else: + min_size[k] = torch.as_tensor(v.shape) + for y in ys: + for k, v in min_size.items(): + y[k] = center_crop(y[k], v.numpy().tolist()) + collated_data.update(default_collate(ys)) - # Crop image sizes to the minimum size over the batch - min_size = {} - for y in ys: - for k, v in y.items(): - if k in min_size: - min_size[k] = torch.min( - min_size[k], torch.as_tensor(v.shape) - ) - else: - min_size[k] = torch.as_tensor(v.shape) - for y in ys: - for k, v in min_size.items(): - y[k] = center_crop(y[k], v.numpy().tolist()) - collated_data.update(default_collate(ys)) - - return collated_data - - return { - k: _collate([_remove_key(v) for v in vs]) - for k, vs in it.groupby(sorted(xs, key=_sort_key), key=_sort_key) - } - - def _worker_init(n): - np.random.seed(np.random.get_state()[1][0] + get("training_data").step) - np.random.seed(np.random.randint(np.iinfo(np.int32).max) + n) + return collated_data + return { + k: _collate([_remove_key(v) for v in vs]) + for k, vs in it.groupby(sorted(xs, key=_sort_key), key=_sort_key) + } + +def make_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader: + r"""Creates a :class:`~torch.utils.data.DataLoader` for `dataset`""" return DataLoader( - dataset, collate_fn=_collate, worker_init_fn=_worker_init, **kwargs - ) + dataset = dataset, + collate_fn = _collate, + worker_init_fn = _worker_init, + **kwargs ) diff --git a/xfuse/messengers/stats/stats_handler.py b/xfuse/messengers/stats/stats_handler.py index 9be9b47e..65a04b08 100644 --- a/xfuse/messengers/stats/stats_handler.py +++ b/xfuse/messengers/stats/stats_handler.py @@ -2,7 +2,7 @@ from io import BytesIO from typing import Callable, List, Optional -import matplotlib +import matplotlib.pyplot as plt import torch from imageio import imread from pyro.poutine.messenger import Messenger @@ -58,9 +58,9 @@ def _postprocess_message(self, msg): self._handle(**msg) -def log_figure(tag: str, figure: matplotlib.figure.Figure, **kwargs,) -> None: +def log_figure(tag: str, figure: plt.Figure, **kwargs,) -> None: r""" - Converts :class:`~matplotlib.figure.Figure`` to image data and logs it + Converts :class:`~plt.Figure`` to image data and logs it using :func:`log_image` """ if "format" not in kwargs: diff --git a/xfuse/run.py b/xfuse/run.py index 01a0f134..3277a30d 100644 --- a/xfuse/run.py +++ b/xfuse/run.py @@ -1,6 +1,7 @@ import os import re import warnings +import multiprocessing from functools import partial, reduce from operator import add from typing import Any, Dict, Optional, Tuple @@ -61,7 +62,7 @@ def run( if slide_options is None: slide_options = {} - if (available_cores := len(os.sched_getaffinity(0))) < num_data_workers: + if (available_cores := multiprocessing.cpu_count()) < num_data_workers: warnings.warn( " ".join( [