diff --git a/src/json_stream/base.py b/src/json_stream/base.py index 8e03bc8..6da4db5 100644 --- a/src/json_stream/base.py +++ b/src/json_stream/base.py @@ -1,8 +1,8 @@ import collections import copy from abc import ABC -from collections import OrderedDict -from itertools import chain +from collections import OrderedDict, deque +from itertools import chain, count as itertools_count, zip_longest from typing import Optional, Iterator, Any from json_stream.tokenizer import TokenType @@ -106,6 +106,10 @@ def __len__(self) -> int: def __repr__(self): # pragma: no cover return f"<{type(self).__name__}: {repr(self._data)}, {'STREAMING' if self.streaming else 'DONE'}>" + def __contains__(self, item): + self.read_all() + return item in self._data + class TransientStreamingJSONBase(StreamingJSONBase, ABC): def __init__(self, token_stream): @@ -158,6 +162,20 @@ def _load_item(self): def _get__iter__(self): return self._iter_items() + def index(self, item, start=0, stop=None): + for i, v in enumerate(self._iter_items()): + if i < start: + continue + elif stop is not None and i > stop: + break + if v is item or v == item: + return i + raise ValueError + + @staticmethod + def _index_args(*args): + return args[:args.index(None) if None in args else -1] + class PersistentStreamingJSONList(PersistentStreamingJSONBase, StreamingJSONList): def _init_persistent_data(self): @@ -183,6 +201,21 @@ def __getitem__(self, k) -> Any: pass return self._find_item(k) + def index(self, item, /, start=None, stop=None): + args = self._index_args(start, stop) + try: + return self._data.index(item, *args) + except ValueError: + return len(self._data) + super().index(item, *args) + + def count(self, item): + self.read_all() + return self._data.count(item) + + def __reversed__(self): + self.read_all() + return reversed(self._data) + class TransientStreamingJSONList(TransientStreamingJSONBase, StreamingJSONList): def __init__(self, token_stream): @@ -202,6 +235,25 @@ def _find_item(self, i): return v raise IndexError(f"Index {i} out of range") + def index(self, item, /, start=None, stop=None): + if (start is not None and start < 0) or (stop is not None and stop < 0): + raise IndexError("Negative indices not supported for transient lists") + return self._index + super().index(item, *self._index_args(start, stop)) + 1 + + def count(self, item): + self._check_started() + # equivalent to but faster than sum(1 for i in self if i is item or i == item) + counter = itertools_count() + deque(zip((i for i in self._iter_items() if i is item or i == item), counter), maxlen=0) # (consume at C speed) + return next(counter) + + def __reversed__(self): + self._check_started() + # this approach releases memory as iterator advances + stack = deque(self._iter_items()) + while stack: + yield stack.pop() + class StreamingJSONObject(StreamingJSONBase, ABC): INCOMPLETE_ERROR = "Unterminated object at end of file" @@ -275,6 +327,18 @@ def __getitem__(self, k) -> Any: pass return self._find_item(k) + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + self.read_all() + return self._data == other + + def __ne__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + self.read_all() + return self._data != other + class TransientStreamingJSONObject(TransientStreamingJSONBase, StreamingJSONObject): def _find_item(self, k): @@ -299,3 +363,9 @@ def keys(self): def values(self): self._check_started() return (v for k, v in self._iter_items()) + + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + not_equal = object() # sentinel for length differences + return all(a == b for a, b in zip_longest(self.items(), other.items(), fillvalue=not_equal))