diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..57e7007 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2023 Corti.ai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 56126c9..9c3e5c0 100644 --- a/README.md +++ b/README.md @@ -153,8 +153,12 @@ PyTorch models are typically trained and evaluated on batches of data. However, - Combining batch or length dimensions with one or more other dimensions into a single dimension using e.g. `torch.reshape`, `torch.flatten` or masked indexing. - Options: - Fail outright. - - Fallback to a regular `torch.Tensor`. + - Fallback to a regular `torch.Tensor`. <-- Chose this one. - Fallback to a different tensor subclass that is identical in behaviour to `torch.Tensor` but carries the frozen `StreamMetadata` along. +- Support loading/saving of named tensors by custom `__reduce__` or `__reduce_ex__`. +- How do we deal with + - Special tokens concatenated to the input? E.g. "translate" and "language" tokens in Whisper? + - Learnable tokens concatenated to the input sequence before an MHSA layer? ## Can we use DreamStream for training? @@ -185,6 +189,11 @@ pip install -r requirements.txt ``` +## Run tests +```bash +pytest -sv --cov=dreamstream --cov-report=term -p no:pytest_wampy tests/test_tensor.py::TestUnbind +``` + ```python diff --git a/doc_scrape/lists/default-valid-pointwise-ops-2023_06_19-16_13_42.txt b/doc_scrape/lists/default-valid-pointwise-ops-2023_06_19-16_13_42.txt new file mode 100644 index 0000000..69d5b6c --- /dev/null +++ b/doc_scrape/lists/default-valid-pointwise-ops-2023_06_19-16_13_42.txt @@ -0,0 +1,166 @@ +abs +Tensor.abs +absolute +Tensor.absolute +acos +Tensor.acos +arccos +Tensor.arccos +acosh +Tensor.acosh +arccosh +Tensor.arccosh +add +Tensor.add +addcdiv +Tensor.addcdiv +addcmul +Tensor.addcmul +angle +Tensor.angle +asin +Tensor.asin +arcsin +Tensor.arcsin +asinh +Tensor.asinh +arcsinh +Tensor.arcsinh +atan +Tensor.atan +arctan +Tensor.arctan +atanh +Tensor.atanh +arctanh +Tensor.arctanh +atan2 +Tensor.atan2 +arctan2 +Tensor.arctan2 +bitwise_not +Tensor.bitwise_not +ceil +Tensor.ceil +clamp +Tensor.clamp +clip +Tensor.clip +conj_physical +Tensor.conj_physical +cos +Tensor.cos +cosh +Tensor.cosh +deg2rad +Tensor.deg2rad +div +Tensor.div +divide +Tensor.divide +digamma +Tensor.digamma +erf +Tensor.erf +erfc +Tensor.erfc +erfinv +Tensor.erfinv +exp +Tensor.exp +expm1 +Tensor.expm1 +fix +Tensor.fix +float_power +Tensor.float_power +floor +Tensor.floor +floor_divide +Tensor.floor_divide +frac +Tensor.frac +ldexp +Tensor.ldexp +lgamma +Tensor.lgamma +log +Tensor.log +log10 +Tensor.log10 +log1p +Tensor.log1p +log2 +Tensor.log2 +logaddexp +Tensor.logaddexp +logaddexp2 +Tensor.logaddexp2 +logical_and +Tensor.logical_and +logical_not +Tensor.logical_not +logical_or +Tensor.logical_or +logical_xor +Tensor.logical_xor +hypot +Tensor.hypot +i0 +Tensor.i0 +igamma +Tensor.igamma +igammac +Tensor.igammac +mul +Tensor.mul +multiply +Tensor.multiply +neg +Tensor.neg +negative +Tensor.negative +nextafter +Tensor.nextafter +polygamma +Tensor.polygamma +positive +Tensor.positive +pow +Tensor.pow +rad2deg +Tensor.rad2deg +reciprocal +Tensor.reciprocal +round +Tensor.round +rsqrt +Tensor.rsqrt +sigmoid +Tensor.sigmoid +sign +Tensor.sign +signbit +Tensor.signbit +sin +Tensor.sin +sinh +Tensor.sinh +softmax +Tensor.softmax +sqrt +Tensor.sqrt +square +Tensor.square +sub +Tensor.sub +subtract +Tensor.subtract +tan +Tensor.tan +tanh +Tensor.tanh +true_divide +Tensor.true_divide +trunc +Tensor.trunc \ No newline at end of file diff --git a/doc_scrape/lists/inplace-recouple-pointwise-ops-2023_06_19-16_13_42.txt b/doc_scrape/lists/inplace-recouple-pointwise-ops-2023_06_19-16_13_42.txt new file mode 100644 index 0000000..24b6836 --- /dev/null +++ b/doc_scrape/lists/inplace-recouple-pointwise-ops-2023_06_19-16_13_42.txt @@ -0,0 +1,23 @@ +Tensor.bitwise_and_ +Tensor.bitwise_or_ +Tensor.bitwise_xor_ +Tensor.bitwise_left_shift_ +Tensor.bitwise_right_shift_ +conj_physical_ +Tensor.conj_physical_ +Tensor.copysign_ +exp2_ +Tensor.exp2_ +Tensor.fmod_ +Tensor.lerp_ +logit_ +Tensor.logit_ +Tensor.mvlgamma_ +nan_to_num_ +Tensor.nan_to_num_ +Tensor.remainder_ +Tensor.sgn_ +sinc_ +Tensor.sinc_ +xlogy_ +Tensor.xlogy_ \ No newline at end of file diff --git a/doc_scrape/lists/recouple-pointwise-ops-2023_06_19-16_13_42.txt b/doc_scrape/lists/recouple-pointwise-ops-2023_06_19-16_13_42.txt new file mode 100644 index 0000000..a2fef71 --- /dev/null +++ b/doc_scrape/lists/recouple-pointwise-ops-2023_06_19-16_13_42.txt @@ -0,0 +1,32 @@ +bitwise_and +Tensor.bitwise_and +bitwise_or +Tensor.bitwise_or +bitwise_xor +Tensor.bitwise_xor +bitwise_left_shift +Tensor.bitwise_left_shift +bitwise_right_shift +Tensor.bitwise_right_shift +copysign +Tensor.copysign +exp2 +Tensor.exp2 +fmod +Tensor.fmod +lerp +Tensor.lerp +logit +Tensor.logit +mvlgamma +Tensor.mvlgamma +nan_to_num +Tensor.nan_to_num +remainder +Tensor.remainder +sgn +Tensor.sgn +sinc +Tensor.sinc +xlogy +Tensor.xlogy \ No newline at end of file diff --git a/doc_scrape/lists/valid-pointwise-ops-2023_06_19-16_13_42.txt b/doc_scrape/lists/valid-pointwise-ops-2023_06_19-16_13_42.txt new file mode 100644 index 0000000..6fff35d --- /dev/null +++ b/doc_scrape/lists/valid-pointwise-ops-2023_06_19-16_13_42.txt @@ -0,0 +1,122 @@ +abs_ +Tensor.abs_ +Tensor.absolute_ +acos_ +Tensor.acos_ +arccos_ +Tensor.arccos_ +acosh_ +Tensor.acosh_ +arccosh_ +Tensor.arccosh_ +Tensor.add_ +Tensor.addcdiv_ +Tensor.addcmul_ +asin_ +Tensor.asin_ +arcsin_ +Tensor.arcsin_ +asinh_ +Tensor.asinh_ +arcsinh_ +Tensor.arcsinh_ +atan_ +Tensor.atan_ +arctan_ +Tensor.arctan_ +atanh_ +Tensor.atanh_ +arctanh_ +Tensor.arctanh_ +Tensor.atan2_ +Tensor.arctan2_ +Tensor.bitwise_not_ +ceil_ +Tensor.ceil_ +clamp_ +Tensor.clamp_ +clip_ +Tensor.clip_ +cos_ +Tensor.cos_ +cosh_ +Tensor.cosh_ +deg2rad_ +Tensor.deg2rad_ +Tensor.div_ +Tensor.divide_ +Tensor.digamma_ +erf_ +Tensor.erf_ +erfc_ +Tensor.erfc_ +Tensor.erfinv_ +exp_ +Tensor.exp_ +expm1_ +Tensor.expm1_ +fix_ +Tensor.fix_ +Tensor.float_power_ +floor_ +Tensor.floor_ +Tensor.floor_divide_ +frac_ +Tensor.frac_ +ldexp_ +Tensor.ldexp_ +Tensor.lgamma_ +log_ +Tensor.log_ +log10_ +Tensor.log10_ +log1p_ +Tensor.log1p_ +log2_ +Tensor.log2_ +Tensor.logical_and_ +Tensor.logical_not_ +Tensor.logical_or_ +Tensor.logical_xor_ +Tensor.hypot_ +i0_ +Tensor.i0_ +Tensor.igamma_ +Tensor.igammac_ +Tensor.mul_ +Tensor.multiply_ +neg_ +Tensor.neg_ +negative_ +Tensor.negative_ +Tensor.nextafter_ +Tensor.polygamma_ +Tensor.pow_ +rad2deg_ +Tensor.rad2deg_ +reciprocal_ +Tensor.reciprocal_ +round_ +Tensor.round_ +rsqrt_ +Tensor.rsqrt_ +sigmoid_ +Tensor.sigmoid_ +Tensor.sign_ +sin_ +Tensor.sin_ +sinh_ +Tensor.sinh_ +sqrt_ +Tensor.sqrt_ +square_ +Tensor.square_ +Tensor.sub_ +Tensor.subtract_ +tan_ +Tensor.tan_ +tanh_ +Tensor.tanh_ +Tensor.true_divide_ +trunc_ +Tensor.trunc_ \ No newline at end of file diff --git a/doc_scrape/pointwise-ops.py b/doc_scrape/pointwise-ops.py new file mode 100644 index 0000000..45f4456 --- /dev/null +++ b/doc_scrape/pointwise-ops.py @@ -0,0 +1,198 @@ +from glob import glob +import urllib.request +from datetime import datetime +from copy import deepcopy + +import torch +from bs4 import BeautifulSoup +from tqdm import tqdm + +# from dreamstream.utils.dummies import TestTensor +from dreamstream.utils.listloaders import get_tensor_attr +from dreamstream.tensor import recouple, inplace_recouple, TestTensor + +# from ..tests import test_dict + + +tensor = torch.tensor( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[13, 14, 15], [16, 17, 18]]], dtype=torch.float32 +) +# tensor = torch.tensor([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]) +tensors = [tensor.clone() for i in range(3)] + + +def to_test(tensor): + return TestTensor(tensor.rename("A", "B", "C"), meta="test") + + +class Inputs: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def __iter__(self): + return iter((self.args, self.kwargs)) + + +def valid(func, *args, **kwargs): + return func(*args, **kwargs) + + +def default_valid(func, *args, **kwargs): + out = func(*args, **kwargs) + + metas = [x.meta for x in [*args, *kwargs.values()] if isinstance(x, TestTensor)] + if not all(s == metas[0] for s in metas[1:]): + msg = ( + f"Called a torch function ({func.__name__}) which was not handled by " + f"StreamTensor.__torch_function__ with {len(metas)} StreamTensors in the input." + f"In this case the function can only be handled if the StreamTensors have equal metadata," + f"but they were not equal." + ) + raise RuntimeError(msg) + + if isinstance(out, TestTensor): + out.meta = metas[0] + return out + elif isinstance(out, torch.Tensor): + return TestTensor(out, meta=metas[0]) + + return out + + +def compare_tensors(test_out, out): + nan_filter = torch.isnan(out) + out = out[~nan_filter] + + assert out.numel() > 0 + assert isinstance(test_out, TestTensor) + assert hasattr(test_out, "meta") + assert test_out.meta == "test" + assert torch.allclose(torch.Tensor(test_out).rename(None)[~nan_filter], out) + + +inplace_func_hierarchy = [valid, inplace_recouple] +outofplace_func_hierarchy = [valid, recouple] + +fp = urllib.request.urlopen("https://pytorch.org/docs/stable/torch.html") +html_doc = fp.read().decode("utf8") +fp.close() + +soup = BeautifulSoup(html_doc, "html.parser") + +# scrape and validate pointwise ops +pwo_section = soup.find("section", {"id": "pointwise-ops"}) +pwo_rows = pwo_section.find_all("tr") +pwo_names = [r.find("td").text for r in pwo_rows] + +func_names = [] +for n in pwo_names: + if hasattr(torch, n): + func_names.append(f"{n}") + if hasattr(torch.Tensor, n): + func_names.append(f"Tensor.{n}") + if hasattr(torch, n + "_"): + func_names.append(f"{n}_") + if hasattr(torch.Tensor, n + "_"): + func_names.append(f"Tensor.{n}_") + +no_valid_input_found = [] +no_valid_output = [] +valid_funcs = [] +default_valid_funcs = [] +recouple_funcs = [] +recouple_inplace_funcs = [] + +for func_name in tqdm(func_names): + # these are handled manually (either decouple or customized) + if "quantize" in func_name: + continue + if func_name.endswith("real") or func_name.endswith("imag"): + continue + if func_name.endswith("frexp"): # or func_name.endswith("ldexp"): + continue + if func_name.endswith("gradient"): + continue + + is_inplace = func_name.endswith("_") + func = get_tensor_attr(func_name) + input_valid = False + + for i in range(1, 4): + try: + inputs = tensors[:i] + + if "bitwise" in func_name: + inputs = [x.to(torch.int64) for x in inputs] + if func_name.endswith("softmax"): + inputs += [-1] + if ("mvlgamma" in func_name) or ("Tensor.polygamma" in func_name): + inputs += [1] + if func_name == "polygamma": + inputs = [1] + inputs + if func_name.endswith("float_power_"): + inputs = [x.to(torch.float64) for x in inputs] + + target_inputs = [x.clone() if isinstance(x, torch.Tensor) else x for x in inputs] if is_inplace else inputs + out = func(*target_inputs) + input_valid = True + break + + except Exception: + if i == 3: + no_valid_input_found.append(func_name) + + if input_valid: + try: + if is_inplace: + test_inputs = [to_test(x.clone()) if isinstance(x, torch.Tensor) else deepcopy(x) for x in inputs] + test_out = valid(func, *test_inputs) + compare_tensors(test_out, out) + valid_funcs.append(func_name) + else: + test_inputs = [to_test(x) if isinstance(x, torch.Tensor) else x for x in inputs] + test_out = default_valid(func, *test_inputs) + compare_tensors(test_out, out) + default_valid_funcs.append(func_name) + continue + except Exception: + pass + + try: + if is_inplace: + test_inputs = [to_test(x.clone()) if isinstance(x, torch.Tensor) else deepcopy(x) for x in inputs] + test_out = inplace_recouple(func, *test_inputs, _tensor_type=TestTensor) + compare_tensors(test_out, out) + recouple_inplace_funcs.append(func_name) + else: + test_inputs = [to_test(x) if isinstance(x, torch.Tensor) else x for x in inputs] + test_out = recouple(func, *test_inputs, _tensor_type=TestTensor) + compare_tensors(test_out, out) + recouple_funcs.append(func_name) + continue + except Exception: + pass + + no_valid_output.append(func_name) + +print(f"\n\nNumber of valid funcs: {len(valid_funcs)}") +print(f"Number of default valid funcs: {len(default_valid_funcs)}") +print(f"Number of recouple funcs: {len(recouple_funcs)}") +print(f"Number of inplace recouple funcs: {len(recouple_inplace_funcs)}") +print(f"Number of funcs wo/ valid input: {len(no_valid_input_found)}") +print(f"Number of funcs wo/ valid output: {len(no_valid_output)}") + + +if len(glob("lists/*pointwise-ops*.txt")) > 0: + raise Exception( + "lists/*pointwise-ops*.txt already exists. Please delete this/these file(s) before running this script." + ) +scrape_date = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") +with open(f"lists/default-valid-pointwise-ops-{scrape_date}.txt", "w") as file_buffer: + file_buffer.write("\n".join(default_valid_funcs)) +with open(f"lists/valid-pointwise-ops-{scrape_date}.txt", "w") as file_buffer: + file_buffer.write("\n".join(valid_funcs)) +with open(f"lists/recouple-pointwise-ops-{scrape_date}.txt", "w") as file_buffer: + file_buffer.write("\n".join(recouple_funcs)) +with open(f"lists/inplace-recouple-pointwise-ops-{scrape_date}.txt", "w") as file_buffer: + file_buffer.write("\n".join(recouple_inplace_funcs)) diff --git a/dreamstream/__init__.py b/dreamstream/__init__.py index dacc0a9..2a1ded7 100644 --- a/dreamstream/__init__.py +++ b/dreamstream/__init__.py @@ -1,4 +1,6 @@ import dreamstream.overrides # noqa: F401 import dreamstream.random # noqa: F401 -from dreamstream.tensor import StreamMetadata, StreamTensor, as_stream_tensor, stream_tensor # noqa: F401 +from dreamstream.patches import patch, patch_module, add_streaming_modes, patch_conv_1d # noqa: F401 +from dreamstream.tensor import StreamTensor, StreamMetadata, stream_tensor, as_stream_tensor # noqa: F401 +from dreamstream.warnings import suppress_warnings # noqa: F401 diff --git a/dreamstream/data/data_objects.py b/dreamstream/data/data_objects.py index fea76e0..4509199 100644 --- a/dreamstream/data/data_objects.py +++ b/dreamstream/data/data_objects.py @@ -100,8 +100,11 @@ def num_chunks(self): class OutputCollector(dict): - def __init__(self, *stream_tensor: StreamTensor): + def __init__(self, *stream_tensor: StreamTensor, collection: str = "cat"): super().__init__() + + update_unaries = dict(cat=self._update_unary_tensor_cat, append=self._update_unary_list_append) + self._update_unary = update_unaries[collection] self.closed_entries = set() self.update(*stream_tensor) @@ -109,13 +112,13 @@ def update(self, *stream_tensors: Tuple[StreamTensor]): for t in stream_tensors: if BATCH in t.names: for x in t.unpad_sequence(): - if x.meta.max_length > 0: + if x.meta.max_length > 0: # NOTE (JDH): Will error out when the first chunk output has length == 0. self._update_unary(x) else: assert t.meta.size() == 1, "The tensor has no batch dimension, but metadata has multiple elements." self._update_unary(t) - def _update_unary(self, stream_tensor: StreamTensor): + def _update_unary_tensor_cat(self, stream_tensor: StreamTensor): _id = stream_tensor.meta.ids[0] if _id in self.closed_entries: @@ -124,9 +127,32 @@ def _update_unary(self, stream_tensor: StreamTensor): self.closed_entries.add(_id) if _id in self: - assert not stream_tensor.meta.sos.item(), "The tensor is the first chunk." + if stream_tensor.meta.sos.item(): + raise ValueError(f"Tried to collect a chunk on a known id {_id} but got a chunk claiming to be first.") + length_dim = stream_tensor.names.index(LENGTH) self[_id] = torch.cat([self[_id], stream_tensor], dim=length_dim) else: - assert stream_tensor.meta.sos.item(), "The tensor is not the first chunk." + if not stream_tensor.meta.sos.item(): + raise ValueError(f"Tried to start a new id `{_id}` but the received chunk was not the first chunk.") + self[_id] = stream_tensor + + def _update_unary_list_append(self, stream_tensor: StreamTensor): + _id = stream_tensor.meta.ids[0] + + if _id in self.closed_entries: + raise ValueError(f"The entry for {_id} has already been closed.") + if stream_tensor.meta.eos.item(): + self.closed_entries.add(_id) + + if _id in self: + if stream_tensor.meta.sos.item(): + raise ValueError(f"Tried to collect a chunk on a known id {_id} but got a chunk claiming to be first.") + + self[_id].append(stream_tensor) + else: + if not stream_tensor.meta.sos.item(): + raise ValueError(f"Tried to start a new id `{_id}` but the received chunk was not the first chunk.") + + self[_id] = [stream_tensor] diff --git a/dreamstream/data/stream_dataset.py b/dreamstream/data/stream_dataset.py index 4be2295..fe02786 100644 --- a/dreamstream/data/stream_dataset.py +++ b/dreamstream/data/stream_dataset.py @@ -323,7 +323,7 @@ def continue_buffering( num_current_buffered = sum(1 for id in active_ids if id in buffers and buffers[id] and buffers[id][-1].eos) num_next_buffered = sum(1 for id in buffers.keys() if id not in active_ids and buffers[id] and buffers[id][-1].eos) - return num_next_buffered < (batch_size - num_current_buffered) or num_current_buffered < num_current_running + return num_current_buffered < num_current_running or num_next_buffered < (batch_size - num_current_running) def make_streams_synchronous( @@ -382,12 +382,13 @@ def __init__( batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, - non_overlapping_batches: bool = False, + overlapping_batches: bool = False, collate_fn: Optional[Callable] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context: Any | None = None, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: Optional[torch.device] = "", @@ -402,12 +403,15 @@ def __init__( Args: dataset (Union[IterableDataset, List[IterableDataset]]): A single or a list of IterableDataset instances. batch_size (int): The batch size. + shuffle (bool): Whether to shuffle the dataset between epochs. num_workers (int, optional): When a single dataset is given, it is split into `num_workers` subsets and - each to be processed in parallel by a single worker. When a list of datasets is given, `num_workers` - has no effect. Defaults to 0. - non_overlapping_batches (bool, optional): If True, the batches are non-overlapping. In this case, new files - are started only once every file in the batch has ended. If False, a new file is started as soon as a - file in the previous batch ended. Defaults to False. + each to be processed in parallel by a DataLoader using a single worker. When a list of datasets is + given, `num_workers` has no effect. Defaults to 0. + overlapping_batches (bool, optional): If True, each new batch will introduce new files to replace those that + ended in the previous batch, keeping the batch size constant (except for the last batch if + `drop_last=False`). If False, new files will only start once all files from the previous batch have + ended. This is usually more memory and compute efficient for state management in models in online mode, + but also leads to more batches. Defaults to False. drop_last (bool, optional): Drop the last batch(es) if smaller than `batch_size`. Defaults to False. collate_fn (Callable, optional): A function that takes a list of batch parts and collates them into a batch. Defaults to None. @@ -416,14 +420,17 @@ def __init__( self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers - self.non_overlapping_batches = non_overlapping_batches + self.overlapping_batches = overlapping_batches self.drop_last = drop_last self.worker_init_fn = worker_init_fn + self.collate_fn = collate_fn - if drop_last and non_overlapping_batches: - raise ValueError("Only one of `drop_last` and `non_overlapping_batches` can be True.") + if drop_last and not overlapping_batches: + raise ValueError("`drop_last` can only be used when overlapping batches.") if isinstance(dataset, list): + if num_workers > 0: + raise ValueError("When a list of datasets is given, `num_workers` has no effect.") self.num_workers = len(dataset) if collate_fn is None: @@ -432,13 +439,13 @@ def __init__( else: collate_fn = torch.utils.data._utils.collate.default_collate - self.collate_fn = collate_fn - self.dataloader_kwargs = dict( batch_size=None, + collate_fn=None, num_workers=(0 if self.num_workers == 0 else 1), pin_memory=pin_memory, timeout=timeout, + multiprocessing_context=multiprocessing_context, prefetch_factor=(prefetch_factor if num_workers > 0 else None), persistent_workers=persistent_workers, pin_memory_device=pin_memory_device, @@ -446,7 +453,7 @@ def __init__( def _get_worker_init_fn(self, actual_worker_id: int): """Wrap worker_init_fn to change the input `worker_id` for each worker. This is necessary because we - use `num_workers` independent DataLoaders instead of one DataLoader with `num_workers` workers.""" + use `num_workers` independent DataLoaders with one worker each instead of one with `num_workers` workers.""" if self.worker_init_fn is None: return None @@ -458,6 +465,9 @@ def wrapped_worker_init_fn(dataloader_worker_id: int): return wrapped_worker_init_fn def get_stream_loaders(self) -> Generator[Any, None, None]: + # TODO (JDH): Wrap this method in a dedicated process to offload all data preparation from main process. + # Currently this happens in the main proces and may take a bit of time, especially the collation. + if isinstance(self.dataset, IterableDataset): num_workers = max(1, self.num_workers) datasets = self.dataset.split( @@ -488,7 +498,7 @@ def get_stream_loaders(self) -> Generator[Any, None, None]: stream_loader = (list(itertools.chain.from_iterable(batch_parts)) for batch_parts in stream_loaders) # Wait for all files being streamed to finish before starting the next set of files. - if self.non_overlapping_batches: + if not self.overlapping_batches: stream_loader = make_streams_synchronous(stream_loader) # Collate batches @@ -500,10 +510,6 @@ def __iter__(self) -> Generator[Any, None, None]: yield batch -# TODO (JDH): Wrap collation in a dedicated process to offload all data preparation from main process. Currently -# collation happens in the main proces. - - class MultiStreamOneProcessDataLoader: def __init__( self, @@ -511,7 +517,7 @@ def __init__( batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, - non_overlapping_batches: bool = False, + overlapping_batches: bool = False, collate_fn: Optional[Callable] = None, pin_memory: bool = False, drop_last: bool = False, @@ -537,7 +543,7 @@ def __init__( num_workers (int, optional): When a single dataset is given, it is split into `num_workers` subsets and each to be processed in parallel by a single worker. When a list of datasets is given, `num_workers` has no effect. Defaults to 0. - non_overlapping_batches (bool, optional): If True, the batches are non-overlapping. In this case, new files + overlapping_batches (bool, optional): If True, the batches are non-overlapping. In this case, new files are started only once every file in the batch has ended. If False, a new file is started as soon as a file in the previous batch ended. Defaults to False. drop_last (bool, optional): Drop the last batch(es) if smaller than `batch_size`. Defaults to False. @@ -552,7 +558,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, num_workers=0, - non_overlapping_batches=non_overlapping_batches, + overlapping_batches=overlapping_batches, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, diff --git a/dreamstream/func_coverage.py b/dreamstream/func_coverage.py index 9bbe44c..904e973 100644 --- a/dreamstream/func_coverage.py +++ b/dreamstream/func_coverage.py @@ -1,25 +1,95 @@ import torch +from dreamstream.utils.listloaders import ( + load_default_valid_pointwise_ops, + load_valid_pointwise_ops, + load_recouple_pointwise_ops, + load_inplace_recouple_pointwise_ops, +) + FLAT_OVERRIDABLE_FUNCTIONS = {f for k, fs in torch.overrides.get_overridable_functions().items() for f in fs} # Functions that must be overridden to handle StreamMetadata and limit use cases. -OVERRIDDEN_FUNCTIONS = dict() +CUSTOMIZED_FUNCTIONS = dict() + +GET_METHODS = {f for f in FLAT_OVERRIDABLE_FUNCTIONS if f.__name__ == "__get__"} + +DEFAULT_VALID_POITWISE_OPS = load_default_valid_pointwise_ops() +VALID_POITWISE_OPS = load_valid_pointwise_ops() +RECOUPLE_POITWISE_OPS = load_recouple_pointwise_ops() +INPLACE_RECOUPLE_POITWISE_OPS = load_inplace_recouple_pointwise_ops() # Functions that work for StreamTensors using the super().__torch_function__. VALID_FUNCTIONS = { torch.Tensor.__repr__, torch.Tensor.__str__, + torch.Tensor.__dir__, + torch.Tensor.size, + torch.equal, + torch.Tensor.rename_, + torch.Tensor.zero_, + torch.Tensor.fill_, + torch.Tensor.bernoulli_, + torch.Tensor.cauchy_, + torch.Tensor.exponential_, + torch.Tensor.geometric_, + torch.Tensor.log_normal_, + torch.Tensor.normal_, + torch.Tensor.random_, + torch.Tensor.uniform_, + # torch.is_tensor and torch.is_storage are not overridable + torch.is_complex, + torch.Tensor.is_complex, + torch.is_conj, + torch.Tensor.is_conj, + torch.is_floating_point, + torch.Tensor.is_floating_point, + torch.is_nonzero, + torch.Tensor.is_nonzero, + torch.numel, + torch.Tensor.numel, + torch.Tensor.dim, + torch.Tensor.imag, + torch.Tensor.real, +} +VALID_FUNCTIONS.update(GET_METHODS) +VALID_FUNCTIONS.update(VALID_POITWISE_OPS) + +# Functions that work for StreamTensors using the super().__torch_function__, but .meta is not preserved. +DEFAULT_VALID_FUNCTIONS = { + torch.Tensor.align_to, + torch.transpose, + torch.Tensor.transpose, + torch.Tensor.rename, + torch.clone, # meta is not preserved + torch.Tensor.clone, # meta is not preserved + torch.rand_like, + torch.randn_like, + torch.zeros_like, + torch.ones_like, + torch.empty_like, + torch.full_like, + torch.Tensor.__abs__, + torch.Tensor.__neg__, + torch.Tensor.__add__, } +DEFAULT_VALID_FUNCTIONS.update(DEFAULT_VALID_POITWISE_OPS) -# Functions that must be wrapped to avoid returning a StreamTensor. +# Functions that must be wrapped to avoid returning a StreamTensor (and may not). DECOUPLE_FUNCTIONS = { torch.argmax, torch.argmin, torch.argsort, + torch.allclose, + torch.Tensor.allclose, + torch.quantize_per_channel, # output "type" is not supported with names + torch.fake_quantize_per_channel_affine, # output "type" is not supported with names + torch.fake_quantize_per_tensor_affine, # output "type" is not supported with names + torch.gradient, # TODO: double check if decouple is the right category for torch.gradient } -# Functions that must be wrapped to avoid failures related to named tensors and to maintain StreamMetadata. +# Functions that must be wrapped to avoid failures related to named tensors (and to maintain StreamMetadata). RECOUPLE_FUNCTIONS = { torch.adjoint, torch.sigmoid, @@ -36,10 +106,35 @@ torch.Tensor.select_scatter, torch.slice_scatter, torch.Tensor.slice_scatter, + torch.isclose, + torch.Tensor.isclose, + torch.randint_like, + torch.heaviside, + torch.Tensor.heaviside, + torch.dequantize, + torch.Tensor.dequantize, + torch.polar, + torch.complex, + torch.real, + torch.imag, +} +RECOUPLE_FUNCTIONS.update(RECOUPLE_POITWISE_OPS) + +INPLACE_RECOUPLE_FUNCTIONS = { + torch.Tensor.heaviside_, } +INPLACE_RECOUPLE_FUNCTIONS.update(INPLACE_RECOUPLE_POITWISE_OPS) # The full set of functions that are explicitly supported for StreamTensors. -SUPPORTED_FUNCTIONS = VALID_FUNCTIONS | DECOUPLE_FUNCTIONS | RECOUPLE_FUNCTIONS | OVERRIDDEN_FUNCTIONS.keys() +SUPPORTED_FUNCTIONS = ( + VALID_FUNCTIONS + | DEFAULT_VALID_FUNCTIONS + | DECOUPLE_FUNCTIONS + | RECOUPLE_FUNCTIONS + | INPLACE_RECOUPLE_FUNCTIONS + | CUSTOMIZED_FUNCTIONS.keys() +) # The set of functions that are overrideable but not explicitly supported. These functions may still work correctly. UNSUPPORTED_FUNCTIONS = {f for f in FLAT_OVERRIDABLE_FUNCTIONS if f not in SUPPORTED_FUNCTIONS} +SUPPORTED_NON_OVERRIDEABLE_FUNCTIONS = {f for f in SUPPORTED_FUNCTIONS if f not in FLAT_OVERRIDABLE_FUNCTIONS} diff --git a/dreamstream/nn/utils/__init__.py b/dreamstream/nn/utils/__init__.py index bf297a1..8083b8f 100644 --- a/dreamstream/nn/utils/__init__.py +++ b/dreamstream/nn/utils/__init__.py @@ -1 +1 @@ -from .pad_sequence import * # noqa: F403 +from dreamstream.nn.utils.pad_sequence import pad_chunks, pad_full_sequence, pad_stream_tensor # noqa: F401 diff --git a/dreamstream/nn/utils/pad_sequence.py b/dreamstream/nn/utils/pad_sequence.py index 1e80c05..14a451b 100644 --- a/dreamstream/nn/utils/pad_sequence.py +++ b/dreamstream/nn/utils/pad_sequence.py @@ -36,14 +36,12 @@ def pad_chunks( tensor = tensor.rename(*names) meta = StreamMetadata(ids=ids, sos=sos, eos=eos, lengths=lengths) - return StreamTensor(data=tensor, meta=meta) def pad_full_sequence(sequences, names: List[str], ids: List[str], batch_first: bool = False) -> StreamTensor: sos = torch.full((len(sequences),), True, dtype=torch.bool) eos = torch.full((len(sequences),), True, dtype=torch.bool) - return pad_chunks(sequences=sequences, names=names, ids=ids, sos=sos, eos=eos, batch_first=batch_first) diff --git a/dreamstream/overrides.py b/dreamstream/overrides.py index 7c8d7f0..c0f2fae 100644 --- a/dreamstream/overrides.py +++ b/dreamstream/overrides.py @@ -1,6 +1,6 @@ import functools from copy import deepcopy -from typing import Any, Callable, NamedTuple, Optional, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, Optional, List, Sequence, Tuple, Union from torch.types import Number import numpy as np @@ -9,8 +9,9 @@ from torch import Tensor from dreamstream.tensor import StreamTensor, StreamMetadata, decouple_recursive -from dreamstream.func_coverage import OVERRIDDEN_FUNCTIONS +from dreamstream.func_coverage import CUSTOMIZED_FUNCTIONS from dreamstream.utils.flags import BATCH, LENGTH +from dreamstream.utils.operations import sequence_mask from dreamstream.warnings import fallback_operation_warning @@ -42,7 +43,7 @@ def implements(torch_function): def decorator(func): functools.update_wrapper(func, torch_function, assigned=WRAPPER_ASSIGNMENTS) func.__doc__ = augment_documentation(func.__doc__, torch_function.__doc__) - OVERRIDDEN_FUNCTIONS[torch_function] = func + CUSTOMIZED_FUNCTIONS[torch_function] = func return func return decorator @@ -62,8 +63,8 @@ def cat(tensors: List[Union[StreamTensor, Tensor]], dim=0, *, out=None): is_batch_dim = [t.is_batch_dim(dim) for t in tensors if isinstance(t, StreamTensor)] # TODO (JDH): Speed this up if any(is_batch_dim): require_all_stream_tensors(tensors, "Cannot concatenate StreamTensor and torch.Tensor along batch dimension.") - tensors = [t.named_tensor() for t in tensors] - tensor = torch.cat(tensors, dim=dim, out=out) + torch_tensors = [t.named_tensor() for t in tensors] + tensor = torch.cat(torch_tensors, dim=dim, out=out) meta = StreamMetadata.cat_batch([t.meta for t in tensors]) # TODO (JDH): Make this lazily evaluated. return StreamTensor(tensor, meta) @@ -72,11 +73,11 @@ def cat(tensors: List[Union[StreamTensor, Tensor]], dim=0, *, out=None): if any(is_length_dim): for t in tensors[:-1]: if isinstance(t, StreamTensor) and t.meta.lengths.min() < t.max_length(): - raise NotImplementedError("Only the last tensor can be padded when concatenating along length.") - tensors = [t.named_tensor() if isinstance(t, StreamTensor) else t for t in tensors] - tensor = torch.cat(tensors, dim=dim, out=out) + raise NotImplementedError("Only the right-most input can be padded when concatenating along length.") + torch_tensors = [t.named_tensor() if isinstance(t, StreamTensor) else t for t in tensors] + tensor = torch.cat(torch_tensors, dim=dim, out=out) meta = StreamMetadata.cat_length([t.meta for t in tensors if isinstance(t, StreamTensor)]) - meta.lengths += sum([t.size(dim) for t in tensors if not isinstance(t, StreamTensor)]) + meta.lengths += sum([t.size(dim) for t in tensors if not isinstance(t, StreamTensor)]) # Add torch.Tensors return StreamTensor(tensor, meta) # Concatenation along a dimension that is neither batch nor length. @@ -89,6 +90,44 @@ def cat(tensors: List[Union[StreamTensor, Tensor]], dim=0, *, out=None): return StreamTensor(tensor, meta) +@implements(torch.stack) +def stack(tensors: List[Union[StreamTensor, Tensor]], dim=0, *, out=None): + """If dim is the batch dimension of any StreamTensor, assert all are StreamTensors and stack the stream states as + well. Else, call torch.stack. + """ + if len(tensors) == 1: + return tensors[0].unsqueeze(dim) + + # If all StreamTensors have the same metas, then just stack the tensors. + metas = [t.meta for t in tensors if isinstance(t, StreamTensor)] + if all(m == metas[0] for m in metas): + torch_tensors = [t.tensor() if isinstance(t, StreamTensor) else t for t in tensors] + tensor = torch.stack(torch_tensors, dim=dim, out=out) + names = tensors[0].names + names = names[:dim] + (None,) + names[dim:] + return StreamTensor(tensor.rename_(*names), metas[0]) + + # If not all StreamTensors have the same metas, then we can only stack them if + # 1, They all have exactly one id + # 2. All StreamTensors have different ids. + # 3. All tensors are StreamTensors. + if not all(isinstance(t, StreamTensor) for t in tensors): + raise ValueError("Cannot stack StreamTensors with different stream states together with torch.Tensors.") + + if not all(len(t.meta.ids) == 1 for t in tensors): + raise ValueError("Can only stack StreamTensors with different ids together when they each have one id.") + + if len(set(t.meta.ids[0] for t in tensors)) != len(tensors): + raise ValueError("Can only stack StreamTensors with different ids together when they each have different ids.") + + # If we get here, then each StreamTensor has one id and all have different ids. + torch_tensors = [t.tensor() if isinstance(t, StreamTensor) else t for t in tensors] + tensor = torch.stack(torch_tensors, dim=dim, out=out) + names = tensors[0].names + names = names[:dim] + (None,) + names[dim:] + return StreamTensor(tensor.rename_(*names), metas[0]) + + @implements(torch.permute) def permute(tensor: StreamTensor, dims: List[int]): out = tensor.tensor().permute(*dims) @@ -153,7 +192,6 @@ def unbind(tensor: StreamTensor, dim=0) -> List[StreamTensor]: if tensor.names[dim] == BATCH: states = meta.unbind_batch() tensors = tensor.unbind(dim=dim) - assert len(tensors) == len(states) tensors = [StreamTensor(t, meta) for t, meta in zip(tensors, states)] else: tensors = tensor.unbind(dim=dim) @@ -162,13 +200,31 @@ def unbind(tensor: StreamTensor, dim=0) -> List[StreamTensor]: return tensors -# @implements(torch.nn.functional.pad) -# def pad(input: StreamTensor, pad: List[int], mode: str = "constant", value: float = None): -# raise NotImplementedError("pad is not currently supported for StreamTensors.") +@implements(torch.quantize_per_tensor) +def quantize_per_tensor(input: Tuple[StreamTensor], scale: float, zero_point: int, dtype: torch.dtype): + input = input.tensor() if isinstance(input, StreamTensor) else tuple([t.tensor() for t in input]) + return torch.quantize_per_tensor(input, scale, zero_point, dtype) -def _compute_conv_output_lengths(input_lengths: Tensor, kernel_width: int, stride: int): - return +@implements(torch.quantize_per_tensor_dynamic) +def quantize_per_tensor_dynamic(input: Tuple[StreamTensor], dtype: torch.dtype, reduce_range: bool): + input = input.tensor() if isinstance(input, StreamTensor) else tuple([t.tensor() for t in input]) + return torch.quantize_per_tensor_dynamic(input, dtype, reduce_range) + + +@implements(torch.fake_quantize_per_tensor_affine) +def fake_quantize_per_tensor_affine(input: StreamTensor, scale: float, zero_point: int, quant_min: int, quant_max: int): + return torch.fake_quantize_per_tensor_affine(input.tensor(), scale, zero_point, quant_min, quant_max) + + +@implements(torch.frexp) +@implements(torch.Tensor.frexp) +def frexp(input: StreamTensor): + tensor, meta, names = input.decouple() + out = torch.frexp(input.tensor()) + mantissa = StreamTensor(out.mantissa.rename(*names), meta=meta) + exponent = StreamTensor(out.exponent.rename(*names), meta=meta) + return torch.return_types.frexp((mantissa, exponent)) @implements(torch.conv1d) @@ -217,6 +273,7 @@ def conv1d( padding = 0 # Create buffer. + # TODO (JDH): Default to storing the batched input buffer. output_lengths = ((meta.lengths - kernel_width) // stride[0] + 1).clip(min=0) next_start = output_lengths * stride[0] buffer = {} @@ -224,13 +281,17 @@ def conv1d( for i, (start, end, _id, eos) in enumerate(zip(next_start, meta.lengths, meta.ids, meta.eos)): if not eos: buffer[_id] = input[i, ..., start:end] + meta._temp_buffer = buffer # Convolve input and revert to StreamTensor. output = torch.conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) output.rename_(*names) meta.lengths = output_lengths - # TODO: Consider whether to zero out the padding. - return StreamTensor(output, meta), buffer + + # Zero out the padding. + mask = sequence_mask(output_lengths, max_len=output.size(-1), device=output.device) + output *= mask.unsqueeze(1) + return StreamTensor(output, meta) IntegerTensorType = Union[torch.ByteTensor, torch.CharTensor, torch.ShortTensor, torch.IntTensor, torch.LongTensor] @@ -496,10 +557,19 @@ def __getitem__(self: StreamTensor, indices: Union[IndexingType, Sequence[Indexi affected_dims = [affected_dims] indices = [indices] - batch_dim = names.index(BATCH) - length_dim = names.index(LENGTH) - is_batch_dim_affected = any([dim == batch_dim for dim in dims_affected_flat]) - is_length_dim_affected = any([dim == length_dim for dim in dims_affected_flat]) + try: + batch_dim = names.index(BATCH) + is_batch_dim_affected = any([dim == batch_dim for dim in dims_affected_flat]) + except ValueError: + batch_dim = None + is_batch_dim_affected = False + + try: + length_dim = names.index(LENGTH) + is_length_dim_affected = any([dim == length_dim for dim in dims_affected_flat]) + except ValueError: + length_dim = None + is_length_dim_affected = False if not (is_batch_dim_affected or is_length_dim_affected): # Indexing operation does not affect the batch or length dimensions, return the indexed tensor with same meta. @@ -690,9 +760,9 @@ def index_select(input: StreamTensor, dim: int, index: Tensor, *, out: Optional[ tensor, meta, names = input.decouple() out = torch.index_select(tensor, dim, index, out=out) - if dim == names.index(BATCH): + if BATCH in names and dim == names.index(BATCH): meta = meta[index] - elif dim == names.index(LENGTH): + elif LENGTH in names and dim == names.index(LENGTH): meta = meta[:, index] out.rename_(*names) @@ -894,3 +964,116 @@ def unqsqueeze(input: StreamTensor, dim: int) -> StreamTensor: # moving dimensions # X @implements(torch.transpose) # X @implements(torch.permute) + + +@implements(torch.Tensor.__reduce_ex__) +def __reduce_ex__(self: StreamTensor, proto): + print("OHI!") + self.rename_(None) + return torch.Tensor.__reduce_ex__(self, proto) + + +@implements(torch._VF._pack_padded_sequence) +def _pack_padded_sequence( + input: StreamTensor, + lengths: Tensor, + batch_first: bool = False, +) -> torch.nn.utils.rnn.PackedSequence: + """Decouple the StreamTensor input and remove the batch dimension before calling the original function.""" + tensor, meta, names = input.decouple() + data, batch_sizes = torch._VF._pack_padded_sequence(tensor, lengths, batch_first=batch_first) + names = names[1:] if batch_first else names[0] + names[2:] + data = StreamTensor(data, meta) + data.rename_(*names) + return data, batch_sizes + + +@implements(torch._VF._pad_packed_sequence) +def _pad_packed_sequence( + input_data: StreamTensor, + batch_sizes: torch.Tensor, + batch_first: bool = False, + padding_value: float = 0.0, + total_length: Optional[int] = None, +) -> Tuple[StreamTensor, Tensor]: + """Decouple the StreamTensor input before calling the original function then add back batch dimension and names.""" + # import IPython + # IPython.embed(using=False, header="pad_packed_sequence") + tensor, meta, names = input_data.decouple() + tensor, lengths = torch._VF._pad_packed_sequence(tensor, batch_sizes, batch_first, padding_value, total_length) + if BATCH in names: + names = (None,) + names if batch_first else names[0] + (None,) + names[1:] + else: + names = (BATCH,) + names if batch_first else names[0] + (BATCH,) + names[1:] + tensor.rename_(*names) + return StreamTensor(tensor, meta), lengths + + +@implements(torch._VF.rnn_tanh) +def rnn_tanh( + input: StreamTensor, + batch_sizes: Optional[torch.Tensor], + hx: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + num_layers: int, + dropout: float, + training: bool, + bidirectional: bool, +) -> Tuple[StreamTensor, torch.Tensor]: + return rnn(input, batch_sizes, hx, weights, bias, num_layers, dropout, training, bidirectional, torch._VF.rnn_tanh) + + +@implements(torch._VF.rnn_relu) +def rnn_relu( + input: StreamTensor, + batch_sizes: Optional[torch.Tensor], + hx: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + num_layers: int, + dropout: float, + training: bool, + bidirectional: bool, +) -> Tuple[StreamTensor, torch.Tensor]: + return rnn(input, batch_sizes, hx, weights, bias, num_layers, dropout, training, bidirectional, torch._VF.rnn_relu) + +# if batch_sizes is None: +# if self.mode == 'RNN_TANH': +# result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers, +# self.dropout, self.training, self.bidirectional, +# self.batch_first) +# else: +# result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers, +# self.dropout, self.training, self.bidirectional, +# self.batch_first) +# else: +# if self.mode == 'RNN_TANH': +# result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias, +# self.num_layers, self.dropout, self.training, +# self.bidirectional) +# else: +# result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias, +# self.num_layers, self.dropout, self.training, +# self.bidirectional) + +def rnn( + input: StreamTensor, + batch_sizes: Optional[torch.Tensor], + hx: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + num_layers: int, + dropout: float, + training: bool, + bidirectional: bool, + method: Union[torch._VF.rnn_tanh, torch._VF.rnn_relu], +) -> Tuple[StreamTensor, StreamTensor]: + input, meta, names = input.decouple() + if isinstance(hx, StreamTensor): + hx, hx_meta, hx_names = hx.decouple() + + output, hx = method(input, batch_sizes, hx, weights, bias, num_layers, dropout, training, bidirectional) + + output = StreamTensor(output, meta).rename_(*names) + return output, hx diff --git a/dreamstream/patches/__init__.py b/dreamstream/patches/__init__.py index 0961199..e70806a 100644 --- a/dreamstream/patches/__init__.py +++ b/dreamstream/patches/__init__.py @@ -1 +1,3 @@ -from .conv import patch_conv_1d # noqa: F401 +from dreamstream.patches.general import patch, patch_module # noqa: F401 +from dreamstream.patches.modes import add_streaming_modes # noqa: F401 +from dreamstream.patches.conv import patch_conv_1d # noqa: F401 diff --git a/dreamstream/patches/conv.py b/dreamstream/patches/conv.py index 94438b9..06e04c1 100644 --- a/dreamstream/patches/conv.py +++ b/dreamstream/patches/conv.py @@ -1,24 +1,27 @@ -import types +from typing import Tuple, Dict import torch -from dreamstream.tensor import StreamTensor, StreamMetadata -from dreamstream.patches.general import online, offline +from dreamstream.tensor import StreamTensor +from dreamstream.patches.modes import add_streaming_modes from dreamstream.nn.utils import pad_stream_tensor -# TODO (LB): Add support for subsampling convolutions (i.e., stride > kernel_width). +# TODO (LB): #1 Add support for subsampling convolutions (i.e., stride > kernel_width). +# TODO (LB): #2 Generalize to N-D convolutions, where the length dimension can be any dimension. +# TODO (LB): Add support for transposed convolutions. -def conv_1d_pre_hook(self, inputs): +def conv_1d_pre_hook(self, inputs: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: + is_stream_tensor = isinstance(inputs[0], StreamTensor) if not self.streaming: - if isinstance(inputs[0], StreamTensor): + if is_stream_tensor: raise RuntimeError("Using StreamTensors in offline mode might result in unexpected behavior.") return inputs input = inputs[0] - - assert isinstance(input, StreamTensor), "The input is expected to be StreamTensor when in online mode." + if not is_stream_tensor: + raise RuntimeError("The input is expected to be StreamTensor when in online mode.") # If all inputs are NOT first, collect states for all. if (self.kernel_width > 1) and (not input.meta.all_starting): @@ -27,116 +30,44 @@ def conv_1d_pre_hook(self, inputs): assert buffer_lengths.min() >= 0, "At least one buffer should have length greater than zero." ref_length = buffer_lengths[0] - # If all have the same length, stack them and concatenate with input. if (buffer_lengths == ref_length).all(): + # If all have the same length, stack them and concatenate with input. buffer_data = torch.stack(buffer_data) input = torch.cat([buffer_data, input], dim=-1) - - # If not, split batch into individual inputs and concatenate separately first. else: - # TODO: This needs to be tested. + # If not, split batch into individual inputs and concatenate separately first. + # TODO (LB): This needs to be tested. input = input.unpad_sequence() - input = [a if b is None else torch.cat([b, a], dim=-1) for a, b in zip(input, buffer_data)] + input = [x if b is None else torch.cat([b, x], dim=-1) for x, b in zip(input, buffer_data)] input = pad_stream_tensor(input).permute(1, 2, 0) return input -def conv_1d_post_hook(self, inputs, outputs): - if self.streaming: - outputs, buffer = outputs - self.stream_buffer.update(buffer) +def conv_1d_post_hook(self, inputs: Tuple[torch.Tensor], outputs: torch.Tensor): + if not self.streaming: + return outputs - # TODO: Simplify this. - if outputs.meta.any_end: - for _id, eos in zip(outputs.meta.ids, outputs.meta.eos): - if eos and _id in self.stream_buffer: - del self.stream_buffer[_id] + self.stream_buffer.update(outputs.meta._temp_buffer) + outputs.meta._temp_buffer = None - return outputs + # TODO (LB): Simplify this. + if outputs.meta.any_ending: + for _id, eos in zip(outputs.meta.ids, outputs.meta.eos): + if eos and _id in self.stream_buffer: + del self.stream_buffer[_id] -def patch_conv_1d(module): - # Add streaming mode. - module.online = types.MethodType(online, module) - module.offline = types.MethodType(offline, module) - module.streaming = False +def patch_conv_1d(module) -> torch.nn.Module: + add_streaming_modes(module) - # Add stream_buffer dictionary. - module.stream_buffer = {} + # Add a dictionary for storing input buffers. + module.stream_buffer: Dict[str, torch.Tensor] = {} # Add module-specific attributes. - module.kernel_width = module.kernel_size[0] + (module.kernel_size[0] - 1) * (module.dilation[0] - 1) + module.kernel_width: int = module.kernel_size[0] + (module.kernel_size[0] - 1) * (module.dilation[0] - 1) # Register pre_hook and post_hook. module.register_forward_pre_hook(conv_1d_pre_hook) module.register_forward_hook(conv_1d_post_hook) - return module - - -if __name__ == "__main__": - import os - import argparse - from random import randint - - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - parser = argparse.ArgumentParser() - - # configuration of layer - parser.add_argument("--in_channels", default=32, type=int) - parser.add_argument("--out_channels", default=64, type=int) - parser.add_argument("--kernel_size", default=3, type=int) - parser.add_argument("--stride", default=2, type=int) - - # configuration of test input - parser.add_argument("--input_min_length", default=1000, type=int) - parser.add_argument("--input_max_length", default=2000, type=int) - parser.add_argument("--input_min_chunk_size", default=50, type=int) - parser.add_argument("--input_max_chunk_size", default=50, type=int) - parser.add_argument("--input_batch_size", default=32, type=int) - - args, _ = parser.parse_known_args() - - assert args.input_min_length <= args.input_max_length - assert args.input_min_chunk_size <= args.input_max_chunk_size - - N = 20_000 - stream_inputs, stream_batch_inputs, inputs = [], [], [] - for n in range(N): - x1 = torch.rand(1, 32, randint(1, 20)) - # x2 = pad_sequence([x1, torch.rand(1, 32, randint(args.kernel_size, 50))], batch_first=True, padding_value=0) - ss1 = StreamMetadata(ids=["test_1"], lengths=[x1.size(-1)], sos=[n == 0], eos=[n == N - 1]) - # ss2 = StreamMetadata(ids=["test_1", "test_2"], lengths=[x1.size(-1), x2.size(-1)], first=[n == 0] * 2, - # last=[n == N - 1] * 2) - xs1 = StreamTensor(x1, meta=ss1) - # xs2 = StreamTensor(x2, meta=ss2) - stream_inputs.append(xs1) - # stream_batch_inputs.append(xs2) - inputs.append(x1) - - x_full = torch.cat(inputs, dim=-1) - s = StreamMetadata(ids=["test_2"], lengths=[x_full.size(-1)], sos=[True], eos=[True]) - x_full_stream = StreamTensor(x_full, meta=s) - - conv1d = torch.nn.Conv1d(args.in_channels, args.out_channels, args.kernel_size, args.stride, padding=2) - - y1 = conv1d(x_full) - - conv1d = patch_conv_1d(conv1d) - conv1d.online() - - y2 = conv1d(x_full_stream) - - ys = [] - for x in stream_inputs: - y = conv1d(x) - if y is None: - print("skipping") - continue - ys.append(y) - y3 = torch.cat([torch.Tensor(y_) for y_ in ys], dim=-1) - - assert torch.allclose(y1, y2, rtol=0, atol=0) - assert torch.allclose(y1, y3, rtol=1e-6, atol=1e-6) diff --git a/dreamstream/patches/general.py b/dreamstream/patches/general.py index 6403830..215b227 100644 --- a/dreamstream/patches/general.py +++ b/dreamstream/patches/general.py @@ -1,15 +1,47 @@ -def online(self, mode=True): - if not isinstance(mode, bool): - raise ValueError("streaming mode is expected to be boolean") - self.streaming = mode - for module in self.children(): - module.online(mode) - return self +import torch.nn as nn +from dreamstream.patches.conv import patch_conv_1d +from dreamstream.patches.rnn import patch_rnn +from dreamstream.patches.modes import add_streaming_modes -def offline(self): - return self.online(mode=False) +MODULE_PATCHERS = { + nn.Conv1d: patch_conv_1d, + nn.RNN: patch_rnn, + nn.LSTM: patch_rnn, + nn.GRU: patch_rnn, +} -def patch(module): - raise NotImplementedError("stream_patch is not implemented") + +def patch(module: nn.Module) -> nn.Module: + """Recursively apply `patch_module` to all modules in `module`.""" + module.apply(patch_module) + return module + + +def patch_module(module) -> None: + """Patch a given module to support streaming mode.""" + patch_method = MODULE_PATCHERS.get(type(module), None) + if patch_method is None: + add_streaming_modes(module) + else: + patch_method(module) + + # Error checking for modules that do NOT have the correct behaviour yet. + if isinstance( + module, + ( + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + # nn.LSTM, + # nn.GRU, + # nn.RNN, + nn.MultiheadAttention, + ), + ): + raise NotImplementedError(f"Module {type(module)} is not supported yet.") + + return module diff --git a/dreamstream/patches/modes.py b/dreamstream/patches/modes.py new file mode 100644 index 0000000..084c27c --- /dev/null +++ b/dreamstream/patches/modes.py @@ -0,0 +1,27 @@ +import types + +import torch.nn as nn + + +def online(self: nn.Module, mode: bool = True) -> nn.Module: + """Set the module to either online (default) or offline mode.""" + if not isinstance(mode, bool): + raise ValueError(f"Streaming `mode` was expected to be boolean but got {mode=}.") + + self.streaming = mode + for module in self.children(): + module.online(mode) + return self + + +def offline(self: nn.Module) -> nn.Module: + """Set the module to offline mode.""" + return self.online(mode=False) + + +def add_streaming_modes(module: nn.Module): + """Equip a module with the streaming mode but no additional functionality.""" + module.online = types.MethodType(online, module) + module.offline = types.MethodType(offline, module) + module.streaming = False + return module diff --git a/dreamstream/patches/rnn.py b/dreamstream/patches/rnn.py new file mode 100644 index 0000000..c8c361a --- /dev/null +++ b/dreamstream/patches/rnn.py @@ -0,0 +1,353 @@ +import collections +import logging + +from typing import Any, List, Optional, Tuple, Union + +import torch + +from dreamstream.tensor import StreamTensor +from dreamstream.patches.modes import add_streaming_modes +from dreamstream.nn.utils import pad_stream_tensor +from dreamstream.utils.flags import BATCH + + +# TODO (JDH): How do we deal with PackedSequence? +# TODO (JDH): When an initial state is provided, we use it only for the first chunk.) +# TODO (JDH): Deal with non-batched inputs and hidden states (note that when input is batched so must be the hidden state). + + +LOGGER = logging.getLogger(__file__) + + +def get_tensor_and_state( + inputs: Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if len(inputs) == 2: + return inputs + + return inputs[0], None + + +class StateStore(dict): + """A dictionary that stores states for each batch id but also supports retrieving states for individual examples.""" + + def __init__(self, *args, batch_dim: int, **kwargs): + self.batch_dim = batch_dim + super().__init__(*args, **kwargs) + + self.individual_keys = set() + self.batch_keys = set() + for k in self.keys(): + if isinstance(k, tuple): + self.batch_keys.add(k) + else: + self.individual_keys.add(k) + + def get_batch_key(self, keys: Union[Any, Tuple[Any]]): + return keys if isinstance(keys, tuple) else (keys,) + + def __contains__(self, keys: Union[Any, Tuple[Any]]): + # Return quickly if there is an non-fragmented individual or batch key match. + if keys in self.keys(): # or not isinstance(keys, tuple) and len(keys) == 1: + return True + + # If no batch keys exist, then no fragmented matches will be found. + if not self.batch_keys: + return False + + # Find individual key matches allowing some to be part of batch keys. + return self.__contains_fragmented__(keys) + + def __contains_fragmented__(self, keys: Union[Any, Tuple[Any]]): + """Check if keys exist in the StateStore in a fragmented way (i.e. split across batched and individual keys).""" + batch_key = self.get_batch_key(keys) + found_ids = 0 + for k in batch_key: + if k in self.individual_keys: + found_ids += 1 + else: + for bk in self.batch_keys: + if k in bk: + found_ids += 1 + + return found_ids == len(batch_key) + + # def __contains_fragmented__(self, keys: Union[Any, Tuple[Any]]): + # """Check if keys exist in the StateStore in a fragmented way (i.e. split across batched and individual keys).""" + # batch_key = self.get_batch_key(keys) + # individual_finds = 0 + # sub_batch_finds = 0 + # for k in batch_key: + # if k in self.individual_keys: + # individual_finds += 1 + # else: + # for bk in self.batch_keys: + # if k in bk: + # sub_batch_finds += 1 + + # return individual_finds > 0 and sub_batch_finds > 0 and individual_finds + sub_batch_finds == len(batch_key) + + def __getitem__(self, keys: Union[Any, Tuple[Any]]) -> torch.Tensor: + print(f"__getitem__: {keys}") + # Return quickly if there is an non-fragmented individual or batch key match. + if keys in self.keys(): + print(f"Pulling direct match:\n{super().__getitem__(keys)}") + return super().__getitem__(keys) + # if not isinstance(keys, tuple) and len(keys) == 1: + # return super().__getitem__(keys[0]) + + # Find individual key matches allowing some to be part of batch keys. Returns KeyError if any key is not found. + batch_key = (keys,) if isinstance(keys, str) else keys + vals = [] + for k in batch_key: + if k in self.keys(): + tensor = super().__getitem__(k) + print(f"Pulling individual:\n{tensor}") + vals.append(tensor) + else: + for bk in self.batch_keys: + if k in bk: + tensor = super().__getitem__(bk) + print(f"Pulling batch:\n{tensor}") + sample = torch.select(tensor, dim=self.batch_dim, index=bk.index(k)).unsqueeze(self.batch_dim) + # if isinstance(tensor, StreamTensor) and tensor.names[dim] != BATCH: + # sample = sample.align_to(*tensor.names) + vals.append(sample) + + if len(vals) != len(batch_key): + found_keys = tuple(k for k in batch_key if k in self.keys() or any(k in bk for bk in self.batch_keys)) + missing_keys = tuple(k for k in batch_key if k not in found_keys) + msg = f"Could not find keys {missing_keys}." + if found_keys: + msg += f" Found keys {found_keys}." + raise KeyError(msg) + + return torch.cat(vals, dim=self.batch_dim).rename(*tensor.names) + + def __setitem__(self, key: Union[Any, Tuple[Any]], value: torch.Tensor): + print(f"__setitem__: {key},\n{value}") + if self.__contains_fragmented__(key): + # We must delete the fragmented locations to make sure we correctly overwrite the state. + # TODO (JDH): Implement this + self.__deltitem_fragmented__(key) + + super().__setitem__(key, value) + if isinstance(key, tuple): + self.batch_keys.add(key) + else: + self.individual_keys.add(key) + + def __delitem__(self, keys): + # Delete quickly if there is an non-fragmented individual or batch key match. + print(f"__delitem__: {keys}") + if keys in self.keys(): + super().__delitem__(keys) + if isinstance(keys, tuple): + self.batch_keys.remove(keys) + else: + self.individual_keys.remove(keys) + + return None + + self.__deltitem_fragmented__(keys) + + def __deltitem_fragmented__(self, keys: Union[Any, Tuple[Any]]): + # print(f"{self.individual_keys=}") + # print(f"{self.batch_keys=}") + # print(f"{keys=}") + print(f"__delitem_fragmented__: {keys}") + batch_key = self.get_batch_key(keys) + individual_keys_to_delete = set() + keys_to_delete_in_batch_keys = collections.defaultdict(set) + for k in batch_key: + if k in self.individual_keys: + individual_keys_to_delete.add(k) + else: + for bk in self.batch_keys: + if k in bk: + keys_to_delete_in_batch_keys[bk].add(k) + + self.individual_keys = self.individual_keys - individual_keys_to_delete + for key_to_delete in individual_keys_to_delete: + super().__delitem__(key_to_delete) + + for bk, keys_to_delete in keys_to_delete_in_batch_keys.items(): + self.batch_keys.remove(bk) + if len(keys_to_delete) == len(bk): + # Delete entire batch key + super().__delitem__(bk) + else: + # Delete individual keys in the batch key also indexing the tensor. + # We assume that since we are deleting part of batch key, then the kept subbatch keys should be simply + # stored as individual keys. + tensor = super().__getitem__(bk) + # import IPython + # IPython.embed(using=False) + keys_to_keep = tuple(k for k in bk if k not in keys_to_delete) + tensors_to_keep = [t for k, t in zip(bk, tensor.unbind(self.batch_dim)) if k in keys_to_keep] + for k, tensor in zip(keys_to_keep, tensors_to_keep): + super().__setitem__(k, tensor.unsqueeze(self.batch_dim)) + + self.individual_keys = self.individual_keys.union(keys_to_keep) + super().__delitem__(bk) + + return None + + +class DefaultStateStore(StateStore): + def __init__(self, *args, batch_dim: int, default_state: torch.Tensor, **kwargs): + super().__init__(*args, batch_dim=batch_dim, **kwargs) + + def __getitem__(self, keys: Union[Any, Tuple[Any]]): + raise NotImplementedError() + + +StreamTensorOrPackedSequence = Union[StreamTensor, torch.nn.utils.rnn.PackedSequence] + + +def rnn_pre_hook( + self, inputs: Union[StreamTensorOrPackedSequence, Tuple[StreamTensorOrPackedSequence, Optional[torch.Tensor]]] +): + """ + + Cases: + - Given state: + 1. None + 2. Tensor + - Given ids: + 1. All are first + 2. Some are first + 3. None are first + 4. All are last + 5. Some are last + 6. None are last + 7. All are first and last + 8. Some are first and some are last + 9. None are first or last + - Hidden state store: + 1. All ids are in the hidden state store + 2. Some ids are in the hidden state store + 3. No ids are in the hidden state store + + Get hidden state (num_layers, batch_size, hidden_size). + If any of the ids are in the hidden state store, we must use that state (unless they are the first chunk, but they can't be if they are in the hidden state store). + Any first chunks must use the given initial state. + 1. If all chunks are first chunks, we can use the given initial state whether it is None or Tensor. + 2. If only some chunks are first chunks, we must use the given initial state for those chunks and the hidden + state store for the others by writing selectively. + + Args: + self (nn.RNN): The current module. + inputs (Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]): Inputs to `self.forward`. + + Raises: + RuntimeError: If input is a StreamTensor and the module is not in online mode, or if the input is not a + StreamTensor but the module is in online mode. + + Returns: + (Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]): Modified inputs to `self.forward`. + """ + x, state = get_tensor_and_state(inputs) + + input_is_packed = isinstance(x, torch.nn.utils.rnn.PackedSequence) + input_is_stream_tensor = isinstance(x.data, StreamTensor) if input_is_packed else isinstance(x, StreamTensor) + if not self.streaming: + state_is_stream_tensor = isinstance(state, StreamTensor) if state is not None else False + if input_is_stream_tensor or state_is_stream_tensor: + raise RuntimeError("Using StreamTensors in offline mode might result in unexpected behavior.") + return inputs + + if not input_is_stream_tensor: + raise RuntimeError("The input is expected to be StreamTensor when in online mode.") + + # Make the input a PackedSequence. TODO (JDH): Don't do this if input is a non-padded batch. + if input_is_packed: + ids = x.data.meta.ids + else: + # Pack the sequence and unsort the ids to either i) match the given `state` or ii) return the `state` from the + # `hidden_state_store` in the unsorted order. This will be sorted with `x.sorted_indices` inside the recurrent + # module's `forward` method. + x = torch.nn.utils.rnn.pack_padded_sequence(x, x.meta.lengths, batch_first=self.batch_first, enforce_sorted=False) + ids = tuple(x.data.meta.ids[i] for i in x.unsorted_indices) + + x.data.meta._temp_input_was_packed = input_is_packed + + if isinstance(state, StreamTensor): + state = state.tensor() + + batch_size = len(ids) + all_ids_in_state = ids in self.hidden_state_store + if not all_ids_in_state: + ids_without_state = tuple((i, _id) for i, _id in enumerate(ids) if _id not in self.hidden_state_store) + if len(ids_without_state) == batch_size and state is not None: + # No ids have state but a custom state is given. + # Use the custom state for all examples. + pass + elif len(ids_without_state) == batch_size and state is None: + # No ids have state and no custom state is given. + # Use the module-native default state by passing None state. + state = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=x.data.device, dtype=x.data.dtype) + elif len(ids_without_state) < batch_size and state is None: + # Some ids have state but no custom state is given. + # Write module-native default state into the StateStore for ids without state and gather contiguous state. + default_state = torch.zeros(self.num_layers, len(ids_without_state), self.hidden_size, device=x.data.device, dtype=x.data.dtype) + ids_without_state = tuple(_id for i, _id in ids_without_state) + self.hidden_state_store[ids_without_state] = default_state + state = self.hidden_state_store[ids] + else: + # Some ids have state but a state is given too. + # Write custom state into StateStore for ids with out state and gather contiguous state. + # TODO (JDH): Speed this up by using torch.select_scatter or similar directly on the given state. + for i, _id in ids_without_state: + self.hidden_state_store[_id] = state[:, i, :].unsqueeze(1) + state = self.hidden_state_store[ids] + # # Write default_state into the state tensor at the batch ids that do not have state. + # index = torch.as_tensor(ids_without_state, device=x.device) + # dim = 0 if self.batch_first else 1 + # torch.select_scatter(state, default_state, dim=dim, index=index) + else: + # All ids have state. Return the state. + state = self.hidden_state_store[ids] + + return x, state + + +def rnn_post_hook(self, inputs, outputs): + if not self.streaming: + return outputs + + print("post hook") + input, in_state = get_tensor_and_state(inputs) + output, out_state = get_tensor_and_state(outputs) + + meta = input.meta if isinstance(input, StreamTensor) else input.data.meta + if not meta._temp_input_was_packed: + output.data.meta = input.data.meta + output, lengths = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) + meta = output.meta + + # Ensure out_state is a proper StreamTensor. + if isinstance(out_state, StreamTensor): + out_state = out_state.tensor() + + # Store hidden states for this batch + self.hidden_state_store[meta.ids] = out_state + + # Delete hidden states of ending files. + ending_ids = tuple(_id for _id, eos in zip(meta.ids, meta.eos) if eos) + if ending_ids: + del self.hidden_state_store[ending_ids] + + return output, out_state + + +def patch_rnn(module): + add_streaming_modes(module) + + # Add a dictionary for storing previous hidden states of shape (D * num_layers, N, H). + module.hidden_state_store = StateStore(batch_dim=1) + + # Register pre_hook and post_hook. + module.register_forward_pre_hook(rnn_pre_hook) + module.register_forward_hook(rnn_post_hook) + return module diff --git a/dreamstream/tensor.py b/dreamstream/tensor.py index 652e70e..c8911ec 100644 --- a/dreamstream/tensor.py +++ b/dreamstream/tensor.py @@ -1,15 +1,21 @@ import itertools -import warnings from copy import deepcopy -from typing import Any, Callable, List, Optional, Tuple, Union, Mapping, Sequence +from typing import Any, Callable, List, Optional, Tuple, Union, Mapping, Sequence, Dict import torch import numpy as np from torch import Tensor -from dreamstream.func_coverage import DECOUPLE_FUNCTIONS, RECOUPLE_FUNCTIONS, VALID_FUNCTIONS, OVERRIDDEN_FUNCTIONS +from dreamstream.func_coverage import ( + DECOUPLE_FUNCTIONS, + RECOUPLE_FUNCTIONS, + VALID_FUNCTIONS, + CUSTOMIZED_FUNCTIONS, + DEFAULT_VALID_FUNCTIONS, + INPLACE_RECOUPLE_FUNCTIONS, +) from dreamstream.utils.flags import BATCH, LENGTH from dreamstream.utils.numba import ( is_sorted_ascending, @@ -19,46 +25,73 @@ update_eos_from_slice, update_lengths_from_list_of_indices, ) +from dreamstream import warnings # TODO (JDH): Make StreamMetadata methods like cat, split and index lazily evaluated such that they only evaluate when # they are needed. This minimizes overhead computation on StreamTensors that end up as leaf nodes in the graph. -def decouple_recursive(x, metas: Optional[List["StreamMetadata"]] = None, names: Optional[List[str]] = None): - """Recurse a nested structure and decouple all StreamTensors.""" - if isinstance(x, StreamTensor): - if metas is None and names is None: - return x.tensor() +class LazyProxy(object): + """A proxy class that lazily instantiates an object of type cls with arguments *args and **kwargs.""" - tensor, meta, name = x.decouple() - metas.append(meta) - names.append(name) - return tensor + def __init__(self, cls, *args, **kwargs): + self.__dict__["_cls"] = cls + self.__dict__["_args"] = args + self.__dict__["_kwargs"] = kwargs + self.__dict__["_obj"] = None - elif isinstance(x, Mapping): - return type(x)((k, decouple_recursive(v, metas=metas, names=names)) for k, v in x.items()) - elif isinstance(x, Sequence): - return type(x)(decouple_recursive(v, metas=metas, names=names) for v in x) + def __getattr__(self, name): + if self.__dict__["_obj"] is None: + self.__init_obj() - return x + return getattr(self.__dict__["_obj"], name) + def __setattr__(self, name, value): + if self.__dict__["_obj"] is None: + self.__init_obj() -def decouple(func, *args, **kwargs): - """Call function on tensor after decoupling it from StreamMetadata.""" - args, kwargs = decouple_recursive((args, kwargs)) - return func(*args, **kwargs) + setattr(self.__dict__["_obj"], name, value) + + def __getitem__(self, key): + if self.__dict__["_obj"] is None: + self.__init_obj() + return self.__dict__["_obj"].__getitem__(key) + + def __copy__(self): + if self.__dict__["_obj"] is None: + self.__init_obj() + + return self.__dict__["_obj"].__copy__() + + def __eq__(self, other): + if self.__dict__["_obj"] is None: + self.__init_obj() + + return self.__dict__["_obj"].__eq__(other) + + def __len__(self): + if self.__dict__["_obj"] is None: + self.__init_obj() + + return self.__dict__["_obj"].__len__() + + def __repr__(self): + if self.__dict__["_obj"] is None: + return f"LazyProxy({self.__dict__['_cls'].__name__}, {self.__dict__['_args']}, {self.__dict__['_kwargs']})" + return self.__dict__["_obj"].__repr__() + + def __init_obj(self): + self.__dict__["_obj"] = object.__new__(self.__dict__["_cls"]) + self.__dict__["_obj"].__init__(*self.__dict__["_args"], **self.__dict__["_kwargs"]) -def recouple(func, *args, **kwargs): - """Call function on tensor after recoupling it to StreamMetadata and recouple again afterwards.""" - metas, names = [], [] - args, kwargs = decouple_recursive((args, kwargs), metas=metas, names=names) - if not all(metas[0] == m for m in metas): - raise ValueError("All StreamTensors must have the same StreamMetadata.") - tensor = func(*args, **kwargs) - return as_stream_tensor(data=tensor, meta=metas[0], names=names[0]) +class LazyInit(object): + """A class that lazily initializes its attributes.""" + + def __new__(cls, *args, **kwargs): + return LazyProxy(cls, *args, **kwargs) class StreamMetadata: @@ -70,6 +103,9 @@ class StreamMetadata: "_eos", "_lengths", "_chunk_indices", + "_temp_buffer", + "_temp_names", + "_temp_input_was_packed", "_min_length", "_max_length", "_lengths_updated", @@ -84,45 +120,60 @@ class StreamMetadata: def __init__( self, - ids: Union[str, List[str]], - sos: Union[bool, List[bool], torch.BoolTensor], - eos: Union[bool, List[bool], torch.BoolTensor], - lengths: Union[int, List[int], torch.IntTensor], - chunk_indices: Optional[Union[int, List[int], torch.IntTensor]] = None, + ids: Union[str, Tuple[str]], + sos: Union[bool, Tuple[bool], torch.BoolTensor], + eos: Union[bool, Tuple[bool], torch.BoolTensor], + lengths: Union[int, Tuple[int], torch.IntTensor], + chunk_indices: Optional[Union[int, Tuple[int], torch.IntTensor]] = None, + _copy_on_init: bool = False, ): - # TODO: Make initialization lazy such that it only happens when the StreamMetadata is actually used. + super().__init__() + if isinstance(ids, str): - ids = [ids] + ids = (ids,) + elif isinstance(ids, Sequence): + ids = tuple(ids) if isinstance(lengths, int): - lengths = [lengths] + lengths = tuple(lengths) if isinstance(sos, bool): - sos = [sos] + sos = tuple(sos) if isinstance(eos, bool): - eos = [eos] + eos = tuple(eos) if not len(ids) == len(lengths) == len(sos) == len(eos): raise ValueError("ids, lengths, sos and eos must have the same length.") - sos = torch.as_tensor(sos, dtype=torch.bool) - eos = torch.as_tensor(eos, dtype=torch.bool) - lengths = torch.as_tensor(lengths, dtype=torch.int) + sos_tensor = torch.as_tensor(sos, dtype=torch.bool) + eos_tensor = torch.as_tensor(eos, dtype=torch.bool) + lengths_tensor = torch.as_tensor(lengths, dtype=torch.int) + + if _copy_on_init: + if sos_tensor is sos: + sos_tensor = sos_tensor.clone() + if eos_tensor is eos: + eos_tensor = eos_tensor.clone() + if lengths_tensor is lengths: + lengths_tensor = lengths_tensor.clone() if not all(isinstance(i, str) for i in ids): raise ValueError("ids must be a list of strings.") - if sos.ndim > 1 or eos.ndim > 1 or lengths.ndim > 1: + if lengths_tensor.ndim > 1 or eos_tensor.ndim > 1 or lengths_tensor.ndim > 1: raise ValueError("sos, eos and lengths must be 1-dimensional.") self.ids = ids - self._sos = sos - self._eos = eos - self._lengths = lengths + self._sos = sos_tensor + self._eos = eos_tensor + self._lengths = lengths_tensor self._chunk_indices = chunk_indices + self._temp_buffer = None + self._temp_names = None + self._temp_input_was_packed = None + self._min_length = None self._max_length = None self._lengths_updated = True - self._update_lengths() self._any_starting = None self._any_ending = None @@ -131,7 +182,6 @@ def __init__( self._any_starting_or_ending = None self._all_starting_and_ending = None self._sos_or_eos_updated = True - self._update_logicals() @property def sos(self) -> torch.BoolTensor: @@ -263,13 +313,15 @@ def __copy__(self): new_meta.__dict__.update(self.__dict__) return new_meta - def __deepcopy__(self): + def __deepcopy__(self, memo: Optional[dict] = None): """Return a deep copy of the StreamMetadata object.""" return StreamMetadata( - ids=deepcopy(self.ids), - sos=self.sos.clone(), - eos=self.eos.clone(), - lengths=self.lengths.clone(), + ids=self.ids, + sos=self.sos, + eos=self.eos, + lengths=self.lengths, + chunk_indices=self._chunk_indices, + _copy_on_init=True, ) def __eq__(self, other: "StreamMetadata") -> bool: @@ -279,6 +331,10 @@ def __eq__(self, other: "StreamMetadata") -> bool: and self.sos.equal(other.sos) and self.eos.equal(other.eos) and self.lengths.equal(other.lengths) + and ( + (self._chunk_indices is None and other._chunk_indices is None) + or self._chunk_indices.equal(other.lengths) + ) ) def __len__(self) -> int: @@ -309,7 +365,6 @@ def __repr__(self) -> str: i += 1 short_ids_repr = ", ".join(repr_ids) + ", ..., " + repr(last) - print(short_ids_repr, len(short_ids_repr)) else: short_ids_repr = repr(self.ids) @@ -366,28 +421,31 @@ def index_batch( raise IndexError(f"Expected batch indices to be a 1-dimensional tensor, but got {indices.ndim} dimensions.") if isinstance(indices, torch.BoolTensor): - ids = [id for i, id in enumerate(self.ids) if indices[i]] + ids = tuple(id for i, id in enumerate(self.ids) if indices[i]) sos = self.sos[indices] eos = self.eos[indices] lengths = self.lengths[indices] - return StreamMetadata(ids, sos, eos, lengths) + chunk_indices = self._chunk_indices[indices] if self._chunk_indices is not None else None + return StreamMetadata(ids, sos, eos, lengths, chunk_indices) if isinstance(indices, int): - ids = [self.ids[indices]] - sos = self.sos[[indices]] - eos = self.eos[[indices]] - lengths = self.lengths[[indices]] - return StreamMetadata(ids, sos, eos, lengths) + ids = (self.ids[indices],) + sos = self.sos[indices].unsqueeze_(0) + eos = self.eos[indices].unsqueeze_(0) + lengths = self.lengths[indices].unsqueeze_(0) + chunk_indices = self._chunk_indices[indices].unsqueeze_(0) if self._chunk_indices is not None else None + return StreamMetadata(ids, sos, eos, lengths, chunk_indices) if isinstance(indices, slice): ids = self.ids[indices] else: # List[int], Tuple[int, ...], torch.IntTensor - ids = [self.ids[i] for i in indices] + ids = tuple(self.ids[i] for i in indices) sos = self.sos[indices] eos = self.eos[indices] lengths = self.lengths[indices] - return StreamMetadata(ids, sos, eos, lengths) + chunk_indices = self._chunk_indices[indices] if self._chunk_indices is not None else None + return StreamMetadata(ids, sos, eos, lengths, chunk_indices) def index_length( self, indices: Union[None, int, slice, List[int], Tuple[int], torch.IntTensor, torch.BoolTensor] @@ -426,16 +484,18 @@ def index_batch_and_length(self, indices: torch.BoolTensor) -> "StreamMetadata": return self.__deepcopy__() if not keep_ids.all(): - ids = [id for i, id in enumerate(self.ids) if keep_ids[i]] + ids = tuple(id for i, id in enumerate(self.ids) if keep_ids[i]) sos = self.sos[keep_ids] eos = self.eos[keep_ids] lengths = self.lengths[keep_ids] + chunk_indices = self._chunk_indices[keep_ids] if self._chunk_indices is not None else None indices = indices[keep_ids] else: # includes `broadcast_batch == True` ids = deepcopy(self.ids) sos = self.sos eos = self.eos lengths = self.lengths + chunk_indices = self._chunk_indices sos = sos & indices[:, 0] # SOS only if the first index is included. if not broadcast_length: @@ -445,8 +505,9 @@ def index_batch_and_length(self, indices: torch.BoolTensor) -> "StreamMetadata": eos = eos & (start < lengths) & (lengths <= stop) # eos = eos & indices[range(indices.size(0)), lengths - 1] # EOS only if the last non-padding is included. lengths = cumsum[keep_ids, lengths - 1] + chunk_indices = chunk_indices[keep_ids] if chunk_indices is not None else None - return StreamMetadata(ids, sos, eos, lengths) + return StreamMetadata(ids, sos, eos, lengths, chunk_indices) def _index_length_int(self, index: int) -> "StreamMetadata": # Convert negative indices to positive @@ -459,7 +520,8 @@ def _index_length_int(self, index: int) -> "StreamMetadata": # TODO (JDH): numba compiled arithmetic is much faster but slowed down due to conversion to/from numpy # Maybe we should store sos and eos as numpy arrays instead of torch tensors? eos = torch.from_numpy(update_eos_from_integer(self.eos.numpy(), self.lengths.numpy(), index)) - return StreamMetadata(deepcopy(self.ids), sos, eos, lengths) + chunk_indices = self._chunk_indices.clone() if self._chunk_indices is not None else None + return StreamMetadata(deepcopy(self.ids), sos, eos, lengths, chunk_indices) def _index_length_slice(self, slice: slice) -> "StreamMetadata": # Convert start and stop to positive indices @@ -480,7 +542,8 @@ def _index_length_slice(self, slice: slice) -> "StreamMetadata": sos = self.sos.clone() if start == 0 and stop > 0 else torch.zeros_like(self.sos) eos = torch.from_numpy(update_eos_from_slice(self.eos.numpy(), self.lengths.numpy(), start, stop)) - return StreamMetadata(deepcopy(self.ids), sos, eos, lengths) + chunk_indices = self._chunk_indices.clone() if self._chunk_indices is not None else None + return StreamMetadata(deepcopy(self.ids), sos, eos, lengths, chunk_indices) def _index_length_list(self, indices: Union[List[int], Tuple[int]]) -> "StreamMetadata": # Convert to numpy arrays for faster manipulation and numba jit support. @@ -498,7 +561,8 @@ def _index_length_list(self, indices: Union[List[int], Tuple[int]]) -> "StreamMe sos = self.sos.clone() if min_i == 0 else torch.zeros_like(self.sos) eos = torch.from_numpy(update_eos_from_slice(self.eos.numpy(), lengths_np, min_i, max_i)) # TODO (JDH): Keep EOS true if the indexing spans over the last non-padding element. - return StreamMetadata(deepcopy(self.ids), sos, eos, lengths) + chunk_indices = self._chunk_indices.clone() if self._chunk_indices is not None else None + return StreamMetadata(deepcopy(self.ids), sos, eos, lengths, chunk_indices) def _index_length_1d_tensor(self, indices: torch.Tensor) -> "StreamMetadata": if indices.dtype == torch.bool: @@ -532,15 +596,21 @@ def cat_batch(cls, metas: List["StreamMetadata"]) -> "StreamMetadata": StreamMetadata: The concatenated StreamMetadata object. """ + if not all(isinstance(s, StreamMetadata) for s in metas): + raise TypeError("All objects in list must be of type StreamMetadata.") + if len(metas) == 1: return deepcopy(metas[0]) - assert all(isinstance(s, StreamMetadata) for s in metas) - ids = list(itertools.chain.from_iterable([s.ids for s in metas])) + ids = tuple(itertools.chain.from_iterable(s.ids for s in metas)) sos = torch.cat([s.sos for s in metas], dim=0) eos = torch.cat([s.eos for s in metas], dim=0) lengths = torch.cat([s.lengths for s in metas], dim=0) - return cls(ids, sos, eos, lengths) + if all(s.chunk_indices is not None for s in metas): + chunk_indices = torch.cat([s.chunk_indices for s in metas], dim=0) + else: + chunk_indices = None + return cls(ids, sos, eos, lengths, chunk_indices) @classmethod def cat_length(cls, metas: List["StreamMetadata"]) -> "StreamMetadata": @@ -568,7 +638,11 @@ def cat_length(cls, metas: List["StreamMetadata"]) -> "StreamMetadata": sos = metas[0].sos.clone() eos = metas[-1].eos.clone() lengths = sum([s.lengths for s in metas]) - return cls(ids, sos, eos, lengths) + if all(s.chunk_indices is not None for s in metas): + chunk_indices = metas[-1].chunk_indices.clone() # TODO (JDH): This assumes the right-most chunk is the last + else: + chunk_indices = None + return cls(ids, sos, eos, lengths, chunk_indices) def split(self, split_size_or_sections: Union[int, List[int]], dim: str) -> List["StreamMetadata"]: """Split a StreamMetadata object into a list of StreamMetadata objects along a given dimension.""" @@ -592,17 +666,21 @@ def split_batch(self, split_size_or_sections: Union[int, List[int]]) -> List["St if isinstance(split_size_or_sections, int): start = range(0, len(self), split_size_or_sections) - split_ids = [self.ids[i : i + split_size_or_sections] for i in start] + split_ids = tuple(self.ids[i : i + split_size_or_sections] for i in start) else: slices = np.cumsum([0] + split_size_or_sections) - split_ids = [self.ids[i:j] for i, j in zip(slices[:-1], slices[1:])] + split_ids = tuple(self.ids[i:j] for i, j in zip(slices[:-1], slices[1:])) - split_first = self.sos.split(split_size_or_sections) - split_last = self.eos.split(split_size_or_sections) + split_sos = self.sos.split(split_size_or_sections) + split_eos = self.eos.split(split_size_or_sections) split_lengths = self.lengths.split(split_size_or_sections) - args_iter = zip(split_ids, split_first, split_last, split_lengths) + if self.chunk_indices is not None: + split_chunk_indices = self.chunk_indices.split(split_size_or_sections) + else: + split_chunk_indices = [None] * len(split_ids) - return [stream_metadata(*args) for args in args_iter] + args_iter = zip(split_ids, split_sos, split_eos, split_lengths, split_chunk_indices) + return [StreamMetadata(*args) for args in args_iter] def split_length(self, split_size_or_sections: Union[int, List[int]]) -> List["StreamMetadata"]: """Split a StreamMetadata object into a list of StreamMetadata objects along the length dimension. @@ -654,19 +732,19 @@ def __init__(self) -> None: def stream_metadata( - ids: Union[str, List[str]], - sos: Union[bool, List[bool], torch.BoolTensor], - eos: Union[bool, List[bool], torch.BoolTensor], - lengths: Union[int, List[int], torch.IntTensor], - chunk_indices: Optional[Union[int, List[int], torch.IntTensor]] = None, + ids: Union[str, Tuple[str]], + sos: Union[bool, Tuple[bool], torch.BoolTensor], + eos: Union[bool, Tuple[bool], torch.BoolTensor], + lengths: Union[int, Tuple[int], torch.IntTensor], + chunk_indices: Optional[Union[int, Tuple[int], torch.IntTensor]] = None, ) -> StreamMetadata: """Create a StreamMetadata object from the given arguments. Args: - ids (Union[str, List[str]]): The ids of the input tensors. + ids (Union[str, Tuple[str]]): The ids of the input tensors. sos (bool): Whether the input tensors are the first in a batch. eos (bool): Whether the input tensors are the last in a batch. - lengths (Union[int, List[int]]): The lengths of the input tensors. + lengths (Union[int, Tuple[int]]): The lengths of the input tensors. chunk_indices (int): The index of the chunk in the batch. num_chunks (Optional[int], optional): The number of chunks in the batch. Defaults to None. @@ -688,12 +766,10 @@ def __new__(cls, data, meta: StreamMetadata, *args, **kwargs) -> "StreamTensor": """Return a new StreamTensor object.""" return super().__new__(cls, data, *args, **kwargs) - def __init__(self, data, meta: StreamMetadata, *args, names: List[str] = None, **kwargs): + def __init__(self, data, meta: StreamMetadata, *args, **kwargs): """Initialize a StreamTensor object (self is StreamTensor, data is e.g. torch.Tensor).""" super().__init__() self.meta = meta - if names is not None: - self.rename_(*names) def __getstate__(self) -> Tuple[torch.Tensor, StreamMetadata, List[str]]: """Return the state of the StreamTensor object.""" @@ -718,45 +794,65 @@ def __torch_function__(cls, func: Callable, types: List[torch.Tensor], args=(), """ if kwargs is None: kwargs = dict() + + if func in VALID_FUNCTIONS: + return super().__torch_function__(func, types, args, {}) - if func in OVERRIDDEN_FUNCTIONS: - # print(f"\n\n{func.__name__}: STREAM_TENSOR_FUNCTIONS\n\n") - return OVERRIDDEN_FUNCTIONS[func](*args, **kwargs) - - if func in RECOUPLE_FUNCTIONS: - # print(f"\n\n{func.__name__}: RECOUPLE_FUNCTIONS\n\n") - return recouple(func, *args, **kwargs) + if func in CUSTOMIZED_FUNCTIONS: + return CUSTOMIZED_FUNCTIONS[func](*args, **kwargs) if func in DECOUPLE_FUNCTIONS: - # print(f"\n\n{func.__name__}: DECOUPLE_FUNCTIONS\n\n") return decouple(func, *args, **kwargs) - if func in VALID_FUNCTIONS: - # print(f"\n\n{func.__name__}: VALID_FUNCTIONS\n\n") - return super().__torch_function__(func, types, args, kwargs) + if func in RECOUPLE_FUNCTIONS: + return recouple(func, *args, **kwargs) + + if func in INPLACE_RECOUPLE_FUNCTIONS: + return inplace_recouple(func, *args, **kwargs) # Unhandled functions are passed to the torch.Tensor.__torch_function__ method. - warnings.warn( - f"Function `{func.__name__}` is not handled by `StreamTensor.__torch_function__` " - "and may not work as expected." - ) + if func not in DEFAULT_VALID_FUNCTIONS: + warnings.warn( + f"Function {func.__name__} is not handled by StreamTensor.__torch_function__ " + f"and may not work as expected." + ) + + return cls.default_valid(func, types, args, kwargs) + + @classmethod + def default_valid(cls, func, types, args, kwargs): out = super().__torch_function__(func, types, args, kwargs) - metas = [x.meta for x in [*args, *kwargs.values()] if isinstance(x, StreamTensor)] + metas = [x.meta for x in [*args, *kwargs.values()] if isinstance(x, StreamTensor)] # TODO (JDH): Make recursive if not all(s == metas[0] for s in metas[1:]): msg = ( f"Called a torch function ({func.__name__}) which was not handled by " - f"StreamTensor.__torch_function__ with {len(metas)} StreamTensors in the input." + f"StreamTensor.__torch_function__ with {len(metas)} StreamTensors in the input. " f"In this case the function can only be handled if the StreamTensors have equal metadata," f"but they were not equal." ) raise RuntimeError(msg) - if isinstance(out, torch.Tensor): + if isinstance(out, StreamTensor): + out.meta = metas[0] + return out + elif isinstance(out, torch.Tensor): return StreamTensor(out, meta=metas[0]) return out + @property + def real(self): + return torch.real(self) + + @property + def imag(self): + return torch.imag(self) + + @property + def T(self): + return self.permute(*reversed(range(self.ndim))) + @property def has_batch_dim(self) -> bool: return BATCH in self.names @@ -802,10 +898,8 @@ def drop_empty(self) -> "StreamTensor": return None if len(self.meta) == 1 and self.meta.max_length > 0: return self - tensor, meta, names = self.decouple() - batch_dim = names.index(BATCH) - tensor = torch.index_select(tensor, batch_dim, meta.lengths.nonzero().squeeze()) - return as_stream_tensor(data=tensor, meta=meta.drop_empty(), names=names) + + return torch.index_select(self, self.batch_dim, self.meta.lengths.nonzero().squeeze()) def named_tensor(self) -> torch.Tensor: """Return the underlying torch.Tensor with names.""" @@ -819,40 +913,150 @@ def unpad_sequence(self, keep_names: bool = False) -> List["StreamTensor"]: length_dim -= 1 return [x.narrow(length_dim, 0, x.meta.lengths.item()) for x in self.unbind(dim=batch_dim)] - def decouple(self, copy_meta: bool = False) -> Tuple[Tensor, StreamMetadata, Tuple[str]]: + def decouple(self, copy_meta: bool = False, keep_names: bool = False) -> Tuple[Tensor, StreamMetadata, Tuple[str]]: """Decouple the StreamTensor from names and metadata.""" meta = self.meta.clone() if copy_meta else self.meta - return self.tensor(), meta, self.names + return self.tensor(keep_names=keep_names), meta, self.names def as_stream_tensor( data, - meta: StreamMetadata, - names: Tuple[Union[None, int]] = None, + meta: StreamMetadata = None, + names: Tuple[Union[None, str]] = None, dtype: torch.dtype = None, device: torch.device = None, ) -> StreamTensor: """Convert a tensor to a StreamTensor. See also `torch.as_tensor`.""" data = torch.as_tensor(data, dtype=dtype, device=device) - if names: + if names is not None: data = data.refine_names(*names) # Make the tensor named if it isn't already. + if meta is None: + meta = data.meta # If meta is not given, assume it is already on the input data. return StreamTensor(data=data, meta=meta) def stream_tensor( data, meta: StreamMetadata, - names: Tuple[Union[None, int]], + names: Tuple[Union[None, str]], dtype: torch.dtype = None, device: torch.device = None, requires_grad: bool = False, pin_memory: bool = False, ) -> StreamTensor: """Convert a tensor to a StreamTensor. See also `torch.tensor`.""" - if isinstance(data, torch.Tensor) and data.names != names: - data = data.rename(*names) + try: + data = torch.tensor( + data, names=names, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory + ) + except RuntimeError: + data = torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) + data.rename_(*names) - data = torch.tensor( - data, names=names, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory - ) return StreamTensor(data=data, meta=meta) + + +class TestTensor(torch.Tensor): + @staticmethod + def __new__(cls, data, meta, *args, **kwargs) -> "TestTensor": + """Return a new StreamTensor object.""" + return super().__new__(cls, data, *args, **kwargs) + + def __init__(self, data, meta, *args, names: List[str] = None, **kwargs): + """Initialize a StreamTensor object (self is StreamTensor, data is e.g. torch.Tensor).""" + super(TestTensor).__init__() + self.meta = meta + + def tensor(self, keep_names: bool = False) -> torch.Tensor: + """Return the underlying torch.Tensor.""" + tensor = torch.Tensor(self) # 1-2 µs + if not keep_names: + tensor.rename_(None) # 2-3 µs + return tensor + + def clone(self, *args, **kwargs): + """Clone a StreamTensor object.""" + return TestTensor(super().clone(*args, **kwargs), self.meta) + + +def decouple_recursive(x, metas: Optional[List["StreamMetadata"]] = None, names: Optional[List[str]] = None): + """Recurse a nested structure and decouple all StreamTensors.""" + if isinstance(x, StreamTensor): + if metas is None and names is None: + return x.tensor() + + tensor, meta, name = x.decouple() + metas.append(meta) + names.append(name) + return tensor + + elif isinstance(x, Mapping): + return type(x)((k, decouple_recursive(v, metas=metas, names=names)) for k, v in x.items()) + elif isinstance(x, str): # Must handle strings before Sequence, because strings are sequences. + return x + elif isinstance(x, Sequence): + return type(x)(decouple_recursive(v, metas=metas, names=names) for v in x) + + return x + + +# TODO (JDH): Use decouple_recursive instead of the two/three for loops below. + +# def decouple(func, *args, **kwargs): +# """Call function on tensor after decoupling it from StreamMetadata.""" +# args, kwargs = decouple_recursive((args, kwargs)) +# return func(*args, **kwargs) + + +# def recouple(func, *args, **kwargs): +# """Call function on tensor after recoupling it to StreamMetadata and recouple again afterwards.""" +# metas, names = [], [] +# args, kwargs = decouple_recursive((args, kwargs), metas=metas, names=names) +# if not all(metas[0] == m for m in metas): +# raise ValueError("All StreamTensors must have the same StreamMetadata.") + +# tensor = func(*args, **kwargs) +# return as_stream_tensor(data=tensor, meta=metas[0], names=names[0]) + + +def decouple(func, *args, _keep_names=False, _tensor_type=StreamTensor, **kwargs): + """Call function on tensor after decoupling it from StreamMetadata and names.""" + args = [arg.tensor(keep_names=_keep_names) if isinstance(arg, _tensor_type) else arg for arg in args] + kwargs = {k: v.tensor(keep_names=_keep_names) if isinstance(v, _tensor_type) else v for k, v in kwargs.items()} + return func(*args, **kwargs) + + +def recouple(func, *args, _tensor_type=StreamTensor, **kwargs): + """Call function on tensor after decoupling it from StreamMetadata and names and recouple again afterwards.""" + meta_list = [x.meta for x in [*args, *kwargs.values()] if isinstance(x, _tensor_type)] + names_list = [x.names for x in [*args, *kwargs.values()] if isinstance(x, _tensor_type)] + if not (all(m == meta_list[0] for m in meta_list[1:]) and all(n == names_list[0] for n in names_list[1:])): + raise RuntimeError("StreamTensor arguments must have the same metadata and names.") + out = decouple(func, *args, _tensor_type=_tensor_type, **kwargs) + out = _tensor_type(out, meta_list[0]).rename_(*names_list[0]) + return out + + +def inplace_recouple(func, tensor, *args, _tensor_type=StreamTensor, **kwargs): + """Call an in-place function on tensor after decoupling it from StreamMetadata and names, return the original.""" + decouple(func, tensor, *args, _tensor_type=_tensor_type, **kwargs) # Inplace operation by func on tensor. + return tensor + + +def decouple_recursive(x, metas: Optional[List["StreamMetadata"]] = None, names: Optional[List[str]] = None): + """Recurse a nested structure and decouple all StreamTensors.""" + if isinstance(x, StreamTensor): + if metas is None and names is None: + return x.tensor() + + tensor, meta, name = x.decouple() + metas.append(meta) + names.append(name) + return tensor + + elif isinstance(x, Mapping): + return type(x)((k, decouple_recursive(v, metas=metas, names=names)) for k, v in x.items()) + elif isinstance(x, Sequence): + return type(x)(decouple_recursive(v, metas=metas, names=names) for v in x) + + return x diff --git a/dreamstream/tests/test_conv1d.py b/dreamstream/tests/test_conv1d.py deleted file mode 100644 index ca608e0..0000000 --- a/dreamstream/tests/test_conv1d.py +++ /dev/null @@ -1,47 +0,0 @@ -from random import randint -from uuid import uuid4 - -import torch -from torch import nn - -from dreamstream.utils.flags import LENGTH -from dreamstream.nn.utils import pad_full_sequence -from dreamstream.patches import patch_conv_1d -from dreamstream.data import OutputCollector - - -def random_chunks(full_length): - chunks = [] - chunk_sum, remaining = 0, full_length - while remaining > 0: - chunks.append(min(randint(7, 200), remaining)) - chunk_sum = sum(chunks) - remaining = full_length - chunk_sum - return chunks - - -conv = nn.Conv1d(256, 128, 7, padding=3) -conv = patch_conv_1d(conv) - -# TEST 1: Test with multiple streams of different lengths. -sequences = [torch.rand(256, randint(50, 2000)) for i in range(32)] -ids = [str(uuid4()) for i in range(32)] -targets = {_id: conv(s) for _id, s in zip(ids, sequences)} -batch = pad_full_sequence(sequences, names=("F", LENGTH), ids=ids).align_to("B", "F", "L") -chunks = random_chunks(batch.size("L")) - -stream_output = OutputCollector() -conv.online() -for x in batch.split(chunks, dim=2): - x = x.drop_empty() - try: - y = conv(x) - except Exception: - import IPython - - IPython.embed(using=False) - stream_output.update(y) - -for _id, _y in targets.items(): - y = stream_output[_id].tensor() - print(torch.allclose(_y, y), (_y - y).abs().max().item()) diff --git a/dreamstream/tests/test_conv1d_2.py b/dreamstream/tests/test_conv1d_2.py deleted file mode 100644 index f7b35c4..0000000 --- a/dreamstream/tests/test_conv1d_2.py +++ /dev/null @@ -1,58 +0,0 @@ -from random import randint -from uuid import uuid4 -import random - -import torch -from torch import nn - -from dreamstream.utils.flags import LENGTH -from dreamstream.nn.utils import pad_full_sequence, pad_stream_tensor -from dreamstream.patches import patch_conv_1d -from dreamstream.data import OutputCollector - - -def random_chunks(full_length): - chunks = [] - chunk_sum, remaining = 0, full_length - while remaining > 0: - chunks.append(min(randint(7 - 3, 100), remaining)) - chunk_sum = sum(chunks) - remaining = full_length - chunk_sum - return chunks - - -conv = nn.Conv1d(256, 128, 7, stride=6, padding=3) -conv = patch_conv_1d(conv) - -# TEST 1: Test with multiple streams of different lengths. -sequences = [torch.rand(256, randint(50, 2000)) for i in range(32)] - -ids = [str(uuid4()) for i in range(32)] -targets = {_id: conv(s) for _id, s in zip(ids, sequences)} -original_sequences = {_id: s for _id, s in zip(ids, sequences)} - -data = pad_full_sequence(sequences, names=("F", LENGTH), ids=ids).align_to("B", "F", "L") -data = data.unpad_sequence() -data = {_id: s.split(random_chunks(s.size("L")), dim=1) for _id, s in zip(ids, data)} - - -def remaining_chunks(data): - return sum([len(x) for x in data.values()]) - - -batches = [] -while remaining_chunks(data) > 0: - batch = [s.pop(0) for _id, s in data.items() if len(s) > 0 and random.random() < 0.75] - if len(batch) > 0: - batches.append(pad_stream_tensor(batch).align_to("B", "F", "L")) - -stream_output = OutputCollector() -conv.online() -for x in batches: - y = conv(x) - stream_output.update(y) - -for _id, _y in targets.items(): - y = stream_output[_id].tensor() - abs_diff = (_y - y).abs() - print(y.size(-1), torch.allclose(_y, y), abs_diff.max().item(), abs_diff.max(0).values[:10].max().item()) diff --git a/dreamstream/utils/listloaders.py b/dreamstream/utils/listloaders.py new file mode 100644 index 0000000..22563d1 --- /dev/null +++ b/dreamstream/utils/listloaders.py @@ -0,0 +1,35 @@ +import os +from glob import glob + +import torch + + +lists_path = __file__.replace("/dreamstream/utils/listloaders.py", "/doc_scrape/lists") + + +def get_tensor_attr(x): + return getattr(torch.Tensor, x.replace("Tensor.", "")) if x.startswith("Tensor.") else getattr(torch, x) + + +def load_default_valid_pointwise_ops(): + with open(glob(os.path.join(lists_path, "default-valid-pointwise-ops-*.txt"))[0], "r") as file_buffer: + ops_list = {get_tensor_attr(f) for f in file_buffer.read().split("\n")} + return ops_list + + +def load_valid_pointwise_ops(): + with open(glob(os.path.join(lists_path, "valid-pointwise-ops-*.txt"))[0], "r") as file_buffer: + ops_list = {get_tensor_attr(f) for f in file_buffer.read().split("\n")} + return ops_list + + +def load_recouple_pointwise_ops(): + with open(glob(os.path.join(lists_path, "recouple-pointwise-ops-*.txt"))[0], "r") as file_buffer: + ops_list = {get_tensor_attr(f) for f in file_buffer.read().split("\n")} + return ops_list + + +def load_inplace_recouple_pointwise_ops(): + with open(glob(os.path.join(lists_path, "inplace-recouple-pointwise-ops-*.txt"))[0], "r") as file_buffer: + ops_list = {get_tensor_attr(f) for f in file_buffer.read().split("\n")} + return ops_list diff --git a/dreamstream/utils/operations.py b/dreamstream/utils/operations.py new file mode 100644 index 0000000..7ba49d0 --- /dev/null +++ b/dreamstream/utils/operations.py @@ -0,0 +1,43 @@ +import math + +from typing import Union + +import torch + + +def sequence_mask( + seq_lens: Union[list, torch.Tensor], + max_len: int = None, + invert: bool = False, + dtype: torch.dtype = torch.bool, + device: torch.device = None, +): + """ + Creates a binary sequence mask where all entries up to seq_lens are 1 and the remaining are 0. + + Args: + seq_lens (Tensor): The sequence lengths from which to construct the mask. Should be shape N with dtype == int64. + max_len (int): The temporal dimension of the sequence mask. If None, will use max of seq_lens. + dtype (torch.dtype): The type of the mask. Default is torch.bool. + invert (bool): If False, `m[i]` is `True` for `i < x_sl` and False for `i >= x_sl`. + If True, returns the inverse i.e. `~m`. Default is False. + Returns: + Tensor: The sequence mask of shape (N, T). + """ + if isinstance(seq_lens, torch.Tensor): + device = seq_lens.device if device is None else device + if device != seq_lens.device: + seq_lens = seq_lens.to(device) + else: + seq_lens = torch.tensor(seq_lens, device=device, dtype=int) + + T = max_len or math.ceil(seq_lens.max()) + + step_ids = torch.arange(T, device=device).unsqueeze(0) # (1, T) + + if invert: + seq_mask = step_ids >= seq_lens.unsqueeze(1) # broadcast over batch, (N, T) + return seq_mask.to(dtype) + + seq_mask = step_ids < seq_lens.unsqueeze(1) # broadcast over batch, (N, T) + return seq_mask.to(dtype) diff --git a/dreamstream/utils/timing.py b/dreamstream/utils/timing.py new file mode 100644 index 0000000..d2b19c6 --- /dev/null +++ b/dreamstream/utils/timing.py @@ -0,0 +1,124 @@ +import warnings +import timeit as timeit_module + +from types import SimpleNamespace +from typing import Callable, Optional, Union + +import numpy as np +import rich + + +UNITS = {"ns": 1e-9, "µs": 1e-6, "ms": 1e-3, "s": 1.0} + + +def format_time(dt, unit=None, precision=4): + """Format a time in seconds by rescaling and appending appropriate unit + + The returned string will always have length at most 5 (precision) + 2 (unit) + 1 (space) = 7 + """ + if unit is not None: + scale = UNITS[unit] + else: + scales = [(scale, unit) for unit, scale in UNITS.items()] + scales.sort(reverse=True) + for scale, unit in scales: + if dt >= scale: + break + + return "%.*g%s" % (precision, dt / scale, unit) + + +def timeit( + statement: Union["str", Callable], + setup: str = "pass", + timer=timeit_module.default_timer, + globals: dict = None, + inner_duration: float = 0.2, + repeats: int = 10, + number: Optional[int] = None, + print_results: bool = False, + print_suffix: str = None, +): + r"""Time the execution of `statement` using the `timeit` package similar to the IPython magic `%timeit`. + + Example:: + + import random + + def a(): + return random.random() + random.random() + + timeit(a) + + >>> namespace(min=1.844983547925949e-07, + max=2.3509945720434188e-07, + mean=1.979972431436181e-07, + median=1.9188937079161407e-07, + std=1.4883481745944718e-08, + number=1000000, + repeats=10) + + Args: + statement (Union[str, Callable]): Statement to `exec` or a callable. + setup (str, optional): Any setup steps to perform before repeats. Defaults to "pass". + timer (optional): Timer to use. Defaults to timeit_module.default_timer. + globals (dict, optional): Namespace for timing. Defaults to None. + inner_duration (float, optional): Minimum number of seconds to use iterating over `statement`. Defaults to 0.2. + repeats (int, optional): Number of times to repeat the inner loop. + number (int, optional): Overrules the `inner_duration` argument and directly sets the `number` of inner iters. + + Returns: + SimpleNamespace: Namespace with min, max, mean, median, std, number and repeats attributes + """ + timer = timeit_module.Timer(statement, setup=setup, timer=timer, globals=globals) + + if number is None: + # Autorange twice to overcome overhead + number, time_taken = timer.autorange() + number, time_taken = timer.autorange() + + multiplier = inner_duration // time_taken + 1 + number = int(multiplier * number) + + # Time + timings = timer.repeat(repeat=repeats, number=number) + timings = np.array(timings) / number + + # Collect results + min = np.min(timings) + max = np.max(timings) + mean = np.mean(timings) + median = np.median(timings) + std = np.std(timings) + + results = SimpleNamespace(min=min, max=max, mean=mean, median=median, std=std, number=number, repeats=repeats) + + if max >= min * 4: + warnings.warn_explicit( + "The test results are likely unreliable. " + "The worst time (%s) was more than four times " + "slower than the best time (%s)." % (format_time(max), format_time(min)), + UserWarning, + "", + 0, + ) + + if print_results: + report_timings(results, suffix=str(statement) if print_suffix is None else print_suffix) + + return results + + +def report_timings(timings, prefix: Optional[str] = None, suffix: Optional[str] = None): + if prefix: + s = f"{prefix:15s} | " + else: + s = "" + + s += f"number={timings.number:>5d} | [{format_time(timings.min):>8s}, {format_time(timings.max):>8s}] | {format_time(timings.median):>8s} | {format_time(timings.mean):>8s} +- {format_time(timings.std):>8s}" + + if suffix: + s += f" | {suffix}" + + rich.print(s) + return s diff --git a/dreamstream/warnings.py b/dreamstream/warnings.py index d860138..91a41ab 100644 --- a/dreamstream/warnings.py +++ b/dreamstream/warnings.py @@ -5,6 +5,10 @@ class TorchStreamWarning(RuntimeWarning): pass +def warn(message: str): + warnings.warn(message, TorchStreamWarning, stacklevel=2) + + class TorchStreamFallbackWarning(TorchStreamWarning): pass @@ -20,7 +24,7 @@ class TorchStreamFallbackWarning(TorchStreamWarning): "open an issue on GitHub." ) -REASON_METADATA_AMBIGUOUS = "since combining the StreamMetadata of the inputs is ambiguous" +DESCRIPTION_METADATA_AMBIGUOUS = "since combining the StreamMetadata of the inputs is ambiguous" def fallback_operation_warning(operation: str, description: str = ""): diff --git a/examples/wav2letter.py b/examples/wav2letter.py new file mode 100644 index 0000000..aaa18b6 --- /dev/null +++ b/examples/wav2letter.py @@ -0,0 +1,73 @@ +import random +import torch + +from torchaudio.models import Wav2Letter + +from dreamstream.patches.general import patch +from dreamstream.utils.flags import LENGTH +from dreamstream.nn.utils import pad_full_sequence, pad_stream_tensor +from dreamstream.data import OutputCollector + + +def random_chunks(full_length, min_size: int = 1000, max_size: int = 8000): + chunks = [] + chunk_sum, remaining = 0, full_length + while remaining > 0: + chunks.append(min(random.randint(min_size, max_size), remaining)) + chunk_sum = sum(chunks) + remaining = full_length - chunk_sum + return chunks + + +def run(): + random.seed(42) + torch.manual_seed(42) + + BATCH_SIZE = 32 + + model = Wav2Letter(num_classes=40, input_type="waveform") + model = patch(model) + + # Test with multiple streams of different lengths. + sequences = [torch.rand(1, random.randint(16000 * 5, 16000 * 30)) for i in range(BATCH_SIZE)] + + ids = [str(hash(i)) for i in range(BATCH_SIZE)] + targets = {_id: model(s.unsqueeze(0)) for _id, s in zip(ids, sequences)} + + data = pad_full_sequence(sequences, names=("F", LENGTH), ids=ids).align_to("B", "F", "L") + data = data.unpad_sequence() + data = { + _id: s.split(random_chunks(s.size("L"), min_size=16000, max_size=32000), dim=1) for _id, s in zip(ids, data) + } + + def remaining_chunks(data): + return sum([len(x) for x in data.values()]) + + batches = [] + while remaining_chunks(data) > 0: + batch = [s.pop(0) for _id, s in data.items() if len(s) > 0 and random.random() < 0.75] + if len(batch) > 0: + batches.append(pad_stream_tensor(batch).align_to("B", "F", "L")) + + stream_output = OutputCollector() + model.online() + ys = [] + for x in batches: + y = model(x) + ys.append(y) + stream_output.update(y) + + for _id, _y in targets.items(): + y = stream_output[_id].tensor() + abs_diff = (_y - y).abs() + print( + _id, + y.shape, + torch.allclose(_y, y, atol=1e-6), + abs_diff.max().item(), + abs_diff.max(0).values[:10].max().item(), + ) + + +if __name__ == "__main__": + run() diff --git a/pyproject.toml b/pyproject.toml index 73ecefc..9a3fbf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,22 @@ +[project] +name = "TorchStream" +version = "0.1.0" +authors = [ + {name = "Jakob Drachmann Havtorn", email = "jdh@corti.ai"}, + {name = "Lasse Borgholt", email = "lb@corti.ai"}, +] +description = "Plug-and-play data streaming for PyTorch" +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.10" +keywords = ["streaming", "online", "automatic speech recognition", "asr", "neural networks", "pytorch", "torch"] +dependencies = [ + "torch", +] + +[tool.setuptools] +py-modules = ["dreamstream"] + [tool.ruff] line-length = 120 diff --git a/requirements.txt b/requirements.txt index ad61aa3..57e8aac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ nbstripout num2words numba numpy +onnx pandas pylint pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..92f5f35 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,129 @@ +import random +from typing import List + +import pytest +import torch + +from dreamstream.nn.utils.pad_sequence import pad_full_sequence, pad_stream_tensor +from dreamstream.utils.flags import BATCH, LENGTH + + +SEED = 42 +torch.manual_seed(SEED) +random.seed(SEED) + + +BATCH_SIZE = 8 + +# Test parameters for waveforms +SAMPLE_RATE = 1000 +WAVEFORM_DIM = 1 # Number of waveform channels +WAVEFORM_MIN_SECONDS = 3 +WAVEFORM_MAX_SECONDS = 7 +WAVEFORM_CHUNK_MIN_SECONDS = 1 +WAVEFORM_CHUNK_MAX_SECONDS = 2 +WAVEFORM_CHUNK_SECONDS = 1.5 +WAVEFORM_MIN_SIZE = round(WAVEFORM_MIN_SECONDS * SAMPLE_RATE) +WAVEFORM_MAX_SIZE = round(WAVEFORM_MAX_SECONDS * SAMPLE_RATE) +WAVEFORM_CHUNK_MIN_SIZE = round(WAVEFORM_CHUNK_MIN_SECONDS * SAMPLE_RATE) +WAVEFORM_CHUNK_MAX_SIZE = round(WAVEFORM_CHUNK_MAX_SECONDS * SAMPLE_RATE) +WAVEFORM_CHUNK_SIZE = round(WAVEFORM_CHUNK_SECONDS * SAMPLE_RATE) + +# Test parameters for sequences of "tokens" (shorter than waveforms) +# TOKEN_DIM = 4 +# TOKEN_MIN_SIZE = 50 +# TOKEN_MAX_SIZE = 500 +# TOKEN_CHUNK_MIN_SIZE = 10 +# TOKEN_CHUNK_MAX_SIZE = 30 +# TOKEN_CHUNK_SIZE = 20 +TOKEN_DIM = 4 +TOKEN_MIN_SIZE = 50 +TOKEN_MAX_SIZE = 500 +TOKEN_CHUNK_MIN_SIZE = 10 +TOKEN_CHUNK_MAX_SIZE = 30 +TOKEN_CHUNK_SIZE = 20 + + +def random_chunks(full_length, min_size: int = 1000, max_size: int = 8000) -> List[int]: + """Return a list of chunk sizes to use for `torch.split`.""" + chunks = [] + chunk_sum, remaining = 0, full_length + while remaining > 0: + chunks.append(min(random.randint(min_size, max_size), remaining)) + chunk_sum = sum(chunks) + remaining = full_length - chunk_sum + return chunks + + +@pytest.fixture +def waveforms(): + """A `BATCH_SIZE` list of length-varying waveform sequences of size (WAVEFORM_DIM, L) of `torch.rand` values.""" + return [torch.rand(WAVEFORM_DIM, random.randint(WAVEFORM_MIN_SIZE, WAVEFORM_MAX_SIZE)) for _ in range(BATCH_SIZE)] + + +@pytest.fixture +def tokens(): + """A `BATCH_SIZE` list of length-varying token sequences of size (TOKEN_DIM, L) of `torch.rand` values.""" + return [torch.rand(TOKEN_DIM, random.randint(TOKEN_MIN_SIZE, TOKEN_MAX_SIZE)) for _ in range(BATCH_SIZE)] + + +@pytest.fixture +def ids(): + """A number of unique ids of size `BATCH_SIZE`.""" + return [str(hash(i)) for i in range(BATCH_SIZE)] + + +def create_random_batches(data, ids, min_size: int, max_size: int): + """Create a number of batches of shape (BATCH, F, RANDOM_CHUNK_SIZE) from variable length sequences (F, LENGTH).""" + data = pad_full_sequence(data, names=("F", LENGTH), ids=ids).align_to("B", "F", "L") + data = data.unpad_sequence() + data = { + _id: s.split(random_chunks(s.size("L"), min_size=min_size, max_size=max_size), dim=1) + for _id, s in zip(ids, data) + } + + def num_remaining_chunks(data): + return sum([len(x) for x in data.values()]) + + batches = [] + while num_remaining_chunks(data) > 0: + batch = [s.pop(0) for _id, s in data.items() if len(s) > 0 and random.random() < 0.75] + if len(batch) > 0: + batches.append(pad_stream_tensor(batch).align_to(BATCH, "F", LENGTH)) + + return batches + + +def create_structured_batches(data, ids, chunk_size: int): + """Create a number of batches of shape (BATCH, F, CHUNK_SIZE) from variable length sequences (F, LENGTH).""" + data = pad_full_sequence(data, names=("F", LENGTH), ids=ids).align_to("B", "F", "L") + data = data.unpad_sequence() + data = {_id: s.split(chunk_size, dim=1) for _id, s in zip(ids, data)} + + def num_remaining_chunks(data): + return sum([len(x) for x in data.values()]) + + batches = [] + while num_remaining_chunks(data) > 0: + batch = [s.pop(0) for _id, s in data.items() if len(s) > 0] + if len(batch) > 0: + batches.append(pad_stream_tensor(batch).align_to(BATCH, "F", LENGTH)) + + return batches + + +# @pytest.fixture +# def batches_of_waveform_chunks(waveforms, ids): +# """Batches of chunks of waveforms of varying lengths from the `waveforms` data.""" +# return create_random_batches( +# waveforms, +# ids, +# min_size=SAMPLE_RATE * WAVEFORM_CHUNK_MIN_SECONDS, +# max_size=SAMPLE_RATE * WAVEFORM_CHUNK_MAX_SECONDS, +# ) + + +# @pytest.fixture +# def batches_of_token_chunks(tokens, ids): +# """Batches of chunks of tokens of varying lengths from the `tokens` data.""" +# return create_random_batches(tokens, ids, min_size=TOKENS_CHUNK_MIN_SIZE, max_size=TOKENS_CHUNK_MAX_SIZE) diff --git a/tests/data/test_stream_dataset.py b/tests/data/test_stream_dataset.py index dcde685..cc1ce65 100644 --- a/tests/data/test_stream_dataset.py +++ b/tests/data/test_stream_dataset.py @@ -134,7 +134,7 @@ def test_split(self, test_source_df, use_file_lengths, num_workers, batch_size, class TestMultiStreamDataLoader: @pytest.mark.parametrize( - "shuffle, drop_last, non_overlapping_batches", + "shuffle, drop_last, overlapping_batches", [ (False, False, False), (True, False, False), @@ -146,7 +146,7 @@ class TestMultiStreamDataLoader: (True, True, True), ], ) - def test_instantiate(self, test_source_df, shuffle, drop_last, non_overlapping_batches): + def test_instantiate(self, test_source_df, shuffle, drop_last, overlapping_batches): dataset = AudioStreamDataset( file_list=test_source_df.filename, file_lengths=test_source_df.length, @@ -158,10 +158,10 @@ def test_instantiate(self, test_source_df, shuffle, drop_last, non_overlapping_b "num_workers": 5, "shuffle": shuffle, "drop_last": drop_last, - "non_overlapping_batches": non_overlapping_batches, + "overlapping_batches": overlapping_batches, } - if drop_last and non_overlapping_batches: + if drop_last and not overlapping_batches: with pytest.raises(ValueError): MultiStreamDataLoader(dataset, **kwargs) return None @@ -171,9 +171,9 @@ def test_instantiate(self, test_source_df, shuffle, drop_last, non_overlapping_b assert isinstance(iter(dataloader), Generator) @pytest.mark.parametrize( - "num_workers, shuffle, non_overlapping_batches", itertools.product(*[[0, 1, 5], [True, False], [True, False]]) + "num_workers, shuffle, overlapping_batches", itertools.product(*[[0, 1, 5], [True, False], [True, False]]) ) - def test_iterate_drop_last0(self, test_source_df, num_workers, shuffle, non_overlapping_batches): + def test_iterate_drop_last0(self, test_source_df, num_workers, shuffle, overlapping_batches): dataset = AudioStreamDataset( file_list=test_source_df.filename, file_lengths=test_source_df.length, @@ -185,7 +185,7 @@ def test_iterate_drop_last0(self, test_source_df, num_workers, shuffle, non_over num_workers=num_workers, shuffle=shuffle, drop_last=False, - non_overlapping_batches=non_overlapping_batches, + overlapping_batches=overlapping_batches, collate_fn=dataset.custom_collate, ) @@ -205,7 +205,7 @@ def test_iterate_drop_last0(self, test_source_df, num_workers, shuffle, non_over samples_seen.extend(sample_ids) sos = [int(sos) for sos in batch.meta.sos] - if non_overlapping_batches: + if not overlapping_batches: assert all([sos[i] == sos[0] for i in range(len(sos))]), "All files must start simultaneously." # names = [id.split("/")[-1].split(".")[0] for id in batch.meta.ids] @@ -234,7 +234,7 @@ def test_iterate_drop_last1(self, test_source_df, shuffle): num_workers=0, shuffle=shuffle, drop_last=True, - non_overlapping_batches=False, + overlapping_batches=True, collate_fn=dataset.custom_collate, ) diff --git a/tests/patches/test_conv.py b/tests/patches/test_conv.py new file mode 100644 index 0000000..34e024d --- /dev/null +++ b/tests/patches/test_conv.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import pytest +import torchaudio + +from dreamstream import patch +from dreamstream.data.data_objects import OutputCollector +from tests.conftest import ( + WAVEFORM_CHUNK_SIZE, + WAVEFORM_CHUNK_MIN_SIZE, + WAVEFORM_CHUNK_MAX_SIZE, + WAVEFORM_DIM, + create_structured_batches, + create_random_batches, +) + + +class TestConvs: + test_modules = [ + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=1, padding=0), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=1, padding=1), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=3, padding=0), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=3, padding=1), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=4, padding=0), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=4, padding=2), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=5, padding=0), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=5, padding=2), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=5, padding=2, stride=2), + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=5, padding=2, stride=2), + nn.Sequential( + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=3, padding=1), + nn.Conv1d(4, 4, kernel_size=5, padding=2), + ), + nn.Sequential( + nn.Conv1d(WAVEFORM_DIM, 4, kernel_size=3, padding=1, stride=2), + nn.Conv1d(4, 4, kernel_size=5, padding=2, stride=2), + ), + torchaudio.models.Wav2Letter( + num_classes=40, input_type="mfcc", num_features=WAVEFORM_DIM + ), # Too slow for remote CI. + ] + + def recursive_assert(self, module): + assert hasattr(module, "online") + assert hasattr(module, "offline") + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + assert hasattr(module, "stream_buffer") + assert hasattr(module, "kernel_width") + + @pytest.mark.parametrize("module", test_modules) + def test_patch(self, module): + patch(module) + module.apply(self.recursive_assert) + + @pytest.mark.parametrize("module", test_modules) + @pytest.mark.parametrize("is_structured_batches", [True, False]) + def test_equivalence(self, waveforms, ids, module, is_structured_batches): + if is_structured_batches: + batches = create_structured_batches(waveforms, ids, chunk_size=WAVEFORM_CHUNK_SIZE) + else: + batches = create_random_batches(waveforms, ids, min_size=WAVEFORM_CHUNK_MIN_SIZE, max_size=WAVEFORM_CHUNK_MAX_SIZE) + + with torch.inference_mode(): + # Offline targets + module.offline() + targets = {_id: module(s.unsqueeze(0)) for _id, s in zip(ids, waveforms)} + + # Online outputs + stream_output = OutputCollector() + module.online() + for x in batches: + y = module(x) + stream_output.update(y) + + for _id, _y in targets.items(): + y = stream_output[_id].tensor() + assert torch.allclose(_y, y, atol=1e-6), f"{_id} failed" diff --git a/tests/patches/test_rnn.py b/tests/patches/test_rnn.py new file mode 100644 index 0000000..9570e74 --- /dev/null +++ b/tests/patches/test_rnn.py @@ -0,0 +1,132 @@ +import collections +import random + +import torch +import torch.nn as nn +import pytest + +from dreamstream import patch +from dreamstream.data.data_objects import OutputCollector +from dreamstream.utils.flags import BATCH, LENGTH +from tests.conftest import ( + create_random_batches, + create_structured_batches, + TOKEN_CHUNK_SIZE, + TOKEN_CHUNK_MIN_SIZE, + TOKEN_CHUNK_MAX_SIZE, + TOKEN_DIM, + BATCH_SIZE +) + + +class TestRNNs: + # initial_h = torch.randn(1, BATCH_SIZE, 16) + # initial_c = torch.randn(1, BATCH_SIZE, 16) + initial_h = torch.randn(1, 1, 4) + initial_c = torch.randn(1, 1, 4) + + test_modules = [ + nn.RNN(TOKEN_DIM, 4, num_layers=1, bidirectional=False, batch_first=True), + nn.RNN(TOKEN_DIM, 4, num_layers=2, bidirectional=False, batch_first=True), + # nn.GRU(TOKEN_DIM, 16, num_layers=1, bidirectional=False, batch_first=True), + # nn.GRU(TOKEN_DIM, 16, num_layers=2, bidirectional=False, batch_first=True), + # nn.LSTM(TOKEN_DIM, 16, num_layers=1, bidirectional=False, batch_first=True), + # nn.LSTM(TOKEN_DIM, 16, num_layers=2, bidirectional=False, batch_first=True), + ] + + def recursive_assert(self, module): + assert hasattr(module, "online") + assert hasattr(module, "offline") + + if isinstance(module, (nn.RNN, nn.LSTM, nn.GRU)): + assert hasattr(module, "hidden_state_store") + + @pytest.mark.parametrize("module", test_modules) + def test_patch(self, module): + patch(module) + module.apply(self.recursive_assert) + + # @pytest.mark.parametrize("use_initial_state", [False, True]) + # @pytest.mark.parametrize("is_packed_sequence", [False, True]) + # @pytest.mark.parametrize("is_structured_batches", [True, False]) + @pytest.mark.parametrize("use_initial_state", [False]) + @pytest.mark.parametrize("is_packed_sequence", [True]) + @pytest.mark.parametrize("is_structured_batches", [True, False]) + @pytest.mark.parametrize("module", test_modules) + def test_equivalence(self, tokens, ids, module, use_initial_state, is_packed_sequence, is_structured_batches): + lengths = torch.tensor([t.shape[1] for t in tokens]) + # Create batches. + if is_structured_batches: + batches = create_structured_batches(tokens, ids, chunk_size=TOKEN_CHUNK_SIZE) + else: + batches = create_random_batches(tokens, ids, min_size=TOKEN_CHUNK_MIN_SIZE, max_size=TOKEN_CHUNK_MAX_SIZE) + + # (Maybe) Create packed sequences + tokens = [t.transpose(0, 1) for t in tokens] # (D, L) -> (L, D) + batches = [b.align_to(BATCH, LENGTH, "F") for b in batches] # (B, L, F) + if is_packed_sequence: + tokens = [torch.nn.utils.rnn.pack_sequence([t]) for t in tokens] + batches = [torch.nn.utils.rnn.pack_padded_sequence(b.align_to(BATCH, LENGTH, "F"), b.meta.lengths, batch_first=True, enforce_sorted=False) for b in batches] + + # (Maybe) Set initial state + if use_initial_state: + initial_state = (self.initial_h, self.initial_c) if isinstance(module, nn.LSTM) else self.initial_h + else: + initial_state = None + + # Run offline and online versions + with torch.inference_mode(): + # Offline targets + module.offline() + targets = { + _id: module(t, initial_state.squeeze(0) if use_initial_state else None) + for _id, t in zip(ids, tokens) + } + if is_packed_sequence: + targets = { + _id: (torch.nn.utils.rnn.pad_packed_sequence(y, batch_first=True)[0].squeeze(0), out_states) + for _id, (y, out_states) in targets.items() + } + + # Online outputs + module.online() + stream_output = OutputCollector(collection="cat") + # stream_states = OutputCollector(collection="append") + stream_states = dict() + for j, x in enumerate(batches): + y, out_states = module(x, initial_state) + if is_packed_sequence: + y = torch.nn.utils.rnn.pad_packed_sequence(y, batch_first=True)[0] + stream_output.update(y) + for i, _id in enumerate(y.meta.ids): + # print(j, out_states[:, i], targets[_id][1]) + if y.meta.eos[i]: + stream_states[_id] = out_states[:, i] + # stream_states.update(out_states) + + # Compare outputs + failed_outputs = [] + failed_states = [] + for _id, (_y, _state) in targets.items(): + y = stream_output[_id].tensor() + state = stream_states[_id] + + # print(f"Online output: {y}") + # print(f"Offline output: {_y}") + print(_id) + print(f"Online state: {state}") + print(f"Offline state: {_state}") + print((_y - y).abs().sum(0).max()) + print((_state - state).abs().sum(0).max()) + + import IPython + IPython.embed(using=False) + + if not torch.allclose(_y, y, atol=1e-6): + num_errs = ((_y - y).abs().sum(-1) > 1e-6).sum() + failed_outputs.append(f"{_id} failed on n={num_errs} of {y.shape[0]} output steps.") + if not torch.allclose(_state, state, atol=1e-6): + failed_states.append(f"{_id} failed on state, {_state} != {state}") + + if any(failed_states) or any(failed_outputs): + raise AssertionError(f"Failed on:\n{failed_outputs}\n{failed_states}") diff --git a/tests/test_stream_metadata.py b/tests/test_stream_metadata.py new file mode 100644 index 0000000..0cf5cb9 --- /dev/null +++ b/tests/test_stream_metadata.py @@ -0,0 +1,45 @@ +import time + +from dreamstream.tensor import LazyInit, LazyProxy +from dreamstream.utils.timing import timeit + + +class DummyObject(): + def __init__(self, *args, **kwargs): + super().__init__() + self.args = args + self.kwargs = kwargs + time.sleep(0.1) + + +class LazyDummyObject(LazyInit): + def __init__(self, *args, **kwargs): + super().__init__() + self.args = args + self.kwargs = kwargs + + +class TestLazy(): + def test_lazy_proxy_laziness(self): + lazy_proxy = LazyProxy(DummyObject, 1, 2, 3, a=1, b=2, c=3) + assert lazy_proxy.__dict__["_cls"] is DummyObject + assert lazy_proxy.__dict__["_args"] == (1, 2, 3) + assert lazy_proxy.__dict__["_kwargs"] == {'a': 1, 'b': 2, 'c': 3} + assert lazy_proxy.__dict__["_obj"] is None + + def test_lazy_proxy_initialization(self): + lazy_proxy = LazyProxy(DummyObject, 1, 2, 3, a=1, b=2, c=3) + _ = lazy_proxy.args # trigger initialization + assert isinstance(lazy_proxy.__dict__["_obj"], DummyObject) + assert lazy_proxy.__dict__["_obj"].args == (1, 2, 3) + assert lazy_proxy.__dict__["_obj"].kwargs == {'a': 1, 'b': 2, 'c': 3} + + def test_lazy_init(self): + lazy_init = LazyDummyObject(1, 2, 3, a=1, b=2, c=3) + assert isinstance(lazy_init, LazyProxy) + + def test_lazy_init_time_saved(self): + """Test that lazy initialization indeed saves time.""" + lazy_timing = timeit("LazyDummyObject(1, 2, 3, a=1, b=2, c=3)", globals=globals()) + init_timing = timeit("DummyObject(1, 2, 3, a=1, b=2, c=3)", globals=globals()) + assert lazy_timing.median < init_timing.median / 100 diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 99759c3..15d97c3 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,17 +1,17 @@ -import collections +# import collections import pytest import torch from dreamstream.tensor import StreamTensor, as_stream_tensor, stream_tensor, stream_metadata, LENGTH, BATCH -from dreamstream.func_coverage import ( - DECOUPLE_FUNCTIONS, - FLAT_OVERRIDABLE_FUNCTIONS, - OVERRIDDEN_FUNCTIONS, - RECOUPLE_FUNCTIONS, - UNSUPPORTED_FUNCTIONS, - VALID_FUNCTIONS, -) +# from dreamstream.func_coverage import ( +# DECOUPLE_FUNCTIONS, +# FLAT_OVERRIDABLE_FUNCTIONS, +# CUSTOMIZED_FUNCTIONS, +# RECOUPLE_FUNCTIONS, +# UNSUPPORTED_FUNCTIONS, +# VALID_FUNCTIONS, +# ) from dreamstream.overrides import join_dim_names @@ -25,7 +25,7 @@ def test_data_3d(): def test_meta_kwargs(): return dict( - ids=["first", "middle", "last"], + ids=("first", "middle", "last"), sos=[True, False, False], eos=[False, False, True], lengths=[3, 3, 2], @@ -107,7 +107,7 @@ def stream_tensor_bfld_fixture(): ) def test_instantiate_stream_tensor(data): """Test that we can instantiate a StreamTensor from different kinds of data.""" - meta = stream_metadata(ids=["a", "b"], sos=[True, False], eos=[False, True], lengths=[3, 3]) + meta = stream_metadata(ids=("a", "b"), sos=[True, False], eos=[False, True], lengths=[3, 3]) tensor = stream_tensor(data, meta, names=(BATCH, LENGTH)) assert isinstance(tensor, StreamTensor) @@ -198,37 +198,37 @@ def to_torch_tensor_recursive(x): TEST_SKIP_EQUALITY_CHECK = {torch.Tensor.__repr__, torch.Tensor.__str__} -def test_valid_coupled_recoupled_functions(): - """Iterate over all valid, coupled and recoupled functions and check that they work as expected.""" - INPUTS = collections.defaultdict(Inputs) - INPUTS.update(TEST_INPUTS_VALID_FUNCTIONS | TEST_INPUTS_DECOUPLE_FUNCTIONS | TEST_INPUTS_RECOUPLE_FUNCTIONS) +# def test_valid_coupled_recoupled_functions(): +# """Iterate over all valid, coupled and recoupled functions and check that they work as expected.""" +# INPUTS = collections.defaultdict(Inputs) +# INPUTS.update(TEST_INPUTS_VALID_FUNCTIONS | TEST_INPUTS_DECOUPLE_FUNCTIONS | TEST_INPUTS_RECOUPLE_FUNCTIONS) - FUNCTIONS = VALID_FUNCTIONS | DECOUPLE_FUNCTIONS | RECOUPLE_FUNCTIONS +# FUNCTIONS = VALID_FUNCTIONS | DECOUPLE_FUNCTIONS | RECOUPLE_FUNCTIONS - failed = [] - for function in FUNCTIONS: - try: - args, kwargs = INPUTS[function] - stream_tensor_out = function(*args, **kwargs) - torch_tensor_out = function(*to_torch_tensor_recursive(args), **to_torch_tensor_recursive(kwargs)) +# failed = [] +# for function in FUNCTIONS: +# try: +# args, kwargs = INPUTS[function] +# stream_tensor_out = function(*args, **kwargs) +# torch_tensor_out = function(*to_torch_tensor_recursive(args), **to_torch_tensor_recursive(kwargs)) - if function in TEST_SKIP_EQUALITY_CHECK: - continue +# if function in TEST_SKIP_EQUALITY_CHECK: +# continue - if isinstance(torch_tensor_out, torch.Tensor): - stream_tensor_out = stream_tensor_out.rename(None) - torch_tensor_out = torch_tensor_out.rename(None) - assert torch.equal(stream_tensor_out, torch_tensor_out) - else: - assert stream_tensor_out == torch_tensor_out +# if isinstance(torch_tensor_out, torch.Tensor): +# stream_tensor_out = stream_tensor_out.rename(None) +# torch_tensor_out = torch_tensor_out.rename(None) +# assert torch.equal(stream_tensor_out, torch_tensor_out) +# else: +# assert stream_tensor_out == torch_tensor_out - except Exception as e: - failed.append((function, e)) +# except Exception as e: +# failed.append((function, e)) - if any(failed): - failed_str = " - " + "\n\n - ".join([f"{f.__name__}: {e}" for f, e in failed]) - err = f"The following functions claimed to be valid, were not:\n{failed_str}" - raise AssertionError("\n\n" + err) +# if any(failed): +# failed_str = " - " + "\n - ".join([f"{f.__name__}: {e}" for f, e in failed]) +# err = f"The following {len(failed)} functions claimed to be valid, were not:\n{failed_str}" +# raise AssertionError("\n" + err) # def test_unsupported_functions(stream_tensor_bfl_fixture): @@ -244,18 +244,18 @@ def test_valid_coupled_recoupled_functions(): # raise AssertionError(f"The following functions claimed to be invalid, were not:\n{failed_str}") -def test_function_coverage(): - """Test that we have covered all functions in torch.nn.functional.""" - num_overridden = len(OVERRIDDEN_FUNCTIONS) - num_valid = len(VALID_FUNCTIONS) - num_invalid = len(UNSUPPORTED_FUNCTIONS) - num_total = num_overridden + num_valid + num_invalid +# def test_function_coverage(): +# """Test that we have covered all functions in torch.nn.functional.""" +# num_customized = len(CUSTOMIZED_FUNCTIONS) +# num_valid = len(VALID_FUNCTIONS) +# num_invalid = len(UNSUPPORTED_FUNCTIONS) +# num_total = num_customized + num_valid + num_invalid - # fraction_working = (num_overridden + num_valid) / num_total +# fraction_working = (num_customized + num_valid) / num_total - msg = "Total number of functions must be the number of overridable functions plus any dunder methods." - assert num_total >= len(FLAT_OVERRIDABLE_FUNCTIONS), msg - # assert fraction_working > 0.8, f"Only {fraction_working*100:.1f} % of torch functions are covered (req >80%)." +# msg = "Total number of functions must be the number of overridable functions plus any dunder methods." +# assert num_total >= len(FLAT_OVERRIDABLE_FUNCTIONS), msg +# assert fraction_working > 0.8, f"Only {fraction_working*100:.1f} % of torch functions are covered (req >80%)." ## Indexing functions @@ -271,7 +271,7 @@ def assert_stream_tensor_and_meta_correct(stream_tensor, func, *args, ids, lengt t = func(torch_tensor, *args, **kwargs) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), t) - assert s.meta.ids == ids + assert s.meta.ids == tuple(ids) assert torch.equal(s.meta.sos, torch.tensor(sos)) assert torch.equal(s.meta.eos, torch.tensor(eos)) assert torch.equal(s.meta.lengths, torch.tensor(lengths)) @@ -294,7 +294,7 @@ class TestNarrow: def test_narrow_batch(self, stream_tensor_bfl_fixture): """Test `torch.narrow` on a StreamTensor when applied to batch, length, and feature dimensions.""" assert_kwargs = dict( - ids=["middle", "last"], + ids=("middle", "last"), lengths=[3, 2], sos=[False, False], eos=[False, True], @@ -313,7 +313,7 @@ def test_narrow_feature(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.narrow(dim=1, start=1, length=1) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().narrow(dim=1, start=1, length=1)) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([3, 3, 2])) @@ -323,7 +323,7 @@ def test_narrow_length(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.narrow(dim=2, start=1, length=2) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().narrow(dim=2, start=1, length=2)) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([False, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([2, 2, 1])) @@ -370,7 +370,7 @@ def test_gather_feature(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.gather(dim=1, index=self.full_index) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().gather(dim=1, index=self.full_index)) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([3, 3, 2])) @@ -380,7 +380,7 @@ def test_gather_feature_truncated(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.gather(dim=1, index=self.truncated_index) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().gather(dim=1, index=self.truncated_index)) - assert s.meta.ids == ["first", "middle"] + assert s.meta.ids == ("first", "middle") assert torch.equal(s.meta.sos, torch.tensor([True, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False])) assert torch.equal(s.meta.lengths, torch.tensor([2, 2])) @@ -403,7 +403,7 @@ def test_take_along_feature(self, stream_tensor_bfl_fixture): assert torch.equal( s.tensor(), stream_tensor_bfl_fixture.tensor().take_along_dim(dim=1, indices=TestGather.full_index) ) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([3, 3, 2])) @@ -424,7 +424,7 @@ def test_select_batch(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.select(dim=0, index=1) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().select(dim=0, index=1)) - assert s.meta.ids == ["middle"] + assert s.meta.ids == ("middle",) assert torch.equal(s.meta.sos, torch.tensor([False])) assert torch.equal(s.meta.eos, torch.tensor([False])) assert torch.equal(s.meta.lengths, torch.tensor([3])) @@ -434,7 +434,7 @@ def test_select_feature(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.select(dim=1, index=1) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().select(dim=1, index=1)) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([3, 3, 2])) @@ -444,7 +444,7 @@ def test_select_length(self, stream_tensor_bfl_fixture): s = stream_tensor_bfl_fixture.select(dim=2, index=1) assert isinstance(s, StreamTensor) assert torch.equal(s.tensor(), stream_tensor_bfl_fixture.tensor().select(dim=2, index=1)) - assert s.meta.ids == ["first", "middle", "last"] + assert s.meta.ids == ("first", "middle", "last") assert torch.equal(s.meta.sos, torch.tensor([False, False, False])) assert torch.equal(s.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s.meta.lengths, torch.tensor([1, 1, 1])) @@ -458,7 +458,7 @@ def test_take_batch(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s1.meta.ids == ["first"] + assert s1.meta.ids == ("first",) assert torch.equal(s1.meta.sos, torch.tensor([True])) assert torch.equal(s1.meta.eos, torch.tensor([False])) assert torch.equal(s1.meta.lengths, torch.tensor([3])) @@ -469,7 +469,7 @@ def test_take_batch(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s2.meta.ids == ["first"] + assert s2.meta.ids == ("first",) assert torch.equal(s2.meta.sos, torch.tensor([False])) assert torch.equal(s2.meta.eos, torch.tensor([False])) assert torch.equal(s2.meta.lengths, torch.tensor([2])) @@ -481,7 +481,7 @@ def test_take_feature(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -493,7 +493,7 @@ def test_take_length(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, False])) assert torch.equal(s1.meta.lengths, torch.tensor([1, 1, 1])) @@ -504,7 +504,7 @@ def test_take_length(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, False])) assert torch.equal(s1.meta.lengths, torch.tensor([1, 1, 0])) @@ -516,7 +516,7 @@ def test_take_every_other(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture.take(indices) assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor().take(indices)) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -531,7 +531,7 @@ def test_index_select_batch(self, stream_tensor_bfl_fixture): torch.index_select, dim=0, index=torch.tensor([0, 2]), - ids=["first", "last"], + ids=("first", "last"), sos=torch.tensor([True, False]), eos=torch.tensor([False, True]), lengths=torch.tensor([3, 2]), @@ -608,7 +608,7 @@ def test_feature_indexing_integer(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[:, 0, :] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, 0, :]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -620,7 +620,7 @@ def test_batch_indexing_integer(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[0] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[0]) - assert s1.meta.ids == ["first"] # changed to only the first example + assert s1.meta.ids == ("first",) # changed to only the first example assert torch.equal(s1.meta.sos, torch.tensor([True])) assert torch.equal(s1.meta.eos, torch.tensor([False])) assert torch.equal(s1.meta.lengths, torch.tensor([3])) @@ -631,7 +631,7 @@ def test_batch_indexing_slice(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[1:] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[1:]) - assert s1.meta.ids == ["middle", "last"] + assert s1.meta.ids == ("middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 2])) @@ -641,7 +641,7 @@ def test_batch_indexing_tuple(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[(0,)] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[(0,)]) - assert s1.meta.ids == ["first"] + assert s1.meta.ids == ("first",) assert torch.equal(s1.meta.sos, torch.tensor([True])) assert torch.equal(s1.meta.eos, torch.tensor([False])) assert torch.equal(s1.meta.lengths, torch.tensor([3])) @@ -650,7 +650,7 @@ def test_batch_indexing_tuple(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[(0, 1)] assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[(0, 1)]) - assert s2.meta.ids == ["first"] + assert s2.meta.ids == ("first",) assert torch.equal(s2.meta.sos, torch.tensor([True])) assert torch.equal(s2.meta.eos, torch.tensor([False])) assert torch.equal(s2.meta.lengths, torch.tensor([3])) @@ -659,7 +659,7 @@ def test_batch_indexing_tuple(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[(0, 1, 2)] assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[(0, 1, 2)]) - assert s3.meta.ids == ["first"] + assert s3.meta.ids == ("first",) assert torch.equal(s3.meta.sos, torch.tensor([False])) assert torch.equal(s3.meta.eos, torch.tensor([False])) assert torch.equal(s3.meta.lengths, torch.tensor([1])) @@ -668,7 +668,7 @@ def test_batch_indexing_tuple(self, stream_tensor_bfl_fixture): s4 = stream_tensor_bfl_fixture[(0, 1, 0)] assert isinstance(s4, StreamTensor) assert torch.equal(s4.tensor(), stream_tensor_bfl_fixture.tensor()[(0, 1, 0)]) - assert s4.meta.ids == ["first"] + assert s4.meta.ids == ("first",) assert torch.equal(s4.meta.sos, torch.tensor([True])) assert torch.equal(s4.meta.eos, torch.tensor([False])) assert torch.equal(s4.meta.lengths, torch.tensor([1])) @@ -679,7 +679,7 @@ def test_batch_indexing_list(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[[1]] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[[1]]) - assert s1.meta.ids == ["middle"] + assert s1.meta.ids == ("middle",) assert torch.equal(s1.meta.sos, torch.tensor([False])) assert torch.equal(s1.meta.eos, torch.tensor([False])) assert torch.equal(s1.meta.lengths, torch.tensor([3])) @@ -688,7 +688,7 @@ def test_batch_indexing_list(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[[0, 2]] assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[[0, 2]]) - assert s2.meta.ids == ["first", "last"] + assert s2.meta.ids == ("first", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, True])) assert torch.equal(s2.meta.lengths, torch.tensor([3, 2])) @@ -699,7 +699,7 @@ def test_batch_indexing_booltensor(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[torch.tensor([False, True, True])] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[torch.tensor([False, True, True])]) - assert s1.meta.ids == ["middle", "last"] + assert s1.meta.ids == ("middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 2])) @@ -710,7 +710,7 @@ def test_batch_indexing_inttensor(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[torch.tensor([1, 2])] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[torch.tensor([1, 2])]) - assert s1.meta.ids == ["middle", "last"] + assert s1.meta.ids == ("middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 2])) @@ -724,7 +724,7 @@ def test_length_indexing_integer(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[:, :, 0] # first length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, 0]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s2.meta.lengths, torch.tensor([1, 1, 1])) # changed to 1 @@ -733,7 +733,7 @@ def test_length_indexing_integer(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[:, :, 1] # middle length index assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, 1]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s3.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s3.meta.lengths, torch.tensor([1, 1, 1])) # changed to 1 @@ -742,7 +742,7 @@ def test_length_indexing_integer(self, stream_tensor_bfl_fixture): s4 = stream_tensor_bfl_fixture[:, :, -1] # last length index assert isinstance(s4, StreamTensor) assert torch.equal(s4.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, -1]) - assert s4.meta.ids == ["first", "middle", "last"] + assert s4.meta.ids == ("first", "middle", "last") assert torch.equal(s4.meta.sos, torch.tensor([False, False, False])) # change to False assert torch.equal(s4.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s4.meta.lengths, torch.tensor([1, 1, 0])) # changed to 1 and 0 @@ -758,7 +758,7 @@ def test_length_indexing_slice(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[:, :, 1:] # remove first length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, 1:]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s2.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s2.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all) @@ -767,7 +767,7 @@ def test_length_indexing_slice(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[:, :, :-1] # remove last length index assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, :-1]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s3.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s3.meta.lengths, torch.tensor([2, 2, 2])) # minus 1 for all but "last" since it was padding @@ -776,7 +776,7 @@ def test_length_indexing_slice(self, stream_tensor_bfl_fixture): s4 = stream_tensor_bfl_fixture[:, :, 1:-1] # remove first and last length index assert isinstance(s4, StreamTensor) assert torch.equal(s4.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, 1:-1]) - assert s4.meta.ids == ["first", "middle", "last"] + assert s4.meta.ids == ("first", "middle", "last") assert torch.equal(s4.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s4.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s4.meta.lengths, torch.tensor([1, 1, 1])) # minus 2 for all but "last" since it was padding @@ -785,7 +785,7 @@ def test_length_indexing_slice(self, stream_tensor_bfl_fixture): s5 = stream_tensor_bfl_fixture[:, :, :-2] # remove two last length indices assert isinstance(s5, StreamTensor) assert torch.equal(s5.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, :-2]) - assert s5.meta.ids == ["first", "middle", "last"] + assert s5.meta.ids == ("first", "middle", "last") assert torch.equal(s5.meta.sos, torch.tensor([True, False, False])) # changed to False assert torch.equal(s5.meta.eos, torch.tensor([False, False, False])) assert torch.equal(s5.meta.lengths, torch.tensor([1, 1, 1])) # minus 2 for all but "last" since it was padding @@ -794,7 +794,7 @@ def test_length_indexing_slice(self, stream_tensor_bfl_fixture): s6 = stream_tensor_bfl_fixture[:, :, ::2] # remove every other length index from start to end assert isinstance(s6, StreamTensor) assert torch.equal(s6.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, ::2]) - assert s6.meta.ids == ["first", "middle", "last"] + assert s6.meta.ids == ("first", "middle", "last") assert torch.equal(s6.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s6.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s6.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -812,7 +812,7 @@ def test_length_indexing_tuple_list(self, stream_tensor_bfl_fixture, indices): s1 = stream_tensor_bfl_fixture[:, :, indices[0]] # remove first length index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, indices[0]]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -821,7 +821,7 @@ def test_length_indexing_tuple_list(self, stream_tensor_bfl_fixture, indices): s2 = stream_tensor_bfl_fixture[:, :, indices[1]] # remove middle length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, indices[1]]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s2.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -830,7 +830,7 @@ def test_length_indexing_tuple_list(self, stream_tensor_bfl_fixture, indices): s3 = stream_tensor_bfl_fixture[:, :, indices[2]] # remove last length index assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, indices[2]]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s3.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s3.meta.lengths, torch.tensor([2, 2, 2])) # minus 1 for all but "last" since was padding @@ -841,7 +841,7 @@ def test_length_indexing_list(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[:, :, [1, 2]] # remove first length index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, [1, 2]]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -850,7 +850,7 @@ def test_length_indexing_list(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[:, :, [0, 2]] # remove middle length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, [0, 2]]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s2.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -859,7 +859,7 @@ def test_length_indexing_list(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[:, :, [0, 1]] # remove last length index assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, [0, 1]]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s3.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s3.meta.lengths, torch.tensor([2, 2, 2])) # minus 1 for all but "last" since it was padding @@ -870,7 +870,7 @@ def test_length_indexing_1d_inttensor(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[:, :, torch.tensor([1, 2])] # remove first length index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, torch.tensor([1, 2])]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -879,7 +879,7 @@ def test_length_indexing_1d_inttensor(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[:, :, torch.tensor([0, 2])] # remove middle length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, torch.tensor([0, 2])]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s2.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -888,7 +888,7 @@ def test_length_indexing_1d_inttensor(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[:, :, torch.tensor([0, 1])] # remove last length index assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, torch.tensor([0, 1])]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s3.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s3.meta.lengths, torch.tensor([2, 2, 2])) # minus 1 for all but "last" since it was padding @@ -899,7 +899,7 @@ def test_length_indexing_1d_booltensor(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[:, :, torch.tensor([False, True, True])] # remove first length index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, torch.tensor([False, True, True])]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -908,7 +908,7 @@ def test_length_indexing_1d_booltensor(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[:, :, torch.tensor([True, True, False])] # remove last length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, :, torch.tensor([True, True, False])]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s2.meta.lengths, torch.tensor([2, 2, 2])) # minus 1 for all but last since this was padding. @@ -920,7 +920,7 @@ def test_batch_and_feature_indexing_2d_booltensor(self, stream_tensor_bfl_fixtur s1 = stream_tensor_bfl_fixture[indices] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[indices]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -934,7 +934,7 @@ def test_length_and_feature_indexing_2d_booltensor(self, stream_tensor_bfl_fixtu s1 = stream_tensor_bfl_fixture[:, indices] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, indices]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -944,7 +944,7 @@ def test_length_and_feature_indexing_2d_booltensor(self, stream_tensor_bfl_fixtu s2 = stream_tensor_bfl_fixture[:, indices] assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[:, indices]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, False])) assert torch.equal(s2.meta.lengths, torch.tensor([1, 1, 1])) # changed to 1 @@ -994,7 +994,7 @@ def test_feature_indexing_2d_inttensor(self, stream_tensor_bfl_fixture): s1 = stream_tensor_bfl_fixture[:, indices] assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, indices]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([3, 3, 2])) @@ -1020,7 +1020,7 @@ def test_length_indexing_integer_multidimensional(self, stream_tensor_bfl_fixtur s1 = stream_tensor_bfl_fixture[:, 0, 0] # first length index and first feature index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, 0, 0]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s1.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.lengths, torch.tensor([1, 1, 1])) # changed to 1 @@ -1034,7 +1034,7 @@ def test_length_indexing_slice_multidimensional(self, stream_tensor_bfl_fixture) s1 = stream_tensor_bfl_fixture[:, 0, 1:] # remove first length index and first feature index assert isinstance(s1, StreamTensor) assert torch.equal(s1.tensor(), stream_tensor_bfl_fixture.tensor()[:, 0, 1:]) - assert s1.meta.ids == ["first", "middle", "last"] + assert s1.meta.ids == ("first", "middle", "last") assert torch.equal(s1.meta.sos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s1.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s1.meta.lengths, torch.tensor([2, 2, 1])) # minus 1 for all @@ -1048,7 +1048,7 @@ def test_batch_and_length_indexing_slice(self, stream_tensor_bfl_fixture): stream_tensor_bfl_fixture, torch.Tensor.__getitem__, (slice(None, -1), slice(None), slice(1, None)), # [:-1, :, 1:] - ids=["first", "middle"], + ids=("first", "middle"), sos=torch.tensor([False, False]), eos=torch.tensor([False, False]), lengths=torch.tensor([2, 2]), @@ -1060,7 +1060,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s1_1 = stream_tensor_bfl_fixture[0, ...] # keep only first batch example assert isinstance(s1_1, StreamTensor) assert torch.equal(s1_1.tensor(), stream_tensor_bfl_fixture.tensor()[0, ...]) - assert s1_1.meta.ids == ["first"] + assert s1_1.meta.ids == ("first",) assert torch.equal(s1_1.meta.sos, torch.tensor([True])) assert torch.equal(s1_1.meta.eos, torch.tensor([False])) assert torch.equal(s1_1.meta.lengths, torch.tensor([3])) @@ -1069,7 +1069,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s1_2 = stream_tensor_bfl_fixture[0, ..., :] # keep only first batch example assert isinstance(s1_2, StreamTensor) assert torch.equal(s1_2.tensor(), stream_tensor_bfl_fixture.tensor()[0, ..., :]) - assert s1_2.meta.ids == ["first"] + assert s1_2.meta.ids == ("first",) assert torch.equal(s1_2.meta.sos, torch.tensor([True])) assert torch.equal(s1_2.meta.eos, torch.tensor([False])) assert torch.equal(s1_2.meta.lengths, torch.tensor([3])) @@ -1078,7 +1078,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s1_3 = stream_tensor_bfl_fixture[0, :, ...] # keep only first batch example assert isinstance(s1_3, StreamTensor) assert torch.equal(s1_1.tensor(), stream_tensor_bfl_fixture.tensor()[0, :, ...]) - assert s1_3.meta.ids == ["first"] + assert s1_3.meta.ids == ("first",) assert torch.equal(s1_3.meta.sos, torch.tensor([True])) assert torch.equal(s1_3.meta.eos, torch.tensor([False])) assert torch.equal(s1_3.meta.lengths, torch.tensor([3])) @@ -1087,7 +1087,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s1_4 = stream_tensor_bfl_fixture[0, ..., ...] # keep only first batch example assert isinstance(s1_4, StreamTensor) assert torch.equal(s1_1.tensor(), stream_tensor_bfl_fixture.tensor()[0, ..., ...]) - assert s1_4.meta.ids == ["first"] + assert s1_4.meta.ids == ("first",) assert torch.equal(s1_4.meta.sos, torch.tensor([True])) assert torch.equal(s1_4.meta.eos, torch.tensor([False])) assert torch.equal(s1_4.meta.lengths, torch.tensor([3])) @@ -1096,7 +1096,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s2 = stream_tensor_bfl_fixture[..., 0] # keep only first length index assert isinstance(s2, StreamTensor) assert torch.equal(s2.tensor(), stream_tensor_bfl_fixture.tensor()[..., 0]) - assert s2.meta.ids == ["first", "middle", "last"] + assert s2.meta.ids == ("first", "middle", "last") assert torch.equal(s2.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s2.meta.eos, torch.tensor([False, False, False])) # changed to False assert torch.equal(s2.meta.lengths, torch.tensor([1, 1, 1])) # set to 1 @@ -1105,7 +1105,7 @@ def test_indexing_ellipsis(self, stream_tensor_bfl_fixture): s3 = stream_tensor_bfl_fixture[..., 0, ...] # keep only first feature dim assert isinstance(s3, StreamTensor) assert torch.equal(s3.tensor(), stream_tensor_bfl_fixture.tensor()[..., 0, ...]) - assert s3.meta.ids == ["first", "middle", "last"] + assert s3.meta.ids == ("first", "middle", "last") assert torch.equal(s3.meta.sos, torch.tensor([True, False, False])) assert torch.equal(s3.meta.eos, torch.tensor([False, False, True])) assert torch.equal(s3.meta.lengths, torch.tensor([3, 3, 2])) @@ -1342,3 +1342,25 @@ def test_squeeze_all_singleton_dims(self, stream_tensor_bfl_fixture, stream_meta torch.squeeze, **stream_meta_kwargs_fixture, ) + + +class TestUnbind: + def test_unbind_batch(self, stream_tensor_bfl_fixture): + batch_size = stream_tensor_bfl_fixture.size(0) + tensors = stream_tensor_bfl_fixture.unbind(dim=0) + assert len(tensors) == batch_size + for i in range(batch_size): + assert stream_tensor_bfl_fixture.meta.ids[i] == tensors[i].meta.ids[0] + assert stream_tensor_bfl_fixture.meta.sos[i] == tensors[i].meta.sos[0] + assert stream_tensor_bfl_fixture.meta.eos[i] == tensors[i].meta.eos[0] + assert stream_tensor_bfl_fixture.meta.lengths[i] == tensors[i].meta.lengths[0] + + def test_unbind_feature(self, stream_tensor_bfl_fixture): + tensors = stream_tensor_bfl_fixture.unbind(dim=1) + for tensor in tensors: + assert tensor.meta == stream_tensor_bfl_fixture.meta + + def test_unbind_length(self, stream_tensor_bfl_fixture): + # TODO (JDH): Implement this. + with pytest.raises(ValueError): + stream_tensor_bfl_fixture.unbind(dim=2)