Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"asyncssh==2.22.0",
"bcrypt==5.0.0",
"cachetools==6.2.6",
"cwl-utils==0.40",
"cwl-utils @ git+https://github.com/common-workflow-language/cwl-utils.git@main",
"importlib-metadata==8.7.1",
"Jinja2==3.1.6",
"jsonschema==4.26.0",
Expand Down Expand Up @@ -182,4 +182,4 @@ strict = true

[[tool.mypy.overrides]]
module = "streamflow.cwl.antlr.*"
ignore_errors = true
ignore_errors = true
10 changes: 3 additions & 7 deletions streamflow/core/asyncache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__all__ = ["cached", "cachedmethod"]

from collections.abc import Callable, MutableMapping
from contextlib import AbstractAsyncContextManager
from contextlib import AbstractAsyncContextManager, suppress
from typing import Any, TypeVar

from cachetools import keys as cache_keys
Expand Down Expand Up @@ -60,10 +60,8 @@ async def wrapper(*args, **kwargs):
f"argument type {type(obj).__name__} uses identity hashing (cache miss risk)."
)
v = await func(*args, **kwargs)
try:
with suppress(ValueError):
cache[k] = v
except ValueError:
pass # value too large
return v

else:
Expand Down Expand Up @@ -131,10 +129,8 @@ async def wrapper(self, *args, **kwargs):
f"argument type {type(obj).__name__} uses identity hashing (cache miss risk)."
)
v = await method(self, *args, **kwargs)
try:
with suppress(ValueError):
c[k] = v
except ValueError:
pass # value too large
return v

else:
Expand Down
8 changes: 3 additions & 5 deletions streamflow/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from abc import ABC, abstractmethod
from collections.abc import MutableSequence
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, AnyStr

from streamflow.core.context import SchemaEntity

if TYPE_CHECKING:
from typing import Any

from streamflow.core.context import StreamFlowContext
from streamflow.core.deployment import ExecutionLocation

Expand Down Expand Up @@ -120,7 +118,7 @@ def __init__(self, stream: Any):
async def close(self) -> None: ...

@abstractmethod
async def read(self, size: int | None = None): ...
async def read(self, size: int | None = None) -> AnyStr: ...

@abstractmethod
async def write(self, data: Any): ...
async def write(self, data: AnyStr): ...
46 changes: 30 additions & 16 deletions streamflow/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from collections.abc import MutableMapping, MutableSequence
from enum import IntEnum
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, TypeVar, cast, overload

from typing_extensions import Self

Expand Down Expand Up @@ -379,36 +379,40 @@ def add_input_port(self, name: str, port: Port) -> None:
def add_output_port(self, name: str, port: Port) -> None:
self._add_port(name, port, DependencyType.OUTPUT)

def get_input_port(self, name: str | None = None) -> Port | None:
def get_input_port(self, name: str | None = None) -> Port:
if name is None:
if len(self.input_ports) == 1:
return self.workflow.ports.get(next(iter(self.input_ports.values())))
else:
raise WorkflowExecutionException(
f"Cannot retrieve default input port as step {self.name} contains multiple input ports."
)
return (
self.workflow.ports.get(self.input_ports[name])
if name in self.input_ports
else None
)
else:
if name in self.input_ports:
return self.workflow.ports[self.input_ports[name]]
else:
raise WorkflowExecutionException(
f"Cannot retrieve input port {name} from step {self.name}"
)

def get_input_ports(self) -> MutableMapping[str, Port]:
return {k: self.workflow.ports[v] for k, v in self.input_ports.items()}

def get_output_port(self, name: str | None = None) -> Port | None:
def get_output_port(self, name: str | None = None) -> Port:
if name is None:
if len(self.output_ports) == 1:
return self.workflow.ports.get(next(iter(self.output_ports.values())))
else:
raise WorkflowExecutionException(
f"Cannot retrieve default output port as step {self.name} contains multiple output ports."
)
return (
self.workflow.ports.get(self.output_ports[name])
if name in self.output_ports
else None
)
else:
if name in self.output_ports:
return self.workflow.ports[self.output_ports[name]]
else:
raise WorkflowExecutionException(
f"Cannot retrieve output port {name} from step {self.name}"
)

def get_output_ports(self) -> MutableMapping[str, Port]:
return {k: self.workflow.ports[v] for k, v in self.output_ports.items()}
Expand Down Expand Up @@ -527,7 +531,7 @@ async def _load(
) -> Self:
return cls(tag=row["tag"], value=row["value"], recoverable=row["recoverable"])

async def _save_value(self, context: StreamFlowContext):
async def _save_value(self, context: StreamFlowContext) -> Any:
return self.value

async def get_weight(self, context: StreamFlowContext) -> int:
Expand Down Expand Up @@ -561,7 +565,7 @@ def recoverable(self, recoverable: bool) -> None:
)
self._recoverable = recoverable

def retag(self, tag: str) -> Token:
def retag(self, tag: str) -> Self:
return self.__class__(tag=tag, value=self.value, recoverable=self._recoverable)

async def save(
Expand All @@ -580,7 +584,7 @@ async def save(
except TypeError as e:
raise WorkflowExecutionException from e

def update(self, value: Any) -> Token:
def update(self, value: Any) -> Self:
return self.__class__(tag=self.tag, value=value, recoverable=self._recoverable)


Expand Down Expand Up @@ -620,6 +624,16 @@ async def _save_additional_params(
) -> MutableMapping[str, Any]:
return {"config": self.config, "output_ports": self.output_ports}

if TYPE_CHECKING:

@overload
def create_port(self) -> Port: ...

@overload
def create_port(
self, cls: type[P] = ..., name: str | None = ..., **kwargs
) -> P: ...

def create_port(self, cls: type[P] = Port, name: str | None = None, **kwargs) -> P:
if name is None:
name = str(uuid.uuid4())
Expand Down
Loading
Loading