Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions streamflow/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import os
import posixpath
import shlex
import sys
import uuid
from collections.abc import Iterable, MutableMapping, MutableSequence
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Any

import mslex

from streamflow.core.exception import ProcessorTypeError, WorkflowExecutionException
from streamflow.core.persistence import PersistableEntity
from streamflow.log_handler import logger
Expand Down Expand Up @@ -126,6 +129,27 @@ def create_command(
)


if sys.platform == "win32":

def create_shell_command(command: MutableSequence[str], local: bool) -> list[str]:
cmd = " ".join(command)
return (
["cmd", "/C", quote(value=cmd, local=local)]
if local
else ["sh", "-c", shlex.quote(cmd)]
)

else:

def create_shell_command(command: MutableSequence[str], local: bool) -> list[str]:
cmd = " ".join(command)
return (
[os.environ.get("SHELL", "sh"), "-c", quote(value=cmd, local=local)]
if local
else ["sh", "-c", shlex.quote(cmd)]
)


def get_job_step_name(job_name: str) -> str:
return PurePosixPath(job_name).parent.as_posix()

Expand Down Expand Up @@ -211,7 +235,7 @@ async def get_local_to_remote_destination(
) -> str:
is_dst_dir, status = await dst_connector.run(
location=dst_location,
command=[f'test -d "{dst}"'],
command=["test", "-d", shlex.quote(dst)],
capture_output=True,
)
if status > 1:
Expand Down Expand Up @@ -254,21 +278,21 @@ async def get_remote_to_remote_write_command(
) -> MutableSequence[str]:
is_dst_dir, status = await dst_connector.run(
location=dst_locations[0],
command=[f'test -d "{dst}"'],
command=["test", "-d", shlex.quote(dst)],
capture_output=True,
)
if status > 1:
raise WorkflowExecutionException(is_dst_dir)
# If destination path exists and is a directory
elif status == 0:
return ["tar", "xf", "-", "-C", dst]
return ["tar", "xf", "-", "-C", shlex.quote(dst)]
# Otherwise, if destination path does not exist
else:
# If basename must be renamed during transfer
if posixpath.basename(src) != posixpath.basename(dst):
is_src_dir, status = await src_connector.run(
location=src_location,
command=[f'test -d "{src}"'],
command=["test", "-d", shlex.quote(src)],
capture_output=True,
)
if status > 1:
Expand All @@ -279,19 +303,44 @@ async def get_remote_to_remote_write_command(
*(
asyncio.create_task(
dst_connector.run(
location=dst_location, command=["mkdir", "-p", dst]
location=dst_location,
command=["mkdir", "-p", shlex.quote(dst)],
)
)
for dst_location in dst_locations
)
)
return ["tar", "xf", "-", "-C", dst, "--strip-components", "1"]
return [
"tar",
"xf",
"-",
"-C",
shlex.quote(dst),
"--strip-components",
"1",
]
# Otherwise, if source path is a file
else:
return ["tar", "xf", "-", "-O", "|", "tee", dst, ">", "/dev/null"]
return [
"tar",
"xf",
"-",
"-O",
"|",
"tee",
shlex.quote(dst),
">",
"/dev/null",
]
# Otherwise, if basename must be preserved
else:
return ["tar", "xf", "-", "-C", posixpath.dirname(dst)]
return [
"tar",
"xf",
"-",
"-C",
shlex.quote(posixpath.dirname(dst)),
]


