Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cwl-conformance-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ venv() {
if ! test -d "$1" ; then
if command -v uv > /dev/null; then
uv venv "$1"
uv sync --locked --no-dev || exit 1
source "$1/bin/activate"
uv sync --active --locked --no-dev || exit 1
elif command -v virtualenv > /dev/null; then
virtualenv -p python3 "$1"
else
python3 -m venv "$1"
fi
fi
source "$1"/bin/activate
source "$1/bin/activate"
}

# Version of the standard to test against
Expand Down
6 changes: 6 additions & 0 deletions streamflow/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ class Executor(ABC):
def __init__(self, workflow: Workflow):
self.workflow: Workflow = workflow

@abstractmethod
async def close(self) -> None: ...

@abstractmethod
async def closed(self) -> bool: ...

@abstractmethod
async def run(self) -> MutableMapping[str, Any]: ...

Expand Down
73 changes: 48 additions & 25 deletions streamflow/workflow/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import time
from collections.abc import MutableMapping, MutableSequence
from collections.abc import Iterable, MutableMapping, MutableSequence
from typing import TYPE_CHECKING, overload

from streamflow.core import utils
Expand All @@ -24,7 +24,8 @@ def __init__(self, workflow: Workflow):
self.executions: MutableSequence[asyncio.Task[None]] = []
self.output_tasks: MutableMapping[str, asyncio.Task[Token]] = {}
self.received: MutableSequence[str] = []
self.closed: bool = False
self._closed: bool = False
self._closing: asyncio.Event | None = None

if TYPE_CHECKING:

Expand All @@ -33,29 +34,29 @@ async def _handle_exception(self, task: asyncio.Task[Token]) -> Token: ...
@overload
async def _handle_exception(self, task: asyncio.Task[None]) -> None: ...

async def _cancel(self, tasks: Iterable[asyncio.Task]) -> None:
if self._closed:
return
if self._closing is not None:
await self._closing.wait()
else:
# Cancel all tasks
for task in tasks:
task.cancel()
await asyncio.gather(*tasks)
# Mark the executor as closed
self._closed = True

async def _handle_exception(self, task: asyncio.Task[Token | None]) -> Token | None:
try:
return await task
except asyncio.CancelledError:
pass
except Exception as exc:
logger.exception(exc)
if not self.closed:
await self._shutdown()
await self.close()
return None

async def _shutdown(self) -> None:
# Terminate all steps
await asyncio.gather(
*(
asyncio.create_task(step.terminate(Status.CANCELLED))
for step in self.workflow.steps.values()
if not step.terminated
)
)
# Mark the executor as closed
self.closed = True

async def _wait_outputs(
self, output_consumer: str, output_tokens: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
Expand All @@ -73,15 +74,13 @@ async def _wait_outputs(
# If a TerminationToken is received, the corresponding port terminated its outputs
if isinstance(token, TerminationToken):
if token.value in (Status.CANCELLED, Status.FAILED):
self.closed = True
for t in unfinished:
t.cancel()
await self._cancel(unfinished)
return output_tokens
else:
self.received.append(task_name)
# When the last port terminates, the entire executor terminates
if len(self.received) == len(self.workflow.output_ports):
self.closed = True
await self.close()
else:
# Collect result
output_tokens[task_name] = get_token_value(token)
Expand All @@ -106,10 +105,35 @@ async def _wait_outputs(
),
name=port_name,
)
self.closed = False
# Reopen the executor if it was closed
if await self.closed():
self._closing = None
self._closed = False
# Return output tokens
return output_tokens

async def close(self) -> None:
if self._closed:
return
if self._closing is not None:
await self._closing.wait()
else:
# Terminate all steps
await asyncio.gather(
*(
asyncio.create_task(step.terminate(Status.CANCELLED))
for step in self.workflow.steps.values()
if not step.terminated
)
)
# Mark the executor as closed
self._closed = True

async def closed(self) -> bool:
if self._closing is not None:
await self._closing.wait()
return self._closed

async def run(self) -> MutableMapping[str, Any]:
try:
output_tokens = {}
Expand Down Expand Up @@ -138,7 +162,7 @@ async def run(self) -> MutableMapping[str, Any]:
),
name=port_name,
)
while not self.closed:
while not await self.closed():
output_tokens = await self._wait_outputs(
output_consumer, output_tokens
)
Expand All @@ -156,12 +180,11 @@ async def run(self) -> MutableMapping[str, Any]:
)
# Print output tokens
return output_tokens
except Exception:
except BaseException:
if self.workflow.persistent_id:
await self.workflow.context.database.update_workflow(
self.workflow.persistent_id,
{"status": Status.FAILED.value, "end_time": time.time_ns()},
)
if not self.closed:
await self._shutdown()
await self.close()
raise
Loading