Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
TumblingCountWindowDefinition,
TumblingTimeWindowDefinition,
)
from .windows.base import WindowOnLateCallback
from .windows.base import WindowOnLateCallback, WindowOnUpdateCallback

if typing.TYPE_CHECKING:
from quixstreams.processing import ProcessingContext
Expand Down Expand Up @@ -1085,6 +1085,7 @@ def tumbling_window(
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
) -> TumblingTimeWindowDefinition:
"""
Create a time-based tumbling window transformation on this StreamingDataFrame.
Expand Down Expand Up @@ -1151,6 +1152,10 @@ def tumbling_window(
(default behavior).
Otherwise, no message will be logged.

:param on_update: an optional callback to react on updated windows and
to expire them sooner. If the callback returns `True`, the window will be expired.
Default - `None`.

:return: `TumblingTimeWindowDefinition` instance representing the tumbling window
configuration.
This object can be further configured with aggregation functions
Expand All @@ -1166,6 +1171,7 @@ def tumbling_window(
dataframe=self,
name=name,
on_late=on_late,
on_update=on_update,
)

def tumbling_count_window(
Expand Down Expand Up @@ -1225,6 +1231,7 @@ def hopping_window(
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
) -> HoppingTimeWindowDefinition:
"""
Create a time-based hopping window transformation on this StreamingDataFrame.
Expand Down Expand Up @@ -1302,6 +1309,10 @@ def hopping_window(
(default behavior).
Otherwise, no message will be logged.

:param on_update: an optional callback to react on updated windows and
to expire them sooner. If the callback returns `True`, the window will be expired.
Default - `None`.

:return: `HoppingTimeWindowDefinition` instance representing the hopping
window configuration.
This object can be further configured with aggregation functions
Expand All @@ -1319,6 +1330,7 @@ def hopping_window(
dataframe=self,
name=name,
on_late=on_late,
on_update=on_update,
)

def hopping_count_window(
Expand Down
1 change: 1 addition & 0 deletions quixstreams/dataframe/windows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
WindowResult: TypeAlias = dict[str, Any]
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
WindowOnUpdateCallback: TypeAlias = Callable[[Any, Any], bool]

WindowAggregateFunc = Callable[[Any, Any], Any]

Expand Down
20 changes: 19 additions & 1 deletion quixstreams/dataframe/windows/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base import (
Window,
WindowOnLateCallback,
WindowOnUpdateCallback,
)
from .count_based import (
CountWindow,
Expand Down Expand Up @@ -54,11 +55,13 @@ def __init__(
name: Optional[str],
dataframe: "StreamingDataFrame",
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
) -> None:
super().__init__()

self._name = name
self._on_late = on_late
self._on_update = on_update
self._dataframe = dataframe

@abstractmethod
Expand Down Expand Up @@ -239,6 +242,7 @@ def __init__(
name: Optional[str] = None,
step_ms: Optional[int] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
):
if not isinstance(duration_ms, int):
raise TypeError("Window size must be an integer")
Expand All @@ -253,7 +257,7 @@ def __init__(
f"got {step_ms}ms"
)

super().__init__(name, dataframe, on_late)
super().__init__(name, dataframe, on_late, on_update)

self._duration_ms = duration_ms
self._grace_ms = grace_ms
Expand Down Expand Up @@ -281,6 +285,7 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
):
super().__init__(
duration_ms=duration_ms,
Expand All @@ -289,6 +294,7 @@ def __init__(
name=name,
step_ms=step_ms,
on_late=on_late,
on_update=on_update,
)

def _get_name(self, func_name: Optional[str]) -> str:
Expand Down Expand Up @@ -320,6 +326,7 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
on_update=self._on_update,
)


Expand All @@ -331,6 +338,7 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
):
super().__init__(
duration_ms=duration_ms,
Expand All @@ -339,6 +347,7 @@ def __init__(
name=name,
on_late=on_late,
)
self._on_update = on_update

def _get_name(self, func_name: Optional[str]) -> str:
prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window"
Expand Down Expand Up @@ -368,6 +377,7 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
on_update=self._on_update,
)


Expand All @@ -379,13 +389,20 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
):
if on_update is not None:
raise ValueError(
"Sliding windows do not support the 'on_update' trigger callback. "
"Use tumbling or hopping windows instead."
)
super().__init__(
duration_ms=duration_ms,
grace_ms=grace_ms,
dataframe=dataframe,
name=name,
on_late=on_late,
on_update=on_update,
)