def get_tag(tokens: Iterable[Token]) -> str:
Expand Down Expand Up @@ -363,5 +412,16 @@ async def run_in_subprocess(
return None


if sys.platform == "win32":

def quote(value: str, local: bool) -> str:
return mslex.quote(value) if local else shlex.quote(value)

else:

def quote(value: str, local: bool) -> str:
return shlex.quote(value)


def wrap_command(command: str) -> list[str]:
return ["/bin/sh", "-c", f"{command}"]
67 changes: 30 additions & 37 deletions streamflow/cwl/command.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import base64
import json
import logging
import posixpath
import shlex
import time
from asyncio.subprocess import STDOUT
from collections.abc import MutableMapping, MutableSequence
Expand All @@ -30,7 +28,7 @@
MapCommandOutputProcessor,
UnionCommandOutputProcessor,
)
from streamflow.core.utils import flatten_list
from streamflow.core.utils import create_shell_command, flatten_list, quote
from streamflow.core.workflow import (
Command,
CommandOptions,
Expand Down Expand Up @@ -226,11 +224,11 @@ def _build_command_output_processor(
)


def _escape_value(value: Any) -> Any:
def _escape_value(value: Any, local: bool) -> Any:
if isinstance(value, MutableSequence):
return [_escape_value(v) for v in value]
return [_escape_value(value=v, local=local) for v in value]
else:
return shlex.quote(_get_value_repr(value))
return quote(value=_get_value_repr(value), local=local)


async def _get_source_location(
Expand Down Expand Up @@ -710,17 +708,23 @@ async def _load(
)

def _get_executable_command(
self, context: MutableMapping[str, Any], inputs: MutableMapping[str, Token]
self,
context: MutableMapping[str, Any],
inputs: MutableMapping[str, Token],
local: bool,
) -> MutableSequence[str]:
command = []
options = CWLCommandOptions(
context=context,
expression_lib=self.expression_lib,
full_js=self.full_js,
local=local,
)
# Process baseCommand
if self.base_command:
command.append(shlex.join(self.base_command))
command = (
[quote(cmd, local=options.local) for cmd in self.base_command]
if self.base_command
else []
)
# Process tokens
bindings = ListCommandToken(name=None, position=None, value=[])
for processor in self.processors:
Expand Down Expand Up @@ -807,8 +811,14 @@ async def execute(self, job: Job) -> CWLCommandOutput:
)
else:
inputs = job.inputs
# Get execution target
connector = self.step.workflow.context.scheduler.get_connector(job.name)
locations = self.step.workflow.context.scheduler.get_locations(job.name)
local = all(loc.local for loc in locations)
# Build command string
cmd = self._get_executable_command(context, inputs)
cmd = self._get_executable_command(context=context, inputs=inputs, local=local)
if self.is_shell_command:
cmd = create_shell_command(cmd, local=local)
# Build environment variables
parsed_env = {
k: str(
Expand All @@ -825,24 +835,14 @@ async def execute(self, job: Job) -> CWLCommandOutput:
parsed_env["HOME"] = job.output_directory
if "TMPDIR" not in parsed_env:
parsed_env["TMPDIR"] = job.tmp_directory
# Get execution target
connector = self.step.workflow.context.scheduler.get_connector(job.name)
locations = self.step.workflow.context.scheduler.get_locations(job.name)
cmd_string = " \\\n\t".join(
["/bin/sh", "-c", '"{cmd}"'.format(cmd=" ".join(cmd))]
if self.is_shell_command
else cmd
)
# Log and persist command
cmd_string = " \\\n\t".join(cmd)
if logger.isEnabledFor(logging.INFO):
logger.info(
"EXECUTING step {step} (job {job}) {location} into directory {outdir}:\n{command}".format(
step=self.step.name,
job=job.name,
location=(
"locally"
if locations[0].local
else f"on location {locations[0]}"
),
location=("locally" if local else f"on location {locations[0]}"),
outdir=job.output_directory,
command=cmd_string,
)
Expand All @@ -856,17 +856,6 @@ async def execute(self, job: Job) -> CWLCommandOutput:
job_token_id=job_token.persistent_id,
cmd=cmd_string,
)
# Escape shell command when needed
if self.is_shell_command:
cmd = [
"/bin/sh",
"-c",
'"$(echo {command} | base64 -d)"'.format(
command=base64.b64encode(" ".join(cmd).encode("utf-8")).decode(
"utf-8"
)
),
]
# If step is assigned to multiple locations, add the STREAMFLOW_HOSTS environment variable
if len(locations) > 1 and (
hostnames := [loc.hostname for loc in locations if loc.hostname is not None]
Expand Down Expand Up @@ -979,17 +968,19 @@ async def execute(self, job: Job) -> CWLCommandOutput:


class CWLCommandOptions(CommandOptions):
__slots__ = ("context", "expression_lib", "full_js")
__slots__ = ("context", "expression_lib", "full_js", "local")

def __init__(
self,
context: MutableMapping[str, Any],
expression_lib: MutableSequence[str] | None = None,
full_js: bool = False,
local: bool = False,
):
self.context: MutableMapping[str, Any] = context
self.expression_lib: MutableSequence[str] | None = expression_lib
self.full_js: bool = full_js
self.local: bool = local


class CWLCommandTokenProcessor(CommandTokenProcessor):
Expand Down Expand Up @@ -1075,7 +1066,7 @@ def bind(
value = [value]
# Process shell escape only on the single command token
if not self.is_shell_command or self.shell_quote:
value = [_escape_value(v) for v in value]
value = [_escape_value(value=v, local=options.local) for v in value]
# Obtain token position
if isinstance(self.position, str) and not self.position.isnumeric():
position = utils.eval_expression(
Expand Down Expand Up @@ -1219,6 +1210,7 @@ def _update_options(
| {"inputs": {self.name: get_token_value(token)}},
expression_lib=options.expression_lib,
full_js=options.full_js,
local=options.local,
)


Expand All @@ -1236,6 +1228,7 @@ def _update_options(
| {"inputs": {self.name: value}, "self": value},
expression_lib=options.expression_lib,
full_js=options.full_js,
local=options.local,
)


Expand Down
2 changes: 1 addition & 1 deletion streamflow/cwl/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ async def _process_command_output(
if self.target
else job.tmp_directory
),
path=cast(str, path),
path=path,
)
)
for path in globpaths
Expand Down
Loading
Loading