Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"numpy",
"pillow",
"typer",
"h5py",
] # Add project dependencies here, e.g. ["click", "numpy"]
dynamic = ["version"]
license.file = "LICENSE"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastcs.attributes import AttrRW
from fastcs.attributes import AttrR, AttrRW
from fastcs_odin.controllers.odin_data.frame_processor import (
FrameProcessorAdapterController,
)
Expand All @@ -7,3 +7,5 @@
class EigerFrameProcessorAdapterController(FrameProcessorAdapterController):
data_compression: AttrRW[str]
data_datatype: AttrRW[str]
data_dims_0: AttrR[int] # y
data_dims_1: AttrR[int] # x
24 changes: 23 additions & 1 deletion src/fastcs_eiger/controllers/odin/eiger_odin_controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from pathlib import Path

from fastcs.attributes import AttrRW
from fastcs.connections import IPConnectionSettings
from fastcs.datatypes import Int
from fastcs.datatypes import Bool, Int
from fastcs.methods import command

from fastcs_eiger.controllers.eiger_controller import COMMAND_GROUP, EigerController
from fastcs_eiger.controllers.odin.generate_vds import create_interleave_vds
from fastcs_eiger.controllers.odin.odin_controller import OdinController
from fastcs_eiger.eiger_parameter import EigerAPIVersion

Expand All @@ -19,6 +21,7 @@ class EigerOdinController(EigerController):
description="Timeout for start writing command",
group=COMMAND_GROUP,
)
enable_vds_creation = AttrRW(Bool())

def __init__(
self,
Expand Down Expand Up @@ -63,6 +66,10 @@ async def start_writing(self):
self.OD.FP.data_datatype.put(f"uint{self.detector.bit_depth_image.get()}"),
)

path = Path(self.OD.file_path.get())
prefix = self.OD.file_prefix.get()
frame_count = self.OD.FP.frames.get()

await self.OD.FP.start_writing()

try:
Expand All @@ -71,3 +78,18 @@ async def start_writing(self):
)
except TimeoutError as e:
raise TimeoutError("File writers failed to start") from e

if self.enable_vds_creation.get():
create_interleave_vds(
path=path,
prefix=prefix,
datasets=["data", "data2", "data3"],
frame_count=frame_count,
frames_per_block=self.OD.block_size.get(),
blocks_per_file=self.OD.FP.process_blocks_per_file.get(),
frame_shape=(
self.OD.FP.data_dims_1.get(),
self.OD.FP.data_dims_0.get(),
),
dtype=self.OD.FP.data_datatype.get(),
)
120 changes: 120 additions & 0 deletions src/fastcs_eiger/controllers/odin/generate_vds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import math
from dataclasses import dataclass
from pathlib import Path

import h5py
from fastcs.logging import logger


@dataclass
class FileFrames:
frames: int
start: int
frames_per_block: int

@property
def blocks(self):
return self.frames // self.frames_per_block

@property
def remainder_frames(self):
return self.frames % self.frames_per_block


def _get_frames_per_file_writer(
frame_count: int, frames_per_block: int, n_file_writers: int
) -> list[int]:
n_blocks = math.ceil(frame_count / frames_per_block)
min_blocks_per_file = n_blocks // n_file_writers
remainder = n_blocks - min_blocks_per_file * n_file_writers

frames_per_file_writer = []
for i in range(n_file_writers):
blocks = min_blocks_per_file + (i < remainder)
frames_per_file_writer.append(blocks * frames_per_block)

overflow = sum(frames_per_file_writer) - frame_count
frames_per_file_writer[remainder - 1] -= overflow
return frames_per_file_writer


def _calculate_frame_distribution(
frame_count, frames_per_block, blocks_per_file, n_file_writers
) -> dict[int, FileFrames]:

frame_distribution: dict[int, FileFrames] = {}

max_frames_per_file = (
frames_per_block * blocks_per_file if blocks_per_file else frame_count
)
# total frames written before one of the file writers has to start a new file
frames_before_new_file = n_file_writers * max_frames_per_file
frames_per_file_writer = _get_frames_per_file_writer(
frame_count, frames_per_block, n_file_writers
)
for file_writer_idx, n_frames in enumerate(frames_per_file_writer):
n_files = math.ceil(n_frames / max_frames_per_file)
for i in range(n_files):
file_idx = file_writer_idx + i * n_file_writers

frame_distribution[file_idx + 1] = FileFrames(
frames=min(n_frames - i * max_frames_per_file, max_frames_per_file),
frames_per_block=frames_per_block,
start=frames_per_block * file_writer_idx
+ file_idx // n_file_writers * frames_before_new_file,
)

return frame_distribution


