diff --git a/README.md b/README.md index d295fbf7..2823d34b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,11 @@ + +This is a fork of the GitHub repository https://github.com/CompVis/taming-transformers. Due to recent changes in PyTorch, the original CompVis/taming-transformers repository requires small updates to remain compatible. I need to use this repo in Chapter 11 of my book Build a Text-to-Image Generator from Scratch with Manning Publications. Rather than asking every reader to manually edit source code, I have created a fork of the repository with these compatibility fixes already applied. + +In the file /traming-transformers/taming/data/utils.py, I have changed string_classes to str in line 152. After that, I deleted line 11 of the file (the line that says "from torch._six import string_classes"). + + + + # Taming Transformers for High-Resolution Image Synthesis ##### CVPR 2021 (Oral) ![teaser](assets/mountain.jpeg) diff --git a/taming/data/utils.py b/taming/data/utils.py index 2b3c3d53..6aa55c5b 100644 --- a/taming/data/utils.py +++ b/taming/data/utils.py @@ -1,169 +1,169 @@ -import collections -import os -import tarfile -import urllib -import zipfile -from pathlib import Path - -import numpy as np -import torch -from taming.data.helper_types import Annotation -from torch._six import string_classes -from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format -from tqdm import tqdm - - -def unpack(path): - if path.endswith("tar.gz"): - with tarfile.open(path, "r:gz") as tar: - tar.extractall(path=os.path.split(path)[0]) - elif path.endswith("tar"): - with tarfile.open(path, "r:") as tar: - tar.extractall(path=os.path.split(path)[0]) - elif path.endswith("zip"): - with zipfile.ZipFile(path, "r") as f: - f.extractall(path=os.path.split(path)[0]) - else: - raise NotImplementedError( - "Unknown file extension: {}".format(os.path.splitext(path)[1]) - ) - - -def reporthook(bar): - """tqdm progress bar for downloads.""" - - def hook(b=1, bsize=1, tsize=None): - if tsize is not None: - bar.total = tsize - bar.update(b * bsize - bar.n) - - return hook - - -def get_root(name): - base = "data/" - root = os.path.join(base, name) - os.makedirs(root, exist_ok=True) - return root - - -def is_prepared(root): - return Path(root).joinpath(".ready").exists() - - -def mark_prepared(root): - Path(root).joinpath(".ready").touch() - - -def prompt_download(file_, source, target_dir, content_dir=None): - targetpath = os.path.join(target_dir, file_) - while not os.path.exists(targetpath): - if content_dir is not None and os.path.exists( - os.path.join(target_dir, content_dir) - ): - break - print( - "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) - ) - if content_dir is not None: - print( - "Or place its content into '{}'.".format( - os.path.join(target_dir, content_dir) - ) - ) - input("Press Enter when done...") - return targetpath - - -def download_url(file_, url, target_dir): - targetpath = os.path.join(target_dir, file_) - os.makedirs(target_dir, exist_ok=True) - with tqdm( - unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ - ) as bar: - urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) - return targetpath - - -def download_urls(urls, target_dir): - paths = dict() - for fname, url in urls.items(): - outpath = download_url(fname, url, target_dir) - paths[fname] = outpath - return paths - - -def quadratic_crop(x, bbox, alpha=1.0): - """bbox is xmin, ymin, xmax, ymax""" - im_h, im_w = x.shape[:2] - bbox = np.array(bbox, dtype=np.float32) - bbox = np.clip(bbox, 0, max(im_h, im_w)) - center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) - w = bbox[2] - bbox[0] - h = bbox[3] - bbox[1] - l = int(alpha * max(w, h)) - l = max(l, 2) - - required_padding = -1 * min( - center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) - ) - required_padding = int(np.ceil(required_padding)) - if required_padding > 0: - padding = [ - [required_padding, required_padding], - [required_padding, required_padding], - ] - padding += [[0, 0]] * (len(x.shape) - 2) - x = np.pad(x, padding, "reflect") - center = center[0] + required_padding, center[1] + required_padding - xmin = int(center[0] - l / 2) - ymin = int(center[1] - l / 2) - return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) - - -def custom_collate(batch): - r"""source: pytorch 1.9.0, only one modification to original code """ - - elem = batch[0] - elem_type = type(elem) - if isinstance(elem, torch.Tensor): - out = None - if torch.utils.data.get_worker_info() is not None: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum([x.numel() for x in batch]) - storage = elem.storage()._new_shared(numel) - out = elem.new(storage) - return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': - if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': - # array of string classes and object - if np_str_obj_array_pattern.search(elem.dtype.str) is not None: - raise TypeError(default_collate_err_msg_format.format(elem.dtype)) - - return custom_collate([torch.as_tensor(b) for b in batch]) - elif elem.shape == (): # scalars - return torch.as_tensor(batch) - elif isinstance(elem, float): - return torch.tensor(batch, dtype=torch.float64) - elif isinstance(elem, int): - return torch.tensor(batch) - elif isinstance(elem, string_classes): - return batch - elif isinstance(elem, collections.abc.Mapping): - return {key: custom_collate([d[key] for d in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple - return elem_type(*(custom_collate(samples) for samples in zip(*batch))) - if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added - return batch # added - elif isinstance(elem, collections.abc.Sequence): - # check to make sure that the elements in batch have consistent size - it = iter(batch) - elem_size = len(next(it)) - if not all(len(elem) == elem_size for elem in it): - raise RuntimeError('each element in list of batch should be of equal size') - transposed = zip(*batch) - return [custom_collate(samples) for samples in transposed] - - raise TypeError(default_collate_err_msg_format.format(elem_type)) +import collections +import os +import tarfile +import urllib +import zipfile +from pathlib import Path + +import numpy as np +import torch +from taming.data.helper_types import Annotation +#from torch._six import string_classes +from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from tqdm import tqdm + + +def unpack(path): + if path.endswith("tar.gz"): + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("tar"): + with tarfile.open(path, "r:") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("zip"): + with zipfile.ZipFile(path, "r") as f: + f.extractall(path=os.path.split(path)[0]) + else: + raise NotImplementedError( + "Unknown file extension: {}".format(os.path.splitext(path)[1]) + ) + + +def reporthook(bar): + """tqdm progress bar for downloads.""" + + def hook(b=1, bsize=1, tsize=None): + if tsize is not None: + bar.total = tsize + bar.update(b * bsize - bar.n) + + return hook + + +def get_root(name): + base = "data/" + root = os.path.join(base, name) + os.makedirs(root, exist_ok=True) + return root + + +def is_prepared(root): + return Path(root).joinpath(".ready").exists() + + +def mark_prepared(root): + Path(root).joinpath(".ready").touch() + + +def prompt_download(file_, source, target_dir, content_dir=None): + targetpath = os.path.join(target_dir, file_) + while not os.path.exists(targetpath): + if content_dir is not None and os.path.exists( + os.path.join(target_dir, content_dir) + ): + break + print( + "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) + ) + if content_dir is not None: + print( + "Or place its content into '{}'.".format( + os.path.join(target_dir, content_dir) + ) + ) + input("Press Enter when done...") + return targetpath + + +def download_url(file_, url, target_dir): + targetpath = os.path.join(target_dir, file_) + os.makedirs(target_dir, exist_ok=True) + with tqdm( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ + ) as bar: + urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) + return targetpath + + +def download_urls(urls, target_dir): + paths = dict() + for fname, url in urls.items(): + outpath = download_url(fname, url, target_dir) + paths[fname] = outpath + return paths + + +def quadratic_crop(x, bbox, alpha=1.0): + """bbox is xmin, ymin, xmax, ymax""" + im_h, im_w = x.shape[:2] + bbox = np.array(bbox, dtype=np.float32) + bbox = np.clip(bbox, 0, max(im_h, im_w)) + center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + l = int(alpha * max(w, h)) + l = max(l, 2) + + required_padding = -1 * min( + center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) + ) + required_padding = int(np.ceil(required_padding)) + if required_padding > 0: + padding = [ + [required_padding, required_padding], + [required_padding, required_padding], + ] + padding += [[0, 0]] * (len(x.shape) - 2) + x = np.pad(x, padding, "reflect") + center = center[0] + required_padding, center[1] + required_padding + xmin = int(center[0] - l / 2) + ymin = int(center[1] - l / 2) + return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) + + +def custom_collate(batch): + r"""source: pytorch 1.9.0, only one modification to original code """ + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return custom_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str):#string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: custom_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(custom_collate(samples) for samples in zip(*batch))) + if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added + return batch # added + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [custom_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type))