def _get_name(self, func_name: Optional[str]) -> str:
Expand Down Expand Up @@ -417,6 +434,7 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
on_update=self._on_update,
)


Expand Down
48 changes: 41 additions & 7 deletions quixstreams/dataframe/windows/time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Window,
WindowKeyResult,
WindowOnLateCallback,
WindowOnUpdateCallback,
get_window_ranges,
)

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
dataframe: "StreamingDataFrame",
step_ms: Optional[int] = None,
on_late: Optional[WindowOnLateCallback] = None,
on_update: Optional[WindowOnUpdateCallback] = None,
):
super().__init__(
name=name,
Expand All @@ -56,6 +58,7 @@ def __init__(
self._grace_ms = grace_ms
self._step_ms = step_ms
self._on_late = on_late
self._on_update = on_update

self._closing_strategy = ClosingStrategy.KEY

Expand Down Expand Up @@ -132,6 +135,7 @@ def process_window(
state = transaction.as_state(prefix=key)
duration_ms = self._duration_ms
grace_ms = self._grace_ms
on_update = self._on_update

collect = self.collect
aggregate = self.aggregate
Expand All @@ -152,6 +156,7 @@ def process_window(
max_expired_window_end = latest_timestamp - grace_ms
max_expired_window_start = max_expired_window_end - duration_ms
updated_windows: list[WindowKeyResult] = []
triggered_windows: list[WindowKeyResult] = []
for start, end in ranges:
if start <= max_expired_window_start:
late_by_ms = max_expired_window_end - timestamp_ms
Expand All @@ -169,18 +174,44 @@ def process_window(
# since actual values are stored separately and combined into an array
# during window expiration.
aggregated = None

if aggregate:
current_value = state.get_window(start, end)
if current_value is None:
current_value = self._initialize_value()

aggregated = self._aggregate_value(current_value, value, timestamp_ms)
updated_windows.append(
(
key,
self._results(aggregated, [], start, end),
)
)

if on_update and on_update(current_value, aggregated):
# Get collected values for the result
collected = []
if collect:
collected = state.get_from_collection(start, end)
# Add the current value that's being collected
collected.append(self._collect_value(value))

result = self._results(aggregated, collected, start, end)
triggered_windows.append((key, result))
transaction.delete_window(start, end, prefix=key)
# Note: We don't delete from collection here - normal expiration
# will handle cleanup for both tumbling and hopping windows
continue

result = self._results(aggregated, [], start, end)
updated_windows.append((key, result))
elif collect and on_update:
# For collect-only windows, get the old and new collected values
old_collected = state.get_from_collection(start, end)
new_collected = [*old_collected, self._collect_value(value)]

if on_update(old_collected, new_collected):
result = self._results(None, new_collected, start, end)
triggered_windows.append((key, result))
transaction.delete_window(start, end, prefix=key)
# Note: We don't delete from collection here - normal expiration
# will handle cleanup for both tumbling and hopping windows
continue

state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms)

if collect:
Expand All @@ -198,7 +229,10 @@ def process_window(
key, state, max_expired_window_start, collect
)

return updated_windows, expired_windows
# Combine triggered windows with time-expired windows
all_expired_windows = triggered_windows + list(expired_windows)

return updated_windows, iter(all_expired_windows)

def expire_by_partition(
self,
Expand Down
10 changes: 10 additions & 0 deletions quixstreams/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,16 @@ def expire_all_windows(
"""
...

def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None:
"""
Delete a single window defined by start and end timestamps.

:param start_ms: start of the window in milliseconds
:param end_ms: end of the window in milliseconds
:param prefix: a key prefix
"""
...

def delete_windows(
self, max_start_time: int, delete_values: bool, prefix: bytes
) -> None:
Expand Down
Loading