diff --git a/pyproject.toml b/pyproject.toml index 3baa10033..74a4c4014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ all = [ "python-dateutil", "python-jose[cryptography]", "python-multipart", + "ragged", "redis", "rich", "sparse >=0.15.5", @@ -126,6 +127,7 @@ client = [ "numpy", "pandas", "pyarrow >=14.0.1", # includes fix to CVE 2023-47248 + "ragged", "rich", "sparse >=0.15.5", "stamina", @@ -260,6 +262,7 @@ server = [ "python-dateutil", "python-jose[cryptography]", "python-multipart", + "ragged", "sparse >=0.15.5", "stamina", "redis", diff --git a/tiled/_tests/adapters/test_sql.py b/tiled/_tests/adapters/test_sql.py index 0d627004e..91b98c9cf 100644 --- a/tiled/_tests/adapters/test_sql.py +++ b/tiled/_tests/adapters/test_sql.py @@ -5,6 +5,7 @@ import pyarrow as pa import pytest +from tiled.adapters.array import ArrayAdapter from tiled.adapters.sql import ( COLUMN_NAME_PATTERN, TABLE_NAME_PATTERN, @@ -21,20 +22,25 @@ data0 = [ pa.array([1, 2, 3, 4, 5]), pa.array([1.0, 2.0, 3.0, 4.0, 5.0]), - pa.array(["foo0", "bar0", "baz0", None, "goo0"]), - pa.array([True, None, False, True, None]), + # pa.array(["foo0", "bar0", "baz0", None, "goo0"]), + # pa.array([True, None, False, True, None]), + pa.array(["foo0", "bar0", "baz0", "None", "goo0"]), + pa.array([True, bool(None), False, True, bool(None)]), ] data1 = [ pa.array([6, 7, 8, 9, 10, 11, 12]), pa.array([6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]), - pa.array(["foo1", "bar1", None, "baz1", "biz", None, "goo"]), - pa.array([None, True, True, False, False, None, True]), + # pa.array(["foo1", "bar1", None, "baz1", "biz", None, "goo"]), + # pa.array([None, True, True, False, False, None, True]), + pa.array(["foo1", "bar1", "None", "baz1", "biz", "None", "goo"]), + pa.array([bool(None), True, True, False, False, bool(None), True]), ] data2 = [ pa.array([13, 14]), pa.array([13.0, 14.0]), pa.array(["foo2", "baz2"]), - pa.array([False, None]), + # pa.array([False, None]), + pa.array([False, bool(None)]), ] batch0 = pa.record_batch(data0, names=names) @@ -797,3 +803,38 @@ def deep_array_equal(a1: Any, a2: Any) -> bool: assert deep_array_equal(result_part, result_full) storage.dispose() # Close all connections + + +@pytest.mark.parametrize( + "sql_adapter_name", + [ + "adapter_duckdb_many_partitions", + "adapter_psql_many_partitions", + "adapter_sqlite_many_partitions", + ], +) +@pytest.mark.parametrize("field", names) +def test_compare_field_data_from_array_adapter( + sql_adapter_name: str, + field: str, + request: pytest.FixtureRequest, +) -> None: + # get adapter from fixture + sql_adapter: SQLAdapter = request.getfixturevalue(sql_adapter_name) + + table = pa.Table.from_batches([batch0, batch1, batch2]) + sql_adapter.append_partition(0, table) + + array_adapter = sql_adapter[field] + assert isinstance(array_adapter, ArrayAdapter) + + result_read = array_adapter.read() + field_index = names.index(field) + assert np.array_equal( + [ + *data0[field_index].tolist(), + *data1[field_index].tolist(), + *data2[field_index].tolist(), + ], + result_read.tolist(), + ) diff --git a/tiled/_tests/adapters/test_sql_arrays.py b/tiled/_tests/adapters/test_sql_arrays.py index 109dc1640..c474f5d6d 100644 --- a/tiled/_tests/adapters/test_sql_arrays.py +++ b/tiled/_tests/adapters/test_sql_arrays.py @@ -1,5 +1,6 @@ -from typing import Callable, cast +from typing import Callable, Dict, Type, Union, cast +import awkward as ak import numpy as np import pyarrow as pa import pytest @@ -9,6 +10,9 @@ from tiled._tests.adapters.test_sql import adapter_psql_many_partitions # noqa: F401 from tiled._tests.adapters.test_sql import adapter_psql_one_partition # noqa: F401 from tiled._tests.adapters.test_sql import assert_same_rows +from tiled.adapters.array import ArrayAdapter +from tiled.adapters.awkward import AwkwardAdapter +from tiled.adapters.ragged import RaggedAdapter from tiled.adapters.sql import SQLAdapter from tiled.storage import SQLStorage, parse_storage, register_storage from tiled.structures.core import StructureFamily @@ -17,57 +21,29 @@ rng = np.random.default_rng(42) -names = ["i0", "i1", "i2", "i3", "f4", "f5"] +names_adapters: Dict[str, Type[Union[ArrayAdapter, AwkwardAdapter, RaggedAdapter]]] = { + "integers": ArrayAdapter, + "floats": ArrayAdapter, + "ragged_floats": RaggedAdapter, +} +names = list(names_adapters.keys()) batch_size = 5 data0 = [ - pa.array( - [rng.integers(-100, 100, size=10, dtype=np.int8) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=11, dtype=np.int16) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=12, dtype=np.int32) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=13, dtype=np.int64) for _ in range(batch_size)] - ), - pa.array([rng.random(size=14, dtype=np.float32) for _ in range(batch_size)]), - pa.array([rng.random(size=15, dtype=np.float64) for _ in range(batch_size)]), + pa.array([rng.integers(-100, 100, size=10) for _ in range(batch_size)]), + pa.array([rng.random(size=15) for _ in range(batch_size)]), + pa.array([rng.random(size=rng.integers(1, 10)) for _ in range(batch_size)]), ] batch_size = 8 data1 = [ - pa.array( - [rng.integers(-100, 100, size=10, dtype=np.int8) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=11, dtype=np.int16) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=12, dtype=np.int32) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=13, dtype=np.int64) for _ in range(batch_size)] - ), - pa.array([rng.random(size=14, dtype=np.float32) for _ in range(batch_size)]), - pa.array([rng.random(size=15, dtype=np.float64) for _ in range(batch_size)]), + pa.array([rng.integers(-100, 100, size=10) for _ in range(batch_size)]), + pa.array([rng.random(size=15) for _ in range(batch_size)]), + pa.array([rng.random(size=rng.integers(1, 10)) for _ in range(batch_size)]), ] batch_size = 3 data2 = [ - pa.array( - [rng.integers(-100, 100, size=10, dtype=np.int8) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=11, dtype=np.int16) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=12, dtype=np.int32) for _ in range(batch_size)] - ), - pa.array( - [rng.integers(-100, 100, size=13, dtype=np.int64) for _ in range(batch_size)] - ), - pa.array([rng.random(size=14, dtype=np.float32) for _ in range(batch_size)]), - pa.array([rng.random(size=15, dtype=np.float64) for _ in range(batch_size)]), + pa.array([rng.integers(-100, 100, size=10) for _ in range(batch_size)]), + pa.array([rng.random(size=15) for _ in range(batch_size)]), + pa.array([rng.random(size=rng.integers(1, 10)) for _ in range(batch_size)]), ] batch0 = pa.record_batch(data0, names=names) @@ -90,7 +66,7 @@ def _data_source_from_init_storage( assets=[], ) - storage = cast(SQLStorage, parse_storage(data_uri)) + storage = cast("SQLStorage", parse_storage(data_uri)) register_storage(storage) return SQLAdapter.init_storage(data_source=data_source, storage=storage) @@ -240,17 +216,53 @@ def test_write_read_one_batch_many_part( # read a specific field result_read = adapter.read_partition(0, fields=[field]) field_index = names.index(field) - assert np.array_equal( + assert ak.array_equal( [*data0[field_index].tolist(), *data2[field_index].tolist()], result_read[field].tolist(), ) result_read = adapter.read_partition(1, fields=[field]) - assert np.array_equal( + assert ak.array_equal( [*data1[field_index].tolist(), *data0[field_index].tolist()], result_read[field].tolist(), ) result_read = adapter.read_partition(2, fields=[field]) - assert np.array_equal( + assert ak.array_equal( [*data2[field_index].tolist(), *data1[field_index].tolist()], result_read[field].tolist(), ) + + +@pytest.mark.parametrize( + "sql_adapter_name", + [("adapter_duckdb_many_partitions"), ("adapter_psql_many_partitions")], +) +@pytest.mark.parametrize(("field", "array_adapter_type"), [*names_adapters.items()]) +def test_compare_field_data_from_array_adapter( + sql_adapter_name: str, + field: str, + array_adapter_type: type, + request: pytest.FixtureRequest, +) -> None: + # get adapter from fixture + sql_adapter: SQLAdapter = request.getfixturevalue(sql_adapter_name) + + table = pa.Table.from_batches([batch0, batch1, batch2]) + sql_adapter.append_partition(0, table) + + array_adapter = sql_adapter[field] + assert isinstance(array_adapter, array_adapter_type) + + field_index = names.index(field) + if isinstance(array_adapter, AwkwardAdapter): + result_read = array_adapter.read() # smoke test + raise NotImplementedError + else: + result_read = array_adapter.read() + assert ak.array_equal( + [ + *data0[field_index].tolist(), + *data1[field_index].tolist(), + *data2[field_index].tolist(), + ], + result_read.tolist(), # type: ignore[attr-defined] + ) diff --git a/tiled/_tests/adapters/test_sql_types.py b/tiled/_tests/adapters/test_sql_types.py index e9bc0923e..18e435147 100644 --- a/tiled/_tests/adapters/test_sql_types.py +++ b/tiled/_tests/adapters/test_sql_types.py @@ -206,6 +206,25 @@ def duckdb_uri(tmp_path: Path) -> Generator[str, None, None]: "duckdb": (["DECIMAL(5, 2) NULL"], pa.schema([("x", pa.decimal128(5, 2))])), }, ), + "ragged_lists": ( + pa.Table.from_arrays( + [ + pa.array([[1], [2, 3], [4, 5, 6]], pa.list_(pa.int32())), + pa.array([[1.1, 2.2, 3.3], [4.4, 5.5], [6.6]], pa.list_(pa.float32())), + ], + names=["x", "y"], + ), + { + "duckdb": ( + ["INTEGER[] NULL", "REAL[] NULL"], + pa.schema([("x", pa.list_(pa.int32())), ("y", pa.list_(pa.float32()))]), + ), + "postgresql": ( + ["INTEGER ARRAY NULL", "REAL ARRAY NULL"], + pa.schema([("x", pa.list_(pa.int32())), ("y", pa.list_(pa.float32()))]), + ), + }, + ), } diff --git a/tiled/_tests/test_ragged.py b/tiled/_tests/test_ragged.py new file mode 100644 index 000000000..c628f23b5 --- /dev/null +++ b/tiled/_tests/test_ragged.py @@ -0,0 +1,194 @@ +import awkward as ak +import numpy as np +import pyarrow.feather +import pyarrow.parquet +import pytest +import ragged + +from tiled.catalog import in_memory +from tiled.client import Context, from_context, record_history +from tiled.serialization.ragged import ( + from_flattened_array, + from_flattened_octet_stream, + from_json, + to_flattened_array, + to_flattened_octet_stream, + to_json, +) +from tiled.server.app import build_app +from tiled.structures.ragged import RaggedStructure +from tiled.utils import APACHE_ARROW_FILE_MIME_TYPE + + +@pytest.fixture +def catalog(tmpdir): + catalog = in_memory(writable_storage=str(tmpdir)) + yield catalog + + +@pytest.fixture +def app(catalog): + app = build_app(catalog) + yield app + + +@pytest.fixture +def context(app): + with Context.from_app(app) as context: + yield context + + +@pytest.fixture +def client(context): + client = from_context(context) + yield client + + +RNG = np.random.default_rng(42) + +arrays = { + # "empty_1d": ragged.array([]), + # "empty_nd": ragged.array([[], [], []]), + "numpy_1d": ragged.array(RNG.random(10)), + "numpy_nd": ragged.array(RNG.random((2, 3, 4))), + "ragged_simple": ragged.array( + [RNG.random(3).tolist(), RNG.random(5).tolist(), RNG.random(8).tolist()], + ), + "ragged_simple_nd": ragged.array( + [RNG.random((2, 3, 4)).tolist(), RNG.random((3, 4, 5)).tolist()], + ), + "ragged_complex": ragged.array( + [ + [RNG.random(10).tolist()], + [RNG.random(8).tolist(), []], + [RNG.random(5).tolist(), RNG.random(2).tolist()], + [[], RNG.random(7).tolist()], + ], + ), + "ragged_complex_nd": ragged.array( + [ + [RNG.random((4, 3)).tolist()], + [RNG.random((2, 8)).tolist(), [[]]], + [RNG.random((5, 2)).tolist(), RNG.random((3, 3)).tolist()], + [[[]], RNG.random((7, 1)).tolist()], + ], + ), +} + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_structure(name): + array = arrays[name] + expected_form, expected_len, expected_nodes = ak.to_buffers( + array._impl, # noqa: SLF001 + ) + + structure = RaggedStructure.from_array(array) + form = ak.forms.from_dict(structure.form) + + assert expected_form == form + assert expected_len == structure.shape[0] + assert len(expected_nodes) == len(structure.offsets) + 1 + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_serialization_roundtrip(name): + array = arrays[name] + structure = RaggedStructure.from_array(array) + + # Test JSON serialization. + json_contents = to_json("application/json", array, metadata={}) + array_from_json = from_json( + json_contents, + dtype=array.dtype.type, + offsets=structure.offsets, + shape=structure.shape, + ) + assert ak.array_equal(array._impl, array_from_json._impl) # noqa: SLF001 + + # Test flattened numpy array. + flattened_array = to_flattened_array(array) + array_from_flattened = from_flattened_array( + flattened_array, + dtype=array.dtype.type, + offsets=structure.offsets, + shape=structure.shape, + ) + assert ak.array_equal(array._impl, array_from_flattened._impl) # noqa: SLF001 + + # Test flattened octet-stream serialization. + octet_stream_contents = to_flattened_octet_stream( + "application/octet-stream", array, metadata={} + ) + array_from_octet_stream = from_flattened_octet_stream( + octet_stream_contents, + dtype=array.dtype.type, + offsets=structure.offsets, + shape=structure.shape, + ) + assert ak.array_equal(array._impl, array_from_octet_stream._impl) # noqa: SLF001 + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_slicing(client, name): + # Write data into catalog. + array = arrays[name] + returned = client.write_ragged(array, key="test") + # Test with client returned, and with client from lookup. + for rac in [returned, client["test"]]: + # Read the data back out from the RaggedClient, progressively sliced. + result = rac.read() + # ragged does not have an array_equal(a, b) equivalent. Use awkward. + assert ak.array_equal(result._impl, array._impl) # noqa: SLF001 + + # When sliced, the server sends less data. + with record_history() as h: + full_result = rac[:] + assert ak.array_equal(full_result._impl, array._impl) # noqa: SLF001 + assert len(h.responses) == 1 # sanity check + # full_response_size = len(h.responses[0].content) + # with record_history() as h: + # sliced_result = rac[1] + # assert ak.array_equal(sliced_result._impl, array[1]._impl) # noqa: SLF001 + # assert len(h.responses) == 1 # sanity check + # sliced_response_size = len(h.responses[0].content) + # assert sliced_response_size < full_response_size + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_export_json(tmpdir, client, name): + array = arrays[name] + rac = client.write_ragged(array, key="test") + + filepath = tmpdir / "actual.json" + rac.export(str(filepath), format="application/json") + actual = filepath.read_text(encoding="utf-8") + assert actual == ak.to_json(array._impl) # noqa: SLF001 + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_export_arrow(tmpdir, client, name): + # Write data into catalog. It will be stored as directory of buffers + # named like 'node0-offsets' and 'node2-data'. + array = arrays[name] + rac = client.write_ragged(array, key="test") + + filepath = tmpdir / "actual.arrow" + rac.export(str(filepath), format=APACHE_ARROW_FILE_MIME_TYPE) + actual = pyarrow.feather.read_table(filepath) + expected = ak.to_arrow_table(array._impl) # noqa: SLF001 + assert actual == expected + + +@pytest.mark.parametrize("name", arrays.keys()) +def test_export_parquet(tmpdir, client, name): + # Write data into catalog. It will be stored as directory of buffers + # named like 'node0-offsets' and 'node2-data'. + array = arrays[name] + rac = client.write_ragged(array, key="test") + + filepath = tmpdir / "actual.parquet" + rac.export(str(filepath), format="application/x-parquet") + actual = pyarrow.parquet.read_table(filepath) + expected = ak.to_arrow_table(array._impl) # noqa: SLF001 + assert actual == expected diff --git a/tiled/adapters/array.py b/tiled/adapters/array.py index 3e74c2d01..08b88c019 100644 --- a/tiled/adapters/array.py +++ b/tiled/adapters/array.py @@ -1,13 +1,13 @@ import contextlib -from typing import Any, List, Optional, Tuple +from collections.abc import Sequence +from typing import Any, List, Optional, Tuple, cast import dask.array import numpy import pandas from numpy.typing import NDArray -from tiled.adapters.core import Adapter - +from ..adapters.core import Adapter from ..ndslice import NDSlice from ..structures.array import ArrayStructure from ..structures.core import Spec, StructureFamily @@ -66,7 +66,7 @@ def from_array( if is_array_of_arrays: with contextlib.suppress(ValueError): # only uniform arrays (with same dimensions) are stackable - array = numpy.vstack(array) + array = numpy.vstack(cast("Sequence[numpy.ndarray]", array)) # Convert (experimental) pandas.StringDtype to numpy's unicode string dtype is_likely_string_dtype = isinstance(array.dtype, pandas.StringDtype) or ( diff --git a/tiled/adapters/protocols.py b/tiled/adapters/protocols.py index fd7028f2c..980844dec 100644 --- a/tiled/adapters/protocols.py +++ b/tiled/adapters/protocols.py @@ -4,9 +4,12 @@ import dask.dataframe import pandas +import ragged import sparse from numpy.typing import NDArray +from tiled.structures.ragged import RaggedStructure + from ..ndslice import NDSlice from ..storage import Storage from ..structures.array import ArrayStructure @@ -82,6 +85,22 @@ def write(self, container: DirectoryContainer) -> None: pass +class RaggedAdapter(BaseAdapter, Protocol): + structure_family: Literal[StructureFamily.ragged] + + @abstractmethod + def structure(self) -> RaggedStructure: + pass + + @abstractmethod + def read(self, slice: NDSlice) -> ragged.array: + pass + + @abstractmethod + def read_block(self, block: Tuple[int, ...]) -> ragged.array: + pass + + class SparseAdapter(BaseAdapter, Protocol): structure_family: Literal[StructureFamily.sparse] = StructureFamily.sparse @@ -127,5 +146,10 @@ def __getitem__(self, key: str) -> ArrayAdapter: AnyAdapter = Union[ - ArrayAdapter, AwkwardAdapter, ContainerAdapter, SparseAdapter, TableAdapter + ArrayAdapter, + AwkwardAdapter, + ContainerAdapter, + RaggedAdapter, + SparseAdapter, + TableAdapter, ] diff --git a/tiled/adapters/ragged.py b/tiled/adapters/ragged.py new file mode 100644 index 000000000..400c6bd39 --- /dev/null +++ b/tiled/adapters/ragged.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import numpy as np +import ragged + +from tiled.adapters.core import Adapter +from tiled.adapters.utils import init_adapter_from_catalog +from tiled.catalog.orm import Node +from tiled.ndslice import NDSlice +from tiled.structures.core import Spec, StructureFamily +from tiled.structures.data_source import DataSource +from tiled.structures.ragged import RaggedStructure + +if TYPE_CHECKING: + from collections.abc import Iterable + + import awkward + from numpy.typing import NDArray + + from tiled.type_aliases import JSON + + +class RaggedAdapter(Adapter[RaggedStructure]): + structure_family = StructureFamily.ragged + + def __init__( + self, + array: ragged.array | None, + structure: RaggedStructure, + metadata: JSON | None = None, + specs: list[Spec] | None = None, + ) -> None: + """ + + Parameters + ---------- + array : + structure : + metadata : + specs : + """ + self._array = array + self._structure = structure + self._metadata = metadata or {} + self.specs = list(specs or []) + + @classmethod + def from_catalog( + cls, + data_source: DataSource[RaggedStructure], + node: Node, + /, + **kwargs: Any | None, + ) -> Self: + return init_adapter_from_catalog(cls, data_source, node, **kwargs) + + @classmethod + def from_array( + cls, + array: ragged.array | awkward.Array | NDArray[Any] | Iterable[Iterable[Any]], + metadata: JSON | None = None, + specs: list[Spec] | None = None, + ) -> Self: + """ + + Parameters + ---------- + array : + metadata : + specs : + + Returns + ------- + + """ + array = ( + ragged.array(list(array)) + if isinstance(array, np.ndarray) + else ragged.asarray(array) + ) + + structure = RaggedStructure.from_array(array) + return cls( + array, + structure, + metadata=metadata, + specs=specs, + ) + + def read( + self, + slice: NDSlice = NDSlice(...), + ) -> ragged.array: + """ + + Parameters + ---------- + slice : + + Returns + ------- + + """ + if self._array is None: + raise NotImplementedError + # _array[...] requires an actual tuple, not just a subclass of tuple + return self._array[tuple(slice)] if slice else self._array + + def write( + self, + array: ragged.array, + ) -> None: + raise NotImplementedError + + def metadata(self) -> JSON: + return self._metadata + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._structure})" + + def structure(self) -> RaggedStructure: + return self._structure diff --git a/tiled/adapters/ragged_npy_store.py b/tiled/adapters/ragged_npy_store.py new file mode 100644 index 000000000..4827fa98c --- /dev/null +++ b/tiled/adapters/ragged_npy_store.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import copy +from urllib.parse import quote_plus + +import numpy +import ragged + +from tiled.adapters.ragged import RaggedAdapter +from tiled.adapters.resource_cache import with_resource_cache +from tiled.ndslice import NDSlice +from tiled.serialization.ragged import from_flattened_array, to_flattened_array +from tiled.storage import FileStorage, Storage +from tiled.structures.core import Spec +from tiled.structures.data_source import Asset, DataSource +from tiled.structures.ragged import RaggedStructure +from tiled.type_aliases import JSON +from tiled.utils import path_from_uri + + +class RaggedNPYAdapter(RaggedAdapter): + def __init__( + self, + data_uri: str, + structure: RaggedStructure, + metadata: JSON | None = None, + specs: list[Spec] | None = None, + ) -> None: + super().__init__(None, structure, metadata, specs) + self._filepath = path_from_uri(data_uri) + + @classmethod + def supported_storage(cls) -> set[type[Storage]]: + return {FileStorage} + + @classmethod + def init_storage( + cls, + storage: Storage, + data_source: DataSource[RaggedStructure], + path_parts: list[str], + ) -> DataSource[RaggedStructure]: + """ + + Parameters + ---------- + data_uri : + structure : + + Returns + ------- + + """ + data_source = copy.deepcopy(data_source) # Do not mutate caller input. + data_uri = storage.uri + "".join( + f"/{quote_plus(segment)}" for segment in path_parts + ) + directory = path_from_uri(data_uri) + directory.mkdir(parents=True, exist_ok=True) + data_source.assets.append( + Asset( + data_uri=f"{data_uri}/ragged-data.npy", + is_directory=False, + parameter="data_uri", + ), + ) + return data_source + + def read(self, slice: NDSlice = NDSlice(...)) -> ragged.array: + cache_key = (numpy.load, self._filepath) + data = with_resource_cache(cache_key, numpy.load, self._filepath) + array = from_flattened_array( + data, + self._structure.data_type.to_numpy_dtype(), + self._structure.offsets, + self._structure.shape, + ) + return array[slice] if slice else array + + def write( + self, + array: ragged.array, + ) -> None: + """ + + Parameters + ---------- + container : + + Returns + ------- + + """ + data = to_flattened_array(array) + numpy.save(self._filepath, data) diff --git a/tiled/adapters/sql.py b/tiled/adapters/sql.py index d2b08280e..bb52940e6 100644 --- a/tiled/adapters/sql.py +++ b/tiled/adapters/sql.py @@ -1,9 +1,21 @@ import copy import hashlib +import logging import re from collections.abc import Set from contextlib import closing -from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple, Union, cast +from typing import ( + Any, + Callable, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) import numpy import pandas @@ -25,6 +37,8 @@ from ..structures.table import TableStructure from ..type_aliases import JSON from .array import ArrayAdapter +from .awkward import AwkwardAdapter +from .ragged import RaggedAdapter from .utils import init_adapter_from_catalog DIALECTS = Literal["postgresql", "sqlite", "duckdb"] @@ -56,6 +70,8 @@ # e.g. "A" and "a" will raise an error. # Furthermore, user-specified table names can only be in lower case. +logger = logging.getLogger(__name__) + class SQLAdapter(Adapter[TableStructure]): """SQLAdapter Class @@ -216,7 +232,7 @@ def structure(self) -> TableStructure: """ return self._structure - def get(self, key: str) -> Union[ArrayAdapter, None]: + def get(self, key: str) -> Union[ArrayAdapter, AwkwardAdapter, RaggedAdapter, None]: """Get the data for a specific key Parameters @@ -231,7 +247,9 @@ def get(self, key: str) -> Union[ArrayAdapter, None]: return None return self[key] - def __getitem__(self, key: str) -> ArrayAdapter: + def __getitem__( + self, key: str + ) -> Union[ArrayAdapter, AwkwardAdapter, RaggedAdapter]: """Get the data for a specific key. Parameters @@ -244,9 +262,46 @@ def __getitem__(self, key: str) -> ArrayAdapter: """ # Must compute to determine shape. - return ArrayAdapter.from_array(self.read([key])[key].values) + array = self.read([key])[key].infer_objects().to_numpy() + if array.dtype.name != "object": + return ArrayAdapter.from_array(array) + + if ( + array.dtype.name == "object" + and len(array) + and isinstance(array[0], numpy.ndarray) + ): + # accumulate errors until an attempt succeeds + errors: List[Exception] = [] + + try: + array = numpy.vstack(cast("Sequence[numpy.ndarray]", array)) + return ArrayAdapter.from_array(array) + except ValueError as err: + errors.append(err) + + try: + return RaggedAdapter.from_array(array) + except Exception as err: + errors.append(err) + + try: + return AwkwardAdapter.from_array(array) + except Exception as err: + errors.append(err) + + logger.error( + "No adapter found that accepts object-array at key %s. (%s)", + key, + errors, + ) + # fallback to string representation conversion in ArrayAdapter - def items(self) -> Iterator[Tuple[str, ArrayAdapter]]: + return ArrayAdapter.from_array(array) + + def items( + self, + ) -> Iterator[Tuple[str, Union[ArrayAdapter, AwkwardAdapter, RaggedAdapter]]]: """Iterate over the SQLAdapter data. Returns diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index abb6aa465..85c8d358e 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -63,6 +63,7 @@ AWKWARD_BUFFERS_MIMETYPE, DEFAULT_ADAPTERS_BY_MIMETYPE, PARQUET_MIMETYPE, + RAGGED_MIMETYPE, SPARSE_BLOCKS_PARQUET_MIMETYPE, TILED_SQL_TABLE_MIMETYPE, ZARR_MIMETYPE, @@ -110,6 +111,7 @@ StructureFamily.array: ZARR_MIMETYPE, StructureFamily.awkward: AWKWARD_BUFFERS_MIMETYPE, StructureFamily.table: PARQUET_MIMETYPE, + StructureFamily.ragged: RAGGED_MIMETYPE, StructureFamily.sparse: SPARSE_BLOCKS_PARQUET_MIMETYPE, } @@ -137,6 +139,9 @@ TILED_SQL_TABLE_MIMETYPE: lambda: importlib.import_module( "...adapters.sql", __name__ ).SQLAdapter, + RAGGED_MIMETYPE: lambda: importlib.import_module( + "...adapters.ragged_npy_store", __name__ + ).RaggedNPYAdapter, } ) @@ -1206,6 +1211,11 @@ async def write(self, media_type, deserializer, entry, body, persist=True): data = await ensure_awaitable(deserializer, body, dtype, shape) elif entry.structure_family == "sparse": data = await ensure_awaitable(deserializer, body) + elif entry.structure_family == "ragged": + dtype = entry.structure().data_type.to_numpy_dtype() + offsets = entry.structure().offsets + shape = entry.structure().shape + data = await ensure_awaitable(deserializer, body, dtype, offsets, shape) else: raise NotImplementedError(entry.structure_family) return await ensure_awaitable((await self.get_adapter()).write, data) @@ -1227,6 +1237,11 @@ async def write_block( data = await ensure_awaitable(deserializer, body, dtype, shape) elif entry.structure_family == "sparse": data = await ensure_awaitable(deserializer, body) + elif entry.structure_family == "ragged": + dtype = entry.structure().data_type.to_numpy_dtype() + offsets = entry.structure().offsets + shape = entry.structure().shape + data = await ensure_awaitable(deserializer, body, dtype, offsets, shape) else: raise NotImplementedError(entry.structure_family) return await ensure_awaitable( @@ -1286,6 +1301,10 @@ async def write(self, *args, **kwargs): return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs) +class CatalogRaggedAdapter(CatalogArrayAdapter): + pass + + class CatalogSparseAdapter(CatalogArrayAdapter): pass @@ -1859,6 +1878,7 @@ def node_from_segments(segments, root_id=0): StructureFamily.array: CatalogArrayAdapter, StructureFamily.awkward: CatalogAwkwardAdapter, StructureFamily.container: CatalogContainerAdapter, + StructureFamily.ragged: CatalogRaggedAdapter, StructureFamily.sparse: CatalogSparseAdapter, StructureFamily.table: CatalogTableAdapter, } diff --git a/tiled/client/container.py b/tiled/client/container.py index 6c19aba39..e19336beb 100644 --- a/tiled/client/container.py +++ b/tiled/client/container.py @@ -973,6 +973,64 @@ def write_awkward( client.write(container) return client + def write_ragged( + self, + array, + *, + key=None, + metadata=None, + dims=None, + specs=None, + access_tags=None, + ): + import ragged + + from tiled.structures.ragged import RaggedStructure + + if not (hasattr(array, "shape") and hasattr(array, "dtype")): + # This does not implement enough of the array-like interface. + # Coerce to numpy-like ragged array. + array = ( + ragged.array(array, dtype=array.dtype) + if hasattr(array, "dtype") + else ragged.array(array) + ) + + # TODO + from dask.array.core import normalize_chunks + + if hasattr(array, "chunks"): + chunks = normalize_chunks( + array.chunks, + limit=self._SUGGESTED_MAX_UPLOAD_SIZE, + dtype=array.dtype, + shape=None, + ) + else: + chunks = normalize_chunks( + tuple("auto" for _ in array.shape), + limit=self._SUGGESTED_MAX_UPLOAD_SIZE, + dtype=array.dtype, + shape=tuple(d if d is not None else array.size for d in array.shape), + ) + + structure = RaggedStructure.from_array(array, chunks=chunks, dims=dims) + + client = self.new( + StructureFamily.ragged, + [ + DataSource( + structure=structure, structure_family=StructureFamily.ragged + ), + ], + key=key, + metadata=metadata, + specs=specs, + access_tags=access_tags, + ) + client.write(array) + return client + def write_sparse( self, coords, @@ -1298,6 +1356,7 @@ def _write_partition(x, partition_info, client): "dataframe": _LazyLoad( ("..dataframe", Container.__module__), "DataFrameClient" ), + "ragged": _LazyLoad(("..ragged", Container.__module__), "RaggedClient"), "sparse": _LazyLoad(("..sparse", Container.__module__), "SparseClient"), "table": _LazyLoad( ("..dataframe", Container.__module__), "DataFrameClient" diff --git a/tiled/client/ragged.py b/tiled/client/ragged.py new file mode 100644 index 000000000..a92d806fb --- /dev/null +++ b/tiled/client/ragged.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import parse_qs, urlparse + +import ragged + +from tiled.client.base import BaseClient +from tiled.client.utils import chunks_repr, export_util, handle_error, retry_context +from tiled.ndslice import NDSlice +from tiled.serialization.ragged import ( + from_flattened_octet_stream, + to_flattened_octet_stream, +) + +if TYPE_CHECKING: + import awkward as ak + + from tiled.structures.ragged import RaggedStructure + + +class RaggedClient(BaseClient): + def write(self, array: ragged.array | ak.Array | list[list]): + array = ( + ragged.array(array, dtype=array.dtype) + if hasattr(array, "dtype") + else ragged.array(array) + ) + mimetype = "application/octet-stream" + for attempt in retry_context(): + with attempt: + handle_error( + self.context.http_client.put( + self.item["links"]["full"], + content=to_flattened_octet_stream( + mimetype=mimetype, + array=array, + metadata={}, + ), + headers={"Content-Type": mimetype}, + ), + ) + + def write_block(self, block: int, array: ragged.array | ak.Array | list[list]): + # TODO: investigate + raise NotImplementedError + + def read(self, slice: NDSlice | None = None) -> ragged.array: + structure = cast("RaggedStructure", self.structure()) + url_path = self.item["links"]["full"] + url_params: dict[str, Any] = {**parse_qs(urlparse(url_path).query)} + if isinstance(slice, NDSlice): + url_params["slice"] = slice.to_numpy_str() + for attempt in retry_context(): + with attempt: + content = handle_error( + self.context.http_client.get( + url_path, + headers={"Accept": "application/octet-stream"}, + params=url_params, + ), + ).read() + # shape = ( + # reshape_from_slice(structure.shape, slice) + # if isinstance(slice, NDSlice) + # else structure.shape + # ) + shape = structure.shape + return from_flattened_octet_stream( + buffer=content, + dtype=structure.data_type.to_numpy_dtype(), + offsets=structure.offsets, + shape=shape, + ) + + def read_block(self, block: int, slice: NDSlice | None = None) -> ragged.array: + # TODO: investigate + raise NotImplementedError + + def __getitem__( + self, slice: NDSlice + ) -> ragged.array: # this is true even when slicing to return a single item + # TODO: should we be smarter, and return the scalar rather a singular array + return self.read(slice=NDSlice(slice)) + + def export(self, filepath, *, format=None): + return export_util( + filepath, + format, + self.context.http_client.get, + self.item["links"]["full"], + params={}, + ) + + @property + def dims(self): + structure = cast("RaggedStructure", self.structure()) + return structure.dims + + @property + def shape(self): + structure = cast("RaggedStructure", self.structure()) + return structure.shape + + @property + def size(self): + structure = cast("RaggedStructure", self.structure()) + return structure.size + + @property + def dtype(self): + structure = cast("RaggedStructure", self.structure()) + return structure.data_type.to_numpy_dtype() + + @property + def nbytes(self): + structure = cast("RaggedStructure", self.structure()) + itemsize = structure.data_type.to_numpy_dtype().itemsize + return structure.size * itemsize + + @property + def chunks(self): + structure = cast("RaggedStructure", self.structure()) + return structure.chunks + + @property + def ndim(self): + structure = cast("RaggedStructure", self.structure()) + return len(structure.shape) + + def __repr__(self): + structure = cast("RaggedStructure", self.structure()) + attrs = { + "shape": structure.shape, + "chunks": chunks_repr(structure.chunks), + "dtype": structure.data_type.to_numpy_dtype(), + } + if structure.dims: + attrs["dims"] = structure.dims + return ( + f"<{type(self).__name__}" + + "".join(f" {k}={v}" for k, v in attrs.items()) + + ">" + ) + + +def reshape_from_slice( + _shape: tuple[int | None, ...], + _slice: NDSlice | None, +) -> tuple[int | None, ...]: + if not _slice: + return _shape + new_shape = [] + for dim_size, dim_slice in zip(_shape, _slice): + if isinstance(dim_slice, slice): + if dim_size is None: + new_shape.append(None) + else: + start, stop, step = dim_slice.indices(dim_size) + length = max(0, (stop - start + (step - 1)) // step) + new_shape.append(length) + # elif dim_slice == Ellipsis: + # remaining_dims = len(_shape) - len(_slice) + 1 + # new_shape.extend(_shape[len(new_shape) : len(new_shape) + remaining_dims]) + else: + new_shape.append(1) + return tuple(new_shape) diff --git a/tiled/links.py b/tiled/links.py index b63aa6764..b1fe195e0 100644 --- a/tiled/links.py +++ b/tiled/links.py @@ -37,6 +37,14 @@ def links_for_container(structure_family, structure, base_url, path_str): return links +def links_for_ragged(structure_family, structure, base_url, path_str): + links = {} + block_template = ",".join(f"{{{index}}}" for index in range(len(structure.shape))) + links["full"] = f"{base_url}/ragged/full/{path_str}" + links["block"] = f"{base_url}/ragged/block/{path_str}?block={block_template}" + return links + + def links_for_table(structure_family, structure, base_url, path_str): links = {} links["partition"] = f"{base_url}/table/partition/{path_str}?partition={{index}}" @@ -48,6 +56,7 @@ def links_for_table(structure_family, structure, base_url, path_str): StructureFamily.array: links_for_array, StructureFamily.awkward: links_for_awkward, StructureFamily.container: links_for_container, + StructureFamily.ragged: links_for_ragged, StructureFamily.sparse: links_for_array, # spare and array are the same StructureFamily.table: links_for_table, } diff --git a/tiled/mimetypes.py b/tiled/mimetypes.py index 4e13f9497..ca35bdf91 100644 --- a/tiled/mimetypes.py +++ b/tiled/mimetypes.py @@ -11,6 +11,7 @@ SPARSE_BLOCKS_PARQUET_MIMETYPE = "application/x-parquet;structure=sparse" ZARR_MIMETYPE = "application/x-zarr" AWKWARD_BUFFERS_MIMETYPE = "application/x-awkward-buffers" +RAGGED_MIMETYPE = "application/x-ragged" TILED_SQL_TABLE_MIMETYPE = "application/x-tiled-sql-table" # TODO: make type[Adapter] after #1047 DEFAULT_ADAPTERS_BY_MIMETYPE = OneShotCachedMap[str, type]( @@ -67,6 +68,9 @@ AWKWARD_BUFFERS_MIMETYPE: lambda: importlib.import_module( "..adapters.awkward_buffers", __name__ ).AwkwardBuffersAdapter, + RAGGED_MIMETYPE: lambda: importlib.import_module( + "..adapters.ragged_npy_store", __name__ + ).RaggedNPYAdapter, APACHE_ARROW_FILE_MIME_TYPE: lambda: importlib.import_module( "..adapters.arrow", __name__ ).ArrowAdapter, @@ -79,8 +83,7 @@ DEFAULT_REGISTRATION_ADAPTERS_BY_MIMETYPE = copy.deepcopy(DEFAULT_ADAPTERS_BY_MIMETYPE) DEFAULT_REGISTRATION_ADAPTERS_BY_MIMETYPE.set( - "text/csv", - lambda: importlib.import_module("..adapters.csv", __name__).CSVAdapter, + "text/csv", lambda: importlib.import_module("..adapters.csv", __name__).CSVAdapter ) diff --git a/tiled/serialization/__init__.py b/tiled/serialization/__init__.py index 368e09a14..19a9b17a7 100644 --- a/tiled/serialization/__init__.py +++ b/tiled/serialization/__init__.py @@ -22,6 +22,10 @@ def register_builtin_serializers(): from ..serialization import table as _table # noqa: F401 del _table + if modules_available("ragged"): + from ..serialization import ragged as _ragged # noqa: F401 + + del _ragged if modules_available("sparse"): from ..serialization import sparse as _sparse # noqa: F401 diff --git a/tiled/serialization/ragged.py b/tiled/serialization/ragged.py new file mode 100644 index 000000000..08d15e97d --- /dev/null +++ b/tiled/serialization/ragged.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import awkward +import numpy as np +import orjson +import ragged + +from tiled.media_type_registration import ( + default_deserialization_registry, + default_serialization_registry, +) +from tiled.mimetypes import APACHE_ARROW_FILE_MIME_TYPE, PARQUET_MIMETYPE +from tiled.serialization import awkward as awkward_serialization +from tiled.structures.core import StructureFamily +from tiled.utils import modules_available, safe_json_dump + + +@default_serialization_registry.register(StructureFamily.ragged, "application/json") +def to_json( + mimetype: str, array: ragged.array, metadata: dict # noqa: ARG001 +) -> bytes: + return safe_json_dump(array.tolist()) + + +@default_deserialization_registry.register(StructureFamily.ragged, "application/json") +def from_json( + contents: str | bytes, + dtype: type, + offsets: list[list[int]], + shape: tuple[int | None, ...], +) -> ragged.array: + lists_of_lists = orjson.loads(contents) + if all(shape) and not any(offsets): + # No raggedness, but array is not strictly N-D. Map to numpy array first. + # Otherwise, it will infer an offset array of type='x0 * Any * ... * Any * dtype' + # rather than a simple numpy array of type='x0 * x1 * ... * xN * dtype'. + return ragged.array(np.array(lists_of_lists, dtype=dtype)) + return ragged.array(lists_of_lists, dtype=dtype) + + +def to_flattened_array(array: ragged.array) -> np.ndarray: + content = array._impl.layout # noqa: SLF001 + while isinstance( + content, (awkward.contents.ListOffsetArray, awkward.contents.ListArray) + ): + content = content.content + return awkward.to_numpy(content) + + +@default_serialization_registry.register( + StructureFamily.ragged, "application/octet-stream" +) +def to_flattened_octet_stream( + mimetype: str, array: ragged.array, metadata: dict # noqa: ARG001 +) -> bytes: + return np.asarray(to_flattened_array(array)).tobytes() + + +def from_flattened_array( + array: np.ndarray, + dtype: type, + offsets: list[list[int]], + shape: tuple[int | None, ...], +) -> ragged.array: + if all(shape) and not any(offsets): + # No raggedness, but need to reshape the flat array + return ragged.array(array.reshape(shape), dtype=dtype) + # return ragged.reshape(ragged.array(array, dtype=dtype), shape) + + def rebuild(offsets: list[list[int]]) -> awkward.contents.Content: + nonlocal array + if not offsets: + return awkward.contents.NumpyArray(array.tolist()) + return awkward.contents.ListOffsetArray( + offsets=awkward.index.Index(offsets[0]), content=rebuild(offsets[1:]) + ) + + return ragged.array(rebuild(offsets), dtype=dtype) + + +@default_deserialization_registry.register( + StructureFamily.ragged, "application/octet-stream" +) +def from_flattened_octet_stream( + buffer: bytes, dtype: type, offsets: list[list[int]], shape: tuple[int | None, ...] +) -> ragged.array: + return from_flattened_array( + np.frombuffer(buffer, dtype=dtype), dtype, offsets, shape + ) + + +if modules_available("pyarrow"): + + @default_serialization_registry.register( + StructureFamily.ragged, APACHE_ARROW_FILE_MIME_TYPE + ) + def to_arrow(mimetype: str, array: ragged.array, metadata: dict): + components = awkward.to_buffers(array._impl) # noqa: SLF001 + return awkward_serialization.to_arrow(mimetype, components, metadata) + + @default_serialization_registry.register(StructureFamily.ragged, PARQUET_MIMETYPE) + def to_parquet(mimetype: str, array: ragged.array, metadata: dict): + components = awkward.to_buffers(array._impl) # noqa: SLF001 + return awkward_serialization.to_parquet(mimetype, components, metadata) diff --git a/tiled/server/core.py b/tiled/server/core.py index dd699e2e9..c11a99ee0 100644 --- a/tiled/server/core.py +++ b/tiled/server/core.py @@ -317,6 +317,7 @@ async def construct_entries_response( DEFAULT_MEDIA_TYPES = { StructureFamily.array: {"*/*": "application/octet-stream", "image/*": "image/png"}, StructureFamily.awkward: {"*/*": "application/zip"}, + StructureFamily.ragged: {"*/*": "application/zip"}, StructureFamily.table: {"*/*": APACHE_ARROW_FILE_MIME_TYPE}, StructureFamily.container: {"*/*": "application/x-hdf5"}, StructureFamily.sparse: {"*/*": APACHE_ARROW_FILE_MIME_TYPE}, diff --git a/tiled/server/router.py b/tiled/server/router.py index fee73b651..30e63d0e2 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -745,6 +745,236 @@ async def websocket_endpoint( handler = entry.make_ws_handler(websocket, formatter, uri) await handler(start) + @router.get( + "/ragged/full/{path:path}", response_model=schemas.Response, name="ragged full" + ) + async def get_ragged_full( + request: Request, + path: str, + slice=Depends(NDSlice.from_query), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + principal: Optional[Principal] = Depends(get_current_principal), + root_tree=Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["read:data"]), + ): + entry = await get_entry( + path=path, + security_scopes=["read:data"], + principal=principal, + authn_access_tags=authn_access_tags, + authn_scopes=authn_scopes, + root_tree=root_tree, + session_state=session_state, + metrics=request.state.metrics, + structure_families={StructureFamily.ragged}, + access_policy=getattr(request.app.state, "access_policy", None), + ) + structure_family = entry.structure_family + + import ragged + + with record_timing(request.state.metrics, "read"): + ragged_array: ragged.array = await ensure_awaitable(entry.read, slice) + + if ragged_array._impl.nbytes > settings.response_bytesize_limit: # noqa: SLF001 + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), + ) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + ragged_array[slice], + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException( + status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0] + ) from err + + @router.get( + "/ragged/block/{path:path}", + response_model=schemas.Response, + name="ragged block", + ) + async def get_ragged_block( + request: Request, + path: str, + block=Depends(block), + slice=Depends(NDSlice.from_query), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + principal: Optional[Principal] = Depends(get_current_principal), + root_tree=Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["read:data"]), + ): + raise NotImplementedError + + @router.put("/ragged/full/{path:path}") + async def put_ragged_full( + request: Request, + path: str, + persist: bool = Query(True, description="Persist data to storage"), + principal: Optional[Principal] = Depends(get_current_principal), + root_tree=Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["write:data"]), + ): + entry = await get_entry( + path, + ["write:data"], + principal, + authn_access_tags, + authn_scopes, + root_tree, + session_state, + request.state.metrics, + {StructureFamily.ragged}, + getattr(request.app.state, "access_policy", None), + ) + body = await request.body() + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + media_type = request.headers["content-type"] + if entry.structure_family == "ragged": + deserializer = deserialization_registry.dispatch("ragged", media_type) + else: + raise NotImplementedError(entry.structure_family) + await ensure_awaitable( + entry.write, media_type, deserializer, entry, body, persist + ) + return json_or_msgpack(request, None) + + @router.put("/ragged/block/{path:path}") + async def put_ragged_block( + request: Request, + path: str, + block=Depends(block), + persist: bool = Query(True, description="Persist data to storage"), + principal: Optional[Principal] = Depends(get_current_principal), + root_tree=Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["write:data"]), + ): + raise NotImplementedError + entry = await get_entry( + path, + ["write:data"], + principal, + authn_access_tags, + authn_scopes, + root_tree, + session_state, + request.state.metrics, + {StructureFamily.ragged}, + getattr(request.app.state, "access_policy", None), + ) + if not hasattr(entry, "write_block"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + entry.structure_family, media_type + ) + await ensure_awaitable( + entry.write_block, block, media_type, deserializer, entry, body, persist + ) + return json_or_msgpack(request, None) + + @router.patch("/ragged/full/{path:path}") + async def patch_ragged_full( + request: Request, + path: str, + offset=Depends(offset_param), + shape=Depends(shape_param), + extend: bool = False, + persist: bool = Query(True, description="Persist data to storage"), + principal: Optional[Principal] = Depends(get_current_principal), + root_tree=Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["write:data"]), + ): + if extend and not persist: + bad_args_message = ( + "Cannot PATCH an array with both parameters" + " extend=True and persist=False." + " To extend the array, you must persist the changes." + " To skip persisting the changes, you must not extend the array." + ) + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=bad_args_message, + ) + entry = await get_entry( + path, + ["write:data"], + principal, + authn_access_tags, + authn_scopes, + root_tree, + session_state, + request.state.metrics, + {StructureFamily.ragged}, + getattr(request.app.state, "access_policy", None), + ) + if not hasattr(entry, "patch"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + entry.structure_family, media_type + ) + structure = await ensure_awaitable( + entry.patch, + shape, + offset, + extend, + media_type, + deserializer, + entry, + body, + persist, + ) + return json_or_msgpack(request, structure) + @router.get( "/table/partition/{path:path}", response_model=schemas.Response, diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index 7271dbe64..3d0bb26fa 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -10,6 +10,8 @@ from pydantic_core import PydanticCustomError from typing_extensions import Annotated, TypedDict +from tiled.structures.ragged import RaggedStructure + from ..structures.array import ArrayStructure from ..structures.awkward import AwkwardStructure from ..structures.core import STRUCTURE_TYPES, Spec, StructureFamily @@ -181,6 +183,7 @@ class NodeAttributes(pydantic.BaseModel): Union[ ArrayStructure, AwkwardStructure, + RaggedStructure, SparseStructure, NodeStructure, TableStructure, @@ -227,6 +230,12 @@ class DataFrameLinks(pydantic.BaseModel): partition: str +class RaggedLinks(pydantic.BaseModel): + self: str + full: str + block: str + + class SparseLinks(pydantic.BaseModel): self: str full: str @@ -237,6 +246,7 @@ class SparseLinks(pydantic.BaseModel): StructureFamily.array: ArrayLinks, StructureFamily.awkward: AwkwardLinks, StructureFamily.container: ContainerLinks, + StructureFamily.ragged: RaggedLinks, StructureFamily.sparse: SparseLinks, StructureFamily.table: DataFrameLinks, } diff --git a/tiled/structures/core.py b/tiled/structures/core.py index 38f3fb020..f24a60ce0 100644 --- a/tiled/structures/core.py +++ b/tiled/structures/core.py @@ -19,6 +19,7 @@ class StructureFamily(str, enum.Enum): array = "array" awkward = "awkward" container = "container" + ragged = "ragged" sparse = "sparse" table = "table" @@ -64,6 +65,9 @@ def dict(self) -> Dict[str, Optional[str]]: StructureFamily.sparse: lambda: importlib.import_module( "...structures.sparse", StructureFamily.__module__ ).SparseStructure, + StructureFamily.ragged: lambda: importlib.import_module( + "...structures.ragged", StructureFamily.__module__ + ).RaggedStructure, StructureFamily.container: lambda: importlib.import_module( "...structures.container", StructureFamily.__module__ ).ContainerStructure, diff --git a/tiled/structures/ragged.py b/tiled/structures/ragged.py new file mode 100644 index 000000000..6cd2851a0 --- /dev/null +++ b/tiled/structures/ragged.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import awkward +import numpy as np +import ragged + +from tiled.structures.array import ArrayStructure, BuiltinDtype, StructDtype + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + + +@dataclass(kw_only=True) +class RaggedStructure(ArrayStructure): + shape: tuple[int | None, ...] # type: ignore[reportIncompatibleVariableOverride] + offsets: list[list[int]] + size: int + + @classmethod + def from_array( + cls, + array: Iterable, + shape: tuple[int | None, ...] | None = None, + chunks: tuple[str, ...] | None = None, + dims: int | None = None, + ) -> Self: + if not isinstance(array, ragged.array): + array = ( + ragged.asarray(array.tolist()) + if hasattr(array, "tolist") + else ragged.array(list(array)) + ) + + if shape is None: + shape = array.shape + if chunks is None: + chunks = ("auto",) * len(shape) + + if array.dtype.fields is not None: + data_type = StructDtype.from_numpy_dtype(array.dtype) + else: + data_type = BuiltinDtype.from_numpy_dtype(array.dtype) + + content = array._impl.layout # noqa: SLF001 + offsets = [] + + while isinstance( + content, (awkward.contents.ListOffsetArray, awkward.contents.ListArray) + ): + if isinstance(content, awkward.contents.ListOffsetArray): + offsets.append(np.array(content.offsets).tolist()) + content = content.content + + size = int(array.size) # should never not be an int + + return cls( + data_type=data_type, + chunks=chunks, + shape=shape, + dims=dims, + resizable=False, + offsets=offsets, + size=size, + ) + + @classmethod + def from_json(cls, structure: Mapping[str, Any]) -> Self: + if "fields" in structure["data_type"]: + data_type = StructDtype.from_json(structure["data_type"]) + else: + data_type = BuiltinDtype.from_json(structure["data_type"]) + dims = structure["dims"] + if dims is not None: + dims = tuple(dims) + return cls( + data_type=data_type, + chunks=tuple(map(tuple, structure["chunks"])), + shape=tuple(structure["shape"]), + dims=dims, + resizable=structure.get("resizable", False), + offsets=structure.get("offsets", []), + size=structure["size"], + ) + + @property + def npartitions(self) -> int: + return 1 + + @property + def form(self) -> dict[str, Any]: + def build(depth: int): + if depth <= 0: + # TODO: Handle EmptyArray, e.g. ragged.array([[], []]) + return { + "class": "NumpyArray", + "primitive": self.data_type.to_numpy_dtype().name, + "form_key": f"node{len(self.offsets) - depth}", + } + return { + "class": "ListOffsetArray", + "offsets": "i64", + "content": build(depth - 1), + "form_key": f"node{len(self.offsets) - depth}", + } + + return build(len(self.offsets))