From 52f1641388b3990e7643b765c8ff625d50633e5a Mon Sep 17 00:00:00 2001 From: GlassOfWhiskey Date: Sat, 21 Mar 2026 13:06:20 +0100 Subject: [PATCH] Improve `StreamFlowExecutor` termination protocol This commit improves the termination protocol of the `StreamFlowExecutor` class, adopting the same behaviour used in `BaseShell`. This new protocol should guarantee that all tasks are correctly terminated before closing the `StreamFlowContext`, allowing cleaner terminations of the StreamFlow application. --- cwl-conformance-test.sh | 5 ++- streamflow/core/workflow.py | 6 +++ streamflow/workflow/executor.py | 73 ++++++++++++++++++++++----------- 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/cwl-conformance-test.sh b/cwl-conformance-test.sh index 387b1f7b5..a7f9d7ceb 100755 --- a/cwl-conformance-test.sh +++ b/cwl-conformance-test.sh @@ -20,14 +20,15 @@ venv() { if ! test -d "$1" ; then if command -v uv > /dev/null; then uv venv "$1" - uv sync --locked --no-dev || exit 1 + source "$1/bin/activate" + uv sync --active --locked --no-dev || exit 1 elif command -v virtualenv > /dev/null; then virtualenv -p python3 "$1" else python3 -m venv "$1" fi fi - source "$1"/bin/activate + source "$1/bin/activate" } # Version of the standard to test against diff --git a/streamflow/core/workflow.py b/streamflow/core/workflow.py index 00e6f6fe1..80ab84d68 100644 --- a/streamflow/core/workflow.py +++ b/streamflow/core/workflow.py @@ -140,6 +140,12 @@ class Executor(ABC): def __init__(self, workflow: Workflow): self.workflow: Workflow = workflow + @abstractmethod + async def close(self) -> None: ... + + @abstractmethod + async def closed(self) -> bool: ... + @abstractmethod async def run(self) -> MutableMapping[str, Any]: ... diff --git a/streamflow/workflow/executor.py b/streamflow/workflow/executor.py index e89908337..7b70e626e 100644 --- a/streamflow/workflow/executor.py +++ b/streamflow/workflow/executor.py @@ -2,7 +2,7 @@ import asyncio import time -from collections.abc import MutableMapping, MutableSequence +from collections.abc import Iterable, MutableMapping, MutableSequence from typing import TYPE_CHECKING, overload from streamflow.core import utils @@ -24,7 +24,8 @@ def __init__(self, workflow: Workflow): self.executions: MutableSequence[asyncio.Task[None]] = [] self.output_tasks: MutableMapping[str, asyncio.Task[Token]] = {} self.received: MutableSequence[str] = [] - self.closed: bool = False + self._closed: bool = False + self._closing: asyncio.Event | None = None if TYPE_CHECKING: @@ -33,6 +34,19 @@ async def _handle_exception(self, task: asyncio.Task[Token]) -> Token: ... @overload async def _handle_exception(self, task: asyncio.Task[None]) -> None: ... + async def _cancel(self, tasks: Iterable[asyncio.Task]) -> None: + if self._closed: + return + if self._closing is not None: + await self._closing.wait() + else: + # Cancel all tasks + for task in tasks: + task.cancel() + await asyncio.gather(*tasks) + # Mark the executor as closed + self._closed = True + async def _handle_exception(self, task: asyncio.Task[Token | None]) -> Token | None: try: return await task @@ -40,22 +54,9 @@ async def _handle_exception(self, task: asyncio.Task[Token | None]) -> Token | N pass except Exception as exc: logger.exception(exc) - if not self.closed: - await self._shutdown() + await self.close() return None - async def _shutdown(self) -> None: - # Terminate all steps - await asyncio.gather( - *( - asyncio.create_task(step.terminate(Status.CANCELLED)) - for step in self.workflow.steps.values() - if not step.terminated - ) - ) - # Mark the executor as closed - self.closed = True - async def _wait_outputs( self, output_consumer: str, output_tokens: MutableMapping[str, Any] ) -> MutableMapping[str, Any]: @@ -73,15 +74,13 @@ async def _wait_outputs( # If a TerminationToken is received, the corresponding port terminated its outputs if isinstance(token, TerminationToken): if token.value in (Status.CANCELLED, Status.FAILED): - self.closed = True - for t in unfinished: - t.cancel() + await self._cancel(unfinished) return output_tokens else: self.received.append(task_name) # When the last port terminates, the entire executor terminates if len(self.received) == len(self.workflow.output_ports): - self.closed = True + await self.close() else: # Collect result output_tokens[task_name] = get_token_value(token) @@ -106,10 +105,35 @@ async def _wait_outputs( ), name=port_name, ) - self.closed = False + # Reopen the executor if it was closed + if await self.closed(): + self._closing = None + self._closed = False # Return output tokens return output_tokens + async def close(self) -> None: + if self._closed: + return + if self._closing is not None: + await self._closing.wait() + else: + # Terminate all steps + await asyncio.gather( + *( + asyncio.create_task(step.terminate(Status.CANCELLED)) + for step in self.workflow.steps.values() + if not step.terminated + ) + ) + # Mark the executor as closed + self._closed = True + + async def closed(self) -> bool: + if self._closing is not None: + await self._closing.wait() + return self._closed + async def run(self) -> MutableMapping[str, Any]: try: output_tokens = {} @@ -138,7 +162,7 @@ async def run(self) -> MutableMapping[str, Any]: ), name=port_name, ) - while not self.closed: + while not await self.closed(): output_tokens = await self._wait_outputs( output_consumer, output_tokens ) @@ -156,12 +180,11 @@ async def run(self) -> MutableMapping[str, Any]: ) # Print output tokens return output_tokens - except Exception: + except BaseException: if self.workflow.persistent_id: await self.workflow.context.database.update_workflow( self.workflow.persistent_id, {"status": Status.FAILED.value, "end_time": time.time_ns()}, ) - if not self.closed: - await self._shutdown() + await self.close() raise