def create_interleave_vds(
path: Path,
prefix: str,
datasets: list[str],
frame_count: int,
frames_per_block: int,
blocks_per_file: int,
frame_shape: tuple[int, int],
dtype: str = "float",
n_file_writers: int = 4,
) -> None:
frame_distribution = _calculate_frame_distribution(
frame_count, frames_per_block, blocks_per_file, n_file_writers
)
stride = n_file_writers * frames_per_block
filepath = f"{path / prefix}_vds.h5"
logger.info(f"Writing virtual dataset at {filepath}")
with h5py.File(f"{path / prefix}_vds.h5", "w", libver="latest") as f:
for dataset_name in datasets:
v_layout = h5py.VirtualLayout(
shape=(frame_count, frame_shape[0], frame_shape[1]),
dtype=dtype,
)
for file_number, file_frames in frame_distribution.items():
full_block_frames = file_frames.blocks * frames_per_block
v_source = h5py.VirtualSource(
f"{prefix}_{str(file_number).zfill(6)}.h5",
name=dataset_name,
shape=(file_frames.frames, frame_shape[0], frame_shape[1]),
dtype=dtype,
)
if file_frames.blocks:
source = v_source[:full_block_frames, :, :]
v_layout[
h5py.MultiBlockSlice(
start=file_frames.start,
stride=stride,
count=file_frames.blocks,
block=frames_per_block,
),
:,
:,
] = source
if file_frames.remainder_frames:
# Last few frames that don't fit into a block
source = v_source[full_block_frames : file_frames.frames, :, :]
v_layout[
frame_count - file_frames.remainder_frames : frame_count, :, :
] = source

f.create_virtual_dataset(dataset_name, v_layout)
53 changes: 44 additions & 9 deletions tests/test_eiger_odin_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pathlib import Path

import pytest
from fastcs.attributes import AttrR, AttrRW
from fastcs.connections import IPConnectionSettings
from fastcs.datatypes import Int, String
from pytest_mock import MockerFixture

from fastcs_eiger.controllers.eiger_controller import EigerController
Expand All @@ -8,13 +12,29 @@


@pytest.fixture
def eiger_odin_controller():
def eiger_odin_controller(mocker: MockerFixture):
detector_connection_settings = IPConnectionSettings("127.0.0.1", 8000)
odin_connection_settings = IPConnectionSettings("127.0.0.1", 8001)
return EigerOdinController(
controller = EigerOdinController(
detector_connection_settings, odin_connection_settings, api_version="1.8.0"
)

controller.OD.file_path = AttrRW(String(), initial_value="/tmp/data") # pyright: ignore[reportAttributeAccessIssue]
controller.OD.file_prefix = AttrRW(String(), initial_value="test_prefix") # pyright: ignore[reportAttributeAccessIssue]
controller.OD.block_size = AttrRW(Int(), initial_value=4) # pyright: ignore[reportAttributeAccessIssue]

fp_mock = mocker.patch.object(controller.OD, "FP", create=True)
fp_mock.data_compression.put = mocker.AsyncMock()
fp_mock.data_datatype.put = mocker.AsyncMock()
fp_mock.data_datatype.get.return_value = "uint16"
fp_mock.frames = AttrRW(Int(), initial_value=100)
fp_mock.process_blocks_per_file = AttrR(Int(), initial_value=10)
fp_mock.data_dims_0 = AttrR(Int(), initial_value=512)
fp_mock.data_dims_1 = AttrR(Int(), initial_value=1024)
fp_mock.start_writing = mocker.AsyncMock()

return controller


@pytest.mark.asyncio
async def test_eiger_odin_controller(eiger_odin_controller, mocker: MockerFixture):
Expand Down Expand Up @@ -58,10 +78,9 @@ async def test_start_writing(eiger_odin_controller, mocker: MockerFixture):
detector_mock.compression.get.return_value = "lz4"
detector_mock.bit_depth_image.get.return_value = 16

fp_mock = mocker.patch.object(controller.OD, "FP", create=True)
fp_mock.data_compression.put = mocker.AsyncMock()
fp_mock.data_datatype.put = mocker.AsyncMock()
fp_mock.start_writing = mocker.AsyncMock()
create_mock = mocker.patch(
"fastcs_eiger.controllers.odin.eiger_odin_controller.create_interleave_vds"
)

writing_wait_mock = mocker.patch.object(controller.OD.writing, "wait_for_value")

Expand All @@ -72,9 +91,25 @@ async def test_start_writing(eiger_odin_controller, mocker: MockerFixture):
writing_wait_mock.side_effect = None
await controller.start_writing()

fp_mock.data_compression.put.assert_awaited_with("LZ4")
fp_mock.data_datatype.put.assert_awaited_with("uint16")
fp_mock.start_writing.assert_awaited_with()
controller.OD.FP.data_compression.put.assert_awaited_with("LZ4")
controller.OD.FP.data_datatype.put.assert_awaited_with("uint16")
controller.OD.FP.start_writing.assert_awaited_with()
writing_wait_mock.assert_awaited_with(
True, timeout=controller.start_writing_timeout.get()
)

create_mock.assert_not_called()

controller.enable_vds_creation._value = True
await controller.start_writing()

create_mock.assert_called_once_with(
path=Path("/tmp/data"),
prefix="test_prefix",
datasets=["data", "data2", "data3"],
frame_count=100,
frames_per_block=4,
blocks_per_file=10,
frame_shape=(1024, 512),
dtype="uint16",
)
Loading
Loading