Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/fd5/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,34 @@ def __getattr__(self, name: str) -> Any:
def __setattr__(self, name: str, value: Any) -> None:
setattr(self._group, name, value)

def __setitem__(self, key: str, value: Any) -> None:
self._group[key] = value

def __contains__(self, item: str) -> bool:
return item in self._group

def __getitem__(self, key: str) -> Any:
return self._group[key]

def __iter__(self):
return iter(self._group)

def __len__(self) -> int:
return len(self._group)

def keys(self): # noqa: D102 — delegates to h5py.Group
return self._group.keys()

def values(self): # noqa: D102
return self._group.values()

def items(self): # noqa: D102
return self._group.items()

def require_group(self, name: str) -> "_HashTrackingGroup":
grp = self._group.require_group(name)
return _HashTrackingGroup(grp, self._data_hash_cache, self._chunk_digest_cache)


def _iter_chunks(
arr: np.ndarray, chunk_shape: tuple[int, ...], hasher: ChunkHasher
Expand Down Expand Up @@ -189,6 +211,25 @@ def write_extra(self, data: dict[str, Any]) -> None:
)
dict_to_h5(grp, data)

def write_dataset(self, path: str, data: Any, **kwargs: Any) -> h5py.Dataset:
"""Write a dataset with inline hash tracking.

Convenience for writing individual datasets outside of
``ProductSchema.write()``. The dataset data is hashed inline
for the Merkle tree.
"""
tracking = _HashTrackingGroup(
self._file, self._data_hash_cache, self._chunk_digest_cache
)
parts = path.strip("/").split("/")
target = tracking
for part in parts[:-1]:
if part not in target:
target = target.create_group(part)
else:
target = target.require_group(part)
return target.create_dataset(parts[-1], data=data, **kwargs)

def write_product(self, data: Any) -> None:
"""Delegate product-specific writes to the registered ProductSchema.

Expand Down
175 changes: 174 additions & 1 deletion tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def out_dir(tmp_path: Path) -> Path:
# ---------------------------------------------------------------------------


from fd5.create import Fd5Builder, Fd5ValidationError, create # noqa: E402
from fd5.create import ( # noqa: E402
Fd5Builder,
Fd5ValidationError,
_HashTrackingGroup,
create,
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -790,6 +795,174 @@ def test_mixed_chunked_and_non_chunked(self, out_dir: Path):
assert stored == second_pass


# ---------------------------------------------------------------------------
# write_dataset() convenience method
# ---------------------------------------------------------------------------


class TestWriteDataset:
def test_write_dataset_creates_dataset(self, out_dir: Path):
"""write_dataset() creates datasets with inline hashing."""
with create(
out_dir,
product="test/product",
name="conv",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
b.write_product(np.arange(10, dtype=np.float32))
ds = b.write_dataset(
"extra_data",
data=np.ones((50, 50), dtype=np.float32),
chunks=(10, 10),
)
assert ds.shape == (50, 50)

files = list(out_dir.glob("*.h5"))
assert len(files) == 1
assert verify(str(files[0])) is True

def test_write_dataset_nested_path(self, out_dir: Path):
"""write_dataset() creates intermediate groups for nested paths."""
with create(
out_dir,
product="test/product",
name="nested",
description="Test nested",
timestamp="2026-01-01T00:00:00Z",
) as b:
b.write_product(np.arange(10, dtype=np.float32))
ds = b.write_dataset(
"group_a/group_b/values",
data=np.zeros((20,), dtype=np.float64),
chunks=(5,),
)
assert ds.shape == (20,)
assert "group_a" in b.file
assert "group_b" in b.file["group_a"]
assert "values" in b.file["group_a/group_b"]

files = list(out_dir.glob("*.h5"))
assert len(files) == 1
assert verify(str(files[0])) is True

def test_write_dataset_into_existing_group(self, out_dir: Path):
"""write_dataset() reuses existing intermediate groups."""
with create(
out_dir,
product="test/product",
name="existing",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
b.write_product(np.arange(10, dtype=np.float32))
b.file.create_group("mygroup")
ds = b.write_dataset(
"mygroup/data",
data=np.array([1, 2, 3], dtype=np.int32),
chunks=(3,),
)
assert ds.shape == (3,)

files = list(out_dir.glob("*.h5"))
assert len(files) == 1
assert verify(str(files[0])) is True


# ---------------------------------------------------------------------------
# Full round-trip: create -> verify -> validate
# ---------------------------------------------------------------------------


class TestCreateVerifyValidateRoundtrip:
def test_create_verify_validate_roundtrip(self, out_dir: Path):
"""Full round-trip: create -> verify -> validate."""
import fd5

with fd5.create(
out_dir,
product="test/product",
name="roundtrip",
description="Integration test",
timestamp="2026-01-01T00:00:00Z",
) as b:
b.write_product(np.arange(100, dtype=np.float32))

files = list(out_dir.glob("*.h5"))
assert len(files) == 1
path = files[0]

# Verify integrity
assert fd5.verify(str(path))

# Validate schema
errors = fd5.validate(str(path))
assert errors == []


# ---------------------------------------------------------------------------
# _HashTrackingGroup proxy methods
# ---------------------------------------------------------------------------


class TestHashTrackingGroupProxy:
def test_setitem(self, out_dir: Path):
with create(
out_dir,
product="test/product",
name="proxy",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
b.write_product(np.arange(10, dtype=np.float32))
tracking = _HashTrackingGroup(b.file, {}, {})
tracking["scalar_ds"] = np.float32(42.0)
assert "scalar_ds" in b.file

def test_iter_and_len(self, out_dir: Path):
with create(
out_dir,
product="test/product",
name="proxy",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
tracking = _HashTrackingGroup(b.file, {}, {})
initial_len = len(tracking)
tracking.create_group("test_g")
assert len(tracking) == initial_len + 1
assert "test_g" in list(tracking)

def test_keys_values_items(self, out_dir: Path):
with create(
out_dir,
product="test/product",
name="proxy",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
tracking = _HashTrackingGroup(b.file, {}, {})
tracking.create_group("grp1")
assert "grp1" in tracking.keys()
assert len(list(tracking.values())) > 0
assert len(list(tracking.items())) > 0

def test_require_group(self, out_dir: Path):
with create(
out_dir,
product="test/product",
name="proxy",
description="Test",
timestamp="2026-01-01T00:00:00Z",
) as b:
tracking = _HashTrackingGroup(b.file, {}, {})
grp = tracking.require_group("new_grp")
assert isinstance(grp, _HashTrackingGroup)
# Calling again should return the same group, not raise
grp2 = tracking.require_group("new_grp")
assert isinstance(grp2, _HashTrackingGroup)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
Expand Down
Loading