Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions src/pyproject_runner/_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,18 @@ def from_pyproject(cls, project: PyProject) -> Workspace | None:


class Task:
__slots__ = "cmd", "cwd", "env", "env_file", "executable", "help", "post", "pre"
__slots__ = (
"cmd",
"cwd",
"env",
"env_file",
"executable",
"help",
"ignore_returncode",
"post",
"pre",
"use_shell",
)

@overload
def __init__(self, cmd: str, *, executable: Path) -> None: ...
Expand All @@ -307,23 +318,30 @@ def __init__(self, cmd: str | Sequence[str] | None, *, cwd: str | None = None,
env: str | Mapping[str, str] | None = None,
env_file: str | Sequence[str] | None = None,
help: str | None = None,
use_shell: bool | None = None,
ignore_returncode: bool | None = None,
pre: Sequence[Sequence[str]] | None = None,
post: Sequence[Sequence[str]] | None = None) -> None: ...

def __init__(self, cmd: str | Sequence[str] | None, *, cwd: str | None = None,
env: str | Mapping[str, str] | None = None,
env_file: str | Sequence[str] | None = None,
help: str | None = None, # noqa: A002
use_shell: bool | None = None,
ignore_returncode: bool | None = None,
executable: Path | None = None,
pre: Sequence[Sequence[str]] | None = None,
post: Sequence[Sequence[str]] | None = None) -> None:

if isinstance(cmd, str):
cmd = shlex.split(cmd)
self.cmd: Final = tuple(cmd) if cmd else None
self.cwd: Final = cwd
self.env: Final = env
self.env_file: Final = env_file
self.help: Final = help
self.use_shell: Final = use_shell
self.ignore_returncode: Final = ignore_returncode
self.executable: Final = executable
self.pre: Final = tuple(tuple(i) for i in pre) if pre else None
self.post: Final = tuple(tuple(i) for i in post) if post else None
Expand Down Expand Up @@ -424,6 +442,8 @@ def run(self, project: PyProject, name: str, args: Sequence[str]) -> int:
returncode = self._run_tasks(project, pre_tasks)
if not returncode and self.cmd:
returncode = self._run(project, name, args)
if self.ignore_returncode:
returncode = 0
if post_tasks and not returncode:
returncode = self._run_tasks(project, post_tasks)
return returncode
Expand Down Expand Up @@ -453,13 +473,18 @@ def _run(self, project: PyProject, name: str, args: Sequence[str]) -> int:
# See the warning about explicitly passing executable:
# https://docs.python.org/3/library/subprocess.html#subprocess.Popen
executable = which(path, env["PATH"])
return subprocess.run(args, cwd=cwd, env=env, executable=executable).returncode # noqa: S603
if self.use_shell:
args[0] = str(executable)
command_string = " ".join(args)
return subprocess.run(command_string, cwd=cwd, env=env, shell=True).returncode # noqa: S602
return subprocess.run(args, cwd=cwd, env=env, executable=executable).returncode # noqa: S603


@staticmethod
def _run_tasks(project: PyProject, tasks: Sequence[tuple[str, Task, Sequence[str]]]) -> int:
for name, task, args in tasks:
returncode = task.run(project, name, args)
if returncode:
if returncode and not task.ignore_returncode:
return returncode
return 0

Expand Down Expand Up @@ -515,8 +540,25 @@ def parse(cls, entry: str | Sequence[str] | Mapping[str, Any]) -> Task:
pass
case value:
raise ValueError(f"Invalid 'env-file' value: {value!r}")

match entry.get("use-shell"):
case bool(use_shell) if use_shell:
pass
case None as use_shell:
pass
case value:
raise ValueError(f"Invalid 'use-shell' value: {value!r}")

match entry.get("ignore-returncode"):
case bool(ignore_returncode) if ignore_returncode:
pass
case None as ignore_returncode:
pass
case value:
raise ValueError(f"Invalid 'ignore-returncode' value: {value!r}")

else:
cwd = env = env_file = None
cwd = env = env_file = use_shell = ignore_returncode = None

match entry.get("help"):
case str(help) if help and not help.isspace():
Expand Down Expand Up @@ -551,8 +593,8 @@ def parse(cls, entry: str | Sequence[str] | Mapping[str, Any]) -> Task:
if not (cmd or pre_tasks or post_tasks):
raise ValueError("Task must define at least one of 'cmd', 'pre', or 'post'")

return cls(cmd, cwd=cwd, env=env, env_file=env_file, help=help,
pre=pre_tasks, post=post_tasks)
return cls(cmd, cwd=cwd, env=env, env_file=env_file, help=help, use_shell = use_shell,
ignore_returncode=ignore_returncode, pre=pre_tasks, post=post_tasks)

@staticmethod
def _parse_tasks(tasks: list[Any]) -> list[list[str]] | None:
Expand Down