Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 61 additions & 60 deletions xfuse/data/utility/misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools as it
from typing import Any, Dict
import platform

import numpy as np
import torch
Expand Down Expand Up @@ -91,68 +92,68 @@ def _compute_size(x):
)
}

def _worker_init(n):
np.random.seed(np.random.get_state()[1][0] + get("training_data").step)
np.random.seed(np.random.randint(np.iinfo(np.int32).max) + n)

def _collate(xs):
def _remove_key(v):
v.pop("data_type")
return v

def _sort_key(x):
return x["data_type"]

def _collate(ys):
collated_data = {}

# we can't collate the count data as a tensor since its dimension
# will differ between samples. therefore, we return it as a list
# instead.
try:
collated_data.update({"data": [y.pop("data") for y in ys]})
except KeyError:
pass

# Collate any other non-tensor as list
collated_data.update(
{
k: [y.pop(k) for y in ys]
for k in set(
k
for y in ys
for k, v in y.items()
if not torch.is_tensor(v)
)
}
)

def make_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader:
r"""Creates a :class:`~torch.utils.data.DataLoader` for `dataset`"""

def _collate(xs):
def _remove_key(v):
v.pop("data_type")
return v

def _sort_key(x):
return x["data_type"]

def _collate(ys):
collated_data = {}

# we can't collate the count data as a tensor since its dimension
# will differ between samples. therefore, we return it as a list
# instead.
try:
collated_data.update({"data": [y.pop("data") for y in ys]})
except KeyError:
pass

# Collate any other non-tensor as list
collated_data.update(
{
k: [y.pop(k) for y in ys]
for k in set(
k
for y in ys
for k, v in y.items()
if not torch.is_tensor(v)
# Crop image sizes to the minimum size over the batch
min_size = {}
for y in ys:
for k, v in y.items():
if k in min_size:
min_size[k] = torch.min(
min_size[k], torch.as_tensor(v.shape)
)
}
)
else:
min_size[k] = torch.as_tensor(v.shape)
for y in ys:
for k, v in min_size.items():
y[k] = center_crop(y[k], v.numpy().tolist())
collated_data.update(default_collate(ys))

# Crop image sizes to the minimum size over the batch
min_size = {}
for y in ys:
for k, v in y.items():
if k in min_size:
min_size[k] = torch.min(
min_size[k], torch.as_tensor(v.shape)
)
else:
min_size[k] = torch.as_tensor(v.shape)
for y in ys:
for k, v in min_size.items():
y[k] = center_crop(y[k], v.numpy().tolist())
collated_data.update(default_collate(ys))

return collated_data

return {
k: _collate([_remove_key(v) for v in vs])
for k, vs in it.groupby(sorted(xs, key=_sort_key), key=_sort_key)
}

def _worker_init(n):
np.random.seed(np.random.get_state()[1][0] + get("training_data").step)
np.random.seed(np.random.randint(np.iinfo(np.int32).max) + n)
return collated_data

return {
k: _collate([_remove_key(v) for v in vs])
for k, vs in it.groupby(sorted(xs, key=_sort_key), key=_sort_key)
}

def make_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader:
r"""Creates a :class:`~torch.utils.data.DataLoader` for `dataset`"""
return DataLoader(
dataset, collate_fn=_collate, worker_init_fn=_worker_init, **kwargs
)
dataset = dataset,
collate_fn = _collate,
worker_init_fn = _worker_init,
**kwargs )
6 changes: 3 additions & 3 deletions xfuse/messengers/stats/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from io import BytesIO
from typing import Callable, List, Optional

import matplotlib
import matplotlib.pyplot as plt
import torch
from imageio import imread
from pyro.poutine.messenger import Messenger
Expand Down Expand Up @@ -58,9 +58,9 @@ def _postprocess_message(self, msg):
self._handle(**msg)


def log_figure(tag: str, figure: matplotlib.figure.Figure, **kwargs,) -> None:
def log_figure(tag: str, figure: plt.Figure, **kwargs,) -> None:
r"""
Converts :class:`~matplotlib.figure.Figure`` to image data and logs it
Converts :class:`~plt.Figure`` to image data and logs it
using :func:`log_image`
"""
if "format" not in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion xfuse/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import warnings
import multiprocessing
from functools import partial, reduce
from operator import add
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -61,7 +62,7 @@ def run(
if slide_options is None:
slide_options = {}

if (available_cores := len(os.sched_getaffinity(0))) < num_data_workers:
if (available_cores := multiprocessing.cpu_count()) < num_data_workers:
warnings.warn(
" ".join(
[
Expand Down