diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 53e90c767..7b84a9b12 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -72,7 +72,7 @@ TumblingCountWindowDefinition, TumblingTimeWindowDefinition, ) -from .windows.base import WindowOnLateCallback +from .windows.base import WindowOnLateCallback, WindowOnUpdateCallback if typing.TYPE_CHECKING: from quixstreams.processing import ProcessingContext @@ -1085,6 +1085,7 @@ def tumbling_window( grace_ms: Union[int, timedelta] = 0, name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ) -> TumblingTimeWindowDefinition: """ Create a time-based tumbling window transformation on this StreamingDataFrame. @@ -1151,6 +1152,10 @@ def tumbling_window( (default behavior). Otherwise, no message will be logged. + :param on_update: an optional callback to react on updated windows and + to expire them sooner. If the callback returns `True`, the window will be expired. + Default - `None`. + :return: `TumblingTimeWindowDefinition` instance representing the tumbling window configuration. This object can be further configured with aggregation functions @@ -1166,6 +1171,7 @@ def tumbling_window( dataframe=self, name=name, on_late=on_late, + on_update=on_update, ) def tumbling_count_window( @@ -1225,6 +1231,7 @@ def hopping_window( grace_ms: Union[int, timedelta] = 0, name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ) -> HoppingTimeWindowDefinition: """ Create a time-based hopping window transformation on this StreamingDataFrame. @@ -1302,6 +1309,10 @@ def hopping_window( (default behavior). Otherwise, no message will be logged. + :param on_update: an optional callback to react on updated windows and + to expire them sooner. If the callback returns `True`, the window will be expired. + Default - `None`. + :return: `HoppingTimeWindowDefinition` instance representing the hopping window configuration. This object can be further configured with aggregation functions @@ -1319,6 +1330,7 @@ def hopping_window( dataframe=self, name=name, on_late=on_late, + on_update=on_update, ) def hopping_count_window( diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 8040b2774..8c261fd14 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -34,6 +34,7 @@ WindowResult: TypeAlias = dict[str, Any] WindowKeyResult: TypeAlias = tuple[Any, WindowResult] Message: TypeAlias = tuple[WindowResult, Any, int, Any] +WindowOnUpdateCallback: TypeAlias = Callable[[Any, Any], bool] WindowAggregateFunc = Callable[[Any, Any], Any] diff --git a/quixstreams/dataframe/windows/definitions.py b/quixstreams/dataframe/windows/definitions.py index 90d4d815b..e20a65cab 100644 --- a/quixstreams/dataframe/windows/definitions.py +++ b/quixstreams/dataframe/windows/definitions.py @@ -16,6 +16,7 @@ from .base import ( Window, WindowOnLateCallback, + WindowOnUpdateCallback, ) from .count_based import ( CountWindow, @@ -54,11 +55,13 @@ def __init__( name: Optional[str], dataframe: "StreamingDataFrame", on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ) -> None: super().__init__() self._name = name self._on_late = on_late + self._on_update = on_update self._dataframe = dataframe @abstractmethod @@ -239,6 +242,7 @@ def __init__( name: Optional[str] = None, step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ): if not isinstance(duration_ms, int): raise TypeError("Window size must be an integer") @@ -253,7 +257,7 @@ def __init__( f"got {step_ms}ms" ) - super().__init__(name, dataframe, on_late) + super().__init__(name, dataframe, on_late, on_update) self._duration_ms = duration_ms self._grace_ms = grace_ms @@ -281,6 +285,7 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ): super().__init__( duration_ms=duration_ms, @@ -289,6 +294,7 @@ def __init__( name=name, step_ms=step_ms, on_late=on_late, + on_update=on_update, ) def _get_name(self, func_name: Optional[str]) -> str: @@ -320,6 +326,7 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + on_update=self._on_update, ) @@ -331,6 +338,7 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ): super().__init__( duration_ms=duration_ms, @@ -339,6 +347,7 @@ def __init__( name=name, on_late=on_late, ) + self._on_update = on_update def _get_name(self, func_name: Optional[str]) -> str: prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window" @@ -368,6 +377,7 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + on_update=self._on_update, ) @@ -379,13 +389,20 @@ def __init__( dataframe: "StreamingDataFrame", name: Optional[str] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ): + if on_update is not None: + raise ValueError( + "Sliding windows do not support the 'on_update' trigger callback. " + "Use tumbling or hopping windows instead." + ) super().__init__( duration_ms=duration_ms, grace_ms=grace_ms, dataframe=dataframe, name=name, on_late=on_late, + on_update=on_update, ) def _get_name(self, func_name: Optional[str]) -> str: @@ -417,6 +434,7 @@ def _create_window( aggregators=aggregators or {}, collectors=collectors or {}, on_late=self._on_late, + on_update=self._on_update, ) diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index c403cfdfa..38c209b57 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -11,6 +11,7 @@ Window, WindowKeyResult, WindowOnLateCallback, + WindowOnUpdateCallback, get_window_ranges, ) @@ -46,6 +47,7 @@ def __init__( dataframe: "StreamingDataFrame", step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, + on_update: Optional[WindowOnUpdateCallback] = None, ): super().__init__( name=name, @@ -56,6 +58,7 @@ def __init__( self._grace_ms = grace_ms self._step_ms = step_ms self._on_late = on_late + self._on_update = on_update self._closing_strategy = ClosingStrategy.KEY @@ -132,6 +135,7 @@ def process_window( state = transaction.as_state(prefix=key) duration_ms = self._duration_ms grace_ms = self._grace_ms + on_update = self._on_update collect = self.collect aggregate = self.aggregate @@ -152,6 +156,7 @@ def process_window( max_expired_window_end = latest_timestamp - grace_ms max_expired_window_start = max_expired_window_end - duration_ms updated_windows: list[WindowKeyResult] = [] + triggered_windows: list[WindowKeyResult] = [] for start, end in ranges: if start <= max_expired_window_start: late_by_ms = max_expired_window_end - timestamp_ms @@ -169,18 +174,44 @@ def process_window( # since actual values are stored separately and combined into an array # during window expiration. aggregated = None + if aggregate: current_value = state.get_window(start, end) if current_value is None: current_value = self._initialize_value() aggregated = self._aggregate_value(current_value, value, timestamp_ms) - updated_windows.append( - ( - key, - self._results(aggregated, [], start, end), - ) - ) + + if on_update and on_update(current_value, aggregated): + # Get collected values for the result + collected = [] + if collect: + collected = state.get_from_collection(start, end) + # Add the current value that's being collected + collected.append(self._collect_value(value)) + + result = self._results(aggregated, collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + + result = self._results(aggregated, [], start, end) + updated_windows.append((key, result)) + elif collect and on_update: + # For collect-only windows, get the old and new collected values + old_collected = state.get_from_collection(start, end) + new_collected = [*old_collected, self._collect_value(value)] + + if on_update(old_collected, new_collected): + result = self._results(None, new_collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms) if collect: @@ -198,7 +229,10 @@ def process_window( key, state, max_expired_window_start, collect ) - return updated_windows, expired_windows + # Combine triggered windows with time-expired windows + all_expired_windows = triggered_windows + list(expired_windows) + + return updated_windows, iter(all_expired_windows) def expire_by_partition( self, diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index c80c9e2ad..2764651b5 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -391,6 +391,16 @@ def expire_all_windows( """ ... + def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None: + """ + Delete a single window defined by start and end timestamps. + + :param start_ms: start of the window in milliseconds + :param end_ms: end of the window in milliseconds + :param prefix: a key prefix + """ + ... + def delete_windows( self, max_start_time: int, delete_values: bool, prefix: bytes ) -> None: diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index 6a0b1fd5f..fa8d7e8fe 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -1,3 +1,5 @@ +import functools + import pytest import quixstreams.dataframe.windows.aggregations as agg @@ -12,13 +14,17 @@ @pytest.fixture() def hopping_window_definition_factory(state_manager, dataframe_factory): def factory( - duration_ms: int, step_ms: int, grace_ms: int = 0 + duration_ms: int, step_ms: int, grace_ms: int = 0, on_update=None ) -> HoppingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = HoppingTimeWindowDefinition( - duration_ms=duration_ms, step_ms=step_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + step_ms=step_ms, + grace_ms=grace_ms, + dataframe=sdf, + on_update=on_update, ) return window_def @@ -33,6 +39,136 @@ def process(window, value, key, transaction, timestamp_ms): class TestHoppingWindow: + def test_hopping_window_with_trigger( + self, hopping_window_definition_factory, state_manager + ): + # Define a trigger that expires windows when the sum reaches 100 or more + def trigger_on_sum_100(old_value, new_value) -> bool: + return new_value >= 100 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, on_update=trigger_on_sum_100 + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add value=90 at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with sum 90 each + updated, expired = _process(value=90, timestamp_ms=50) + assert len(updated) == 2 + assert updated[0][1]["value"] == 90 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert updated[1][1]["value"] == 90 + assert updated[1][1]["start"] == 50 + assert updated[1][1]["end"] == 150 + assert not expired + + # Step 2: Add value=5 at timestamp 110ms + # With grace_ms=100, [0, 100) does NOT expire naturally yet + # [0, 100): stays 90 (timestamp 110 is outside [0, 100), not updated) + # [50, 150): 90 -> 95 (< 100, NOT TRIGGERED) + # [100, 200): newly created with sum 5 + updated, expired = _process(value=5, timestamp_ms=110) + assert len(updated) == 2 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 50 + assert updated[0][1]["end"] == 150 + assert updated[1][1]["value"] == 5 + assert updated[1][1]["start"] == 100 + assert updated[1][1]["end"] == 200 + # No windows expired (grace period keeps [0, 100) alive) + assert not expired + + # Step 3: Add value=5 at timestamp 90ms (late message) + # Timestamp 90 belongs to BOTH [0, 100) and [50, 150) + # [0, 100): 90 -> 95 (< 100, NOT TRIGGERED) + # [50, 150): 95 -> 100 (>= 100, TRIGGERED!) + updated, expired = _process(value=5, timestamp_ms=90) + # Only [0, 100) remains in updated (not triggered, 95 < 100) + # Only [50, 150) was triggered (100 >= 100) + assert len(updated) == 1 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert len(expired) == 1 + assert expired[0][1]["value"] == 100 + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_collect_with_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that on_update callback works with collect for hopping windows.""" + + # Define a trigger that expires windows when we collect 3 or more items + def trigger_on_count_3(old_value, new_value) -> bool: + return len(new_value) >= 3 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, on_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add first value at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with 1 item each + updated, expired = _process(value=1, timestamp_ms=50) + assert not updated # collect doesn't emit on updates + assert not expired + + # Step 2: Add second value at timestamp 60ms + # Both windows now have 2 items + updated, expired = _process(value=2, timestamp_ms=60) + assert not updated + assert not expired + + # Step 3: Add third value at timestamp 70ms + # Both windows now have 3 items - BOTH SHOULD TRIGGER + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + # Window [0, 100) triggered + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + # Window [50, 150) triggered + assert expired[1][1]["value"] == [1, 2, 3] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add fourth value at timestamp 110ms + # Timestamp 110 belongs to windows [50, 150) and [100, 200) + # Window [50, 150) is "resurrected" because collection values weren't deleted + # (for hopping windows, we don't delete collection on trigger to preserve + # values for overlapping windows) + # Window [50, 150) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN! + # Window [100, 200) has [4] = 1 item - doesn't trigger + updated, expired = _process(value=4, timestamp_ms=110) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + @pytest.mark.parametrize( "duration, grace, step, provided_name, func_name, expected_name", [ diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index 98d9f56c1..94950e1e5 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -11,12 +11,17 @@ @pytest.fixture() def tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory(duration_ms: int, grace_ms: int = 0) -> TumblingTimeWindowDefinition: + def factory( + duration_ms: int, grace_ms: int = 0, on_update=None + ) -> TumblingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = TumblingTimeWindowDefinition( - duration_ms=duration_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + grace_ms=grace_ms, + dataframe=sdf, + on_update=on_update, ) return window_def @@ -31,6 +36,116 @@ def process(window, value, key, transaction, timestamp_ms): class TestTumblingWindow: + def test_tumbling_window_with_trigger( + self, tumbling_window_definition_factory, state_manager + ): + # Define a trigger that expires the window when the sum increases by 5 or more + def trigger_on_delta_5(old_value, new_value) -> bool: + return (new_value - old_value) >= 5 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, on_update=trigger_on_delta_5 + ) + window = window_def.sum() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=2, sum becomes 2, delta from 0 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + # Add value=2, sum becomes 4, delta from 2 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 4 + assert not expired + + # Add value=5, sum becomes 9, delta from 4 is 5, should trigger (>= 5) + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + assert expired[0][1]["value"] == 9 + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + def test_tumbling_window_collect_with_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that on_update callback works with collect.""" + + # Define a trigger that expires the window when we collect 3 or more items + def trigger_on_count_3(old_value, new_value) -> bool: + # For collect, old_value and new_value are lists + return len(new_value) >= 3 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, on_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final(closing_strategy="key") + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add first value - should not trigger (count=1) + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert not updated # collect doesn't emit on updates + assert not expired + + # Add second value - should not trigger (count=2) + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert not updated + assert not expired + + # Add third value - should trigger (count=3) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value at t=80 still belongs to window [0, 100) + # Window is "resurrected" because collection values weren't deleted + # (we let normal expiration handle cleanup for simplicity) + # Window [0, 100) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + @pytest.mark.parametrize( "duration, grace, provided_name, func_name, expected_name", [