diff --git a/config.example.yml b/config.example.yml index 64aed0b..bb71f4d 100644 --- a/config.example.yml +++ b/config.example.yml @@ -64,13 +64,20 @@ permissions: # - package_selector: # origin: # local: true - # namespace: "qpy" + # namespace: qpy # auto_grant_permissions: # memory: 1 GiB # override_permissions: # cpus: 2 #packages: +environment_variables: + # Environment variables that are passed to every worker + #global: + + # Package- and request-specific environment variables + #packages: + cache: # Maximum cache size #size: 1 GiB diff --git a/docs/qppe-server.yaml b/docs/qppe-server.yaml index 56ce83c..3df4338 100644 --- a/docs/qppe-server.yaml +++ b/docs/qppe-server.yaml @@ -1165,6 +1165,7 @@ components: type: string enum: - PACKAGE_PERMISSION_ERROR + - PACKAGE_ENVIRONMENT_VARIABLES_ERROR - QUEUE_WAITING_TIMEOUT - WORKER_TIMEOUT - OUT_OF_MEMORY @@ -1179,6 +1180,7 @@ components: - SERVER_ERROR description: > * `PACKAGE_PERMISSION_ERROR` - The package requested more permissions than allowed. + * `PACKAGE_ENVIRONMENT_VARIABLES_ERROR` - The package requires unprovided environment variables. * `QUEUE_WAITING_TIMEOUT` - The request has been waiting too long in a job queue. Try again later. * `WORKER_TIMEOUT` - Question package did not answer in a reasonable amount of time. * `OUT_OF_MEMORY` - Question package reached its memory limit. diff --git a/questionpy_common/constants.py b/questionpy_common/constants.py index d19c025..47ce7f2 100644 --- a/questionpy_common/constants.py +++ b/questionpy_common/constants.py @@ -2,9 +2,9 @@ # QuestionPy is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus import re -from typing import Final +from typing import Annotated, Final -from pydantic import ByteSize +from pydantic import ByteSize, Field # General. KiB: Final[int] = 1024 @@ -26,3 +26,6 @@ FORM_REFERENCE_PATTERN: Final[re.Pattern[str]] = re.compile( r"^([a-zA-Z_][a-zA-Z0-9_]*|\.\.)(\[([a-zA-Z_][a-zA-Z0-9_]*|\.\.)?])*$" ) + +ENVIRONMENT_VARIABLE_REGEX: Final[str] = r"[a-zA-Z_][a-zA-Z0-9_]*" +ENVIRONMENT_VARIABLE = Annotated[str, Field(pattern=f"^{ENVIRONMENT_VARIABLE_REGEX}$")] diff --git a/questionpy_common/manifest.py b/questionpy_common/manifest.py index f0d06ab..ccb9a79 100644 --- a/questionpy_common/manifest.py +++ b/questionpy_common/manifest.py @@ -10,6 +10,8 @@ from pydantic import BaseModel, ByteSize, PositiveInt, StringConstraints, conset, field_validator from pydantic.fields import Field +from questionpy_common.constants import ENVIRONMENT_VARIABLE + class PackageType(StrEnum): LIBRARY = "LIBRARY" @@ -111,6 +113,7 @@ class SourceManifest(BaseModel): type: PackageType = DEFAULT_PACKAGETYPE license: str | None = None permissions: PartialPackagePermissions | None = None + environment_variables: set[ENVIRONMENT_VARIABLE] | None = None tags: set[str] = set() requirements: str | list[str] | None = None diff --git a/questionpy_server/models.py b/questionpy_server/models.py index f831414..6746d20 100644 --- a/questionpy_server/models.py +++ b/questionpy_server/models.py @@ -117,6 +117,7 @@ class AttemptScoredResponse(AttemptScoredModel, PackageDependenciesModel): class RequestErrorCode(Enum): PACKAGE_PERMISSION_ERROR = "PACKAGE_PERMISSION_ERROR" + PACKAGE_ENVIRONMENT_VARIABLES_ERROR = "PACKAGE_ENVIRONMENT_VARIABLES_ERROR" QUEUE_WAITING_TIMEOUT = "QUEUE_WAITING_TIMEOUT" WORKER_TIMEOUT = "WORKER_TIMEOUT" OUT_OF_MEMORY = "OUT_OF_MEMORY" diff --git a/questionpy_server/settings.py b/questionpy_server/settings.py index 64005c8..ae4aeba 100644 --- a/questionpy_server/settings.py +++ b/questionpy_server/settings.py @@ -3,15 +3,27 @@ # (c) Technische Universität Berlin, innoCampus import builtins import logging +import os +import re from datetime import timedelta from pathlib import Path from pydoc import locate -from typing import Any, ClassVar, Final, Literal +from typing import Any, ClassVar, Final, Literal, Self import semver import yaml -from pydantic import BaseModel, ByteSize, DirectoryPath, HttpUrl, PositiveInt, conset, field_validator -from pydantic.fields import FieldInfo +from pydantic import ( + BaseModel, + ByteSize, + DirectoryPath, + HttpUrl, + PositiveInt, + RootModel, + conset, + field_validator, + model_validator, +) +from pydantic.fields import Field, FieldInfo from pydantic_settings import ( BaseSettings, EnvSettingsSource, @@ -20,7 +32,7 @@ SettingsConfigDict, ) -from questionpy_common.constants import MAX_PACKAGE_SIZE, GiB, MiB +from questionpy_common.constants import ENVIRONMENT_VARIABLE, ENVIRONMENT_VARIABLE_REGEX, MAX_PACKAGE_SIZE, GiB, MiB from questionpy_common.manifest import PartialPackagePermissions, ensure_is_valid_name from questionpy_server.worker import Worker from questionpy_server.worker.impl.subprocess import SubprocessWorker @@ -155,8 +167,11 @@ def validate_name(cls, value: str) -> str: MainProcessExecutionModeValues = {"container", "trusted"} -class SpecificPackagePermissions(BaseModel): +class Selectable(BaseModel): package_selector: PackageSelector = PackageSelector() + + +class SpecificPackagePermissions(Selectable): auto_grant_permissions: PartialPackagePermissions | None = None override_permissions: PartialPackagePermissions | None = None @@ -195,6 +210,53 @@ class PackagePermissionsSettings(BaseModel): packages: list[SpecificPackagePermissions] = [] +class EnvironmentVariables(RootModel[dict[ENVIRONMENT_VARIABLE, str]]): + interpolation_pattern: ClassVar[re.Pattern] = re.compile(rf"^\$\{{({ENVIRONMENT_VARIABLE_REGEX})}}$") + """Matches environment variable interpolation syntax for replacement. + + Identifies strings in the format ${VAR_NAME} that should be replaced with the + corresponding environment variable value. + + Example: + "${MY_VAR}" matches and becomes the value of the environment variable "MY_VAR" + + """ + + escaped_interpolation_pattern: ClassVar[re.Pattern] = re.compile(rf"^\$(\$+\{{{ENVIRONMENT_VARIABLE_REGEX}}})$") + """Matches escaped interpolation syntax to prevent variable replacement. + + Identifies strings with escaped dollar signs that should be unescaped rather + than treated as environment variable references. + + Example: + "$${MY_VAR}" matches and becomes "${MY_VAR}" (literal string) + """ + + @model_validator(mode="after") + def check_environment_variables(self) -> Self: + for key, value in self.root.items(): + if match := self.interpolation_pattern.match(value): + interpolated_key = match.group(1) + if interpolated_key not in os.environ: + msg = f"Environment variable '{interpolated_key}' not found." + raise ValueError(msg) + self.root[key] = os.environ[interpolated_key] + _log.debug("Interpolated environment variable: %s=%s.", key, self.root[key]) + elif match := self.escaped_interpolation_pattern.match(value): + self.root[key] = match.group(1) + _log.debug("Escaped environment variable: %s=%s.", key, self.root[key]) + return self + + +class SpecificPackageEnvironmentVariables(Selectable): + environment_variables: EnvironmentVariables | None = None + + +class EnvironmentVariablesSettings(BaseModel): + global_: EnvironmentVariables = Field(alias="global", default=EnvironmentVariables({})) + packages: list[SpecificPackageEnvironmentVariables] = [] + + class CacheSettings(BaseModel): size: ByteSize = ByteSize(1 * GiB) directory: DirectoryPath = Path("cache").resolve() @@ -274,6 +336,7 @@ class Settings(BaseSettings): webservice: WebserviceSettings worker_pool: WorkerPoolSettings permissions: PackagePermissionsSettings + environment_variables: EnvironmentVariablesSettings cache: CacheSettings collector: CollectorSettings auth: AuthSettings diff --git a/questionpy_server/web/_routes/_files.py b/questionpy_server/web/_routes/_files.py index 17703fc..62d4c10 100644 --- a/questionpy_server/web/_routes/_files.py +++ b/questionpy_server/web/_routes/_files.py @@ -9,6 +9,7 @@ from questionpy_server.web import CURRENT_USER_KEY from questionpy_server.web._decorators import ensure_package from questionpy_server.web.app import QPyServer +from questionpy_server.worker.selector import SelectorQuery file_routes = web.RouteTableDef() @@ -26,10 +27,14 @@ async def serve_static_file(request: web.Request, package: Package) -> web.Respo raise HTTPNotImplemented(text="Static file retrieval from non-main packages is not supported yet.") current_user = request.get(CURRENT_USER_KEY) - permissions = qpy_server.package_permissions.get_effective_permissions(package, current_user, "files") + selector_query = SelectorQuery(package, current_user, "files") + permissions = qpy_server.package_permissions.get(selector_query) + environment_variables = qpy_server.environment_variables.get(selector_query) location = await package.get_zip_package_location() - async with qpy_server.worker_pool.get_worker(location, current_user, "files", permissions) as worker: + async with qpy_server.worker_pool.get_worker( + location, current_user, "files", permissions, environment_variables + ) as worker: try: file = await worker.get_static_file(path) except FileNotFoundError as e: diff --git a/questionpy_server/web/_worker_context.py b/questionpy_server/web/_worker_context.py index 408fbab..1b31c5c 100644 --- a/questionpy_server/web/_worker_context.py +++ b/questionpy_server/web/_worker_context.py @@ -12,6 +12,7 @@ from questionpy_server.web import CURRENT_USER_KEY from questionpy_server.web.app import QPyServer from questionpy_server.worker import Worker +from questionpy_server.worker.selector import SelectorQuery def get_request_info( @@ -36,14 +37,20 @@ async def worker_context(request: web.Request, package: Package, data: RequestBa """Returns the worker context for the given request.""" qpyserver = request.app[QPyServer.APP_KEY] current_user = request.get(CURRENT_USER_KEY) - permissions = qpyserver.package_permissions.get_effective_permissions(package, current_user, data.context) + + selector_query = SelectorQuery(package, current_user, data.context) + permissions = qpyserver.package_permissions.get(selector_query) + environment_variables = qpyserver.environment_variables.get(selector_query) + location = await package.get_zip_package_location() lms_provided_attributes = None if isinstance(data, LmsProvidedAttributesModel): lms_provided_attributes = data.lms_provided_attributes - async with qpyserver.worker_pool.get_worker(location, current_user, data.context, permissions) as worker: + async with qpyserver.worker_pool.get_worker( + location, current_user, data.context, permissions, environment_variables + ) as worker: yield WorkerContext( worker, get_request_info(request, lms_provided_attributes=lms_provided_attributes), diff --git a/questionpy_server/web/app.py b/questionpy_server/web/app.py index 9b81a81..915d5b4 100644 --- a/questionpy_server/web/app.py +++ b/questionpy_server/web/app.py @@ -14,8 +14,9 @@ from questionpy_server.collector import PackageCollection from questionpy_server.settings import Settings from questionpy_server.web.middlewares import middlewares -from questionpy_server.worker.permissions import PackagePermissionsHandler from questionpy_server.worker.pool import WorkerPool +from questionpy_server.worker.selector.environment_variables import EnvironmentVariablesHandler +from questionpy_server.worker.selector.permissions import PackagePermissionsHandler _log = logging.getLogger(__name__) @@ -36,6 +37,7 @@ def __init__(self, settings: Settings): settings.worker_pool.max_cpus, settings.worker_pool.max_memory, worker_type=settings.worker_pool.type ) self.package_permissions = PackagePermissionsHandler(settings.permissions) + self.environment_variables = EnvironmentVariablesHandler(settings.environment_variables) cache_supervisor = LRUCacheSupervisor(settings.cache.directory, settings.cache.size) self.package_cache = LRUCache(cache_supervisor, Path("packages"), extension=".qpy") diff --git a/questionpy_server/web/errors.py b/questionpy_server/web/errors.py index bf3fbf2..7a8eb62 100644 --- a/questionpy_server/web/errors.py +++ b/questionpy_server/web/errors.py @@ -23,6 +23,18 @@ def __init__(self, msg: str, body: RequestError) -> None: web_logger.info(msg) +class PackageEnvironmentVariablesError(web.HTTPForbidden, _ExceptionMixin): + def __init__(self, *, reason: str | None, temporary: bool) -> None: + super().__init__( + msg="Question package requires environment variables that are not provided by the server", + body=RequestError( + error_code=RequestErrorCode.PACKAGE_ENVIRONMENT_VARIABLES_ERROR, + reason=reason, + temporary=temporary, + ), + ) + + class PackagePermissionError(web.HTTPForbidden, _ExceptionMixin): def __init__(self, *, reason: str | None, temporary: bool) -> None: super().__init__( @@ -156,6 +168,7 @@ def __init__(self, *, reason: str | None, temporary: bool) -> None: QpyWebError = ( PackagePermissionError + | PackageEnvironmentVariablesError | WorkerTimeoutError | OutOfMemoryError | InvalidAttemptStateError diff --git a/questionpy_server/web/middlewares/_error.py b/questionpy_server/web/middlewares/_error.py index 2ef23c2..332ee49 100644 --- a/questionpy_server/web/middlewares/_error.py +++ b/questionpy_server/web/middlewares/_error.py @@ -18,13 +18,15 @@ WorkerRealTimeLimitExceededError, WorkerStartError, ) -from questionpy_server.worker.permissions import PackagePermissionError from questionpy_server.worker.runtime.messages import WorkerMemoryLimitExceededError, WorkerUnknownError +from questionpy_server.worker.selector.environment_variables import PackageEnvironmentVariablesError +from questionpy_server.worker.selector.permissions import PackagePermissionError exception_map: dict[type[QPyBaseError], type[web_error.QpyWebError]] = { InvalidAttemptStateError: web_error.InvalidAttemptStateError, InvalidQuestionStateError: web_error.InvalidQuestionStateError, ManifestError: web_error.InvalidPackageError, + PackageEnvironmentVariablesError: web_error.PackageEnvironmentVariablesError, PackagePermissionError: web_error.PackagePermissionError, StaticFileSizeMismatchError: web_error.InvalidPackageError, WorkerCPUTimeLimitExceededError: web_error.WorkerTimeoutError, diff --git a/questionpy_server/worker/__init__.py b/questionpy_server/worker/__init__.py index db3f4c3..ece7f18 100644 --- a/questionpy_server/worker/__init__.py +++ b/questionpy_server/worker/__init__.py @@ -61,6 +61,8 @@ class WorkerArgs(TypedDict): """An existing directory owned by the worker, with the same lifetime as the worker.""" permissions: PackagePermissions """The package permissions.""" + environment_variables: dict[str, str] + """Environment variables to be set in the worker.""" class Worker(ABC): @@ -72,6 +74,7 @@ def __init__(self, **kwargs: Unpack[WorkerArgs]) -> None: self.package = kwargs["package"] self.worker_home = kwargs["worker_home"] self.permissions = kwargs["permissions"] + self.environment_variables = kwargs["environment_variables"] self.state = WorkerState.NOT_RUNNING self.loaded_packages: list[LoadedPackage] = [] diff --git a/questionpy_server/worker/impl/subprocess.py b/questionpy_server/worker/impl/subprocess.py index 44f48d5..fd67a19 100644 --- a/questionpy_server/worker/impl/subprocess.py +++ b/questionpy_server/worker/impl/subprocess.py @@ -108,7 +108,7 @@ class SubprocessWorker(BaseWorker, LimitTimeUsageMixin): _worker_type = "process" - # Allows to use a patched runtime in tests. + # Allows using a patched runtime in tests. _runtime_main = ["-m", "questionpy_server.worker.runtime"] def __init__(self, **kwargs: Unpack[WorkerArgs]): @@ -117,18 +117,17 @@ def __init__(self, **kwargs: Unpack[WorkerArgs]): self._proc: Process | None = None self._stderr_buffer: _StderrBuffer | None = None + if "OPENBLAS_NUM_THREADS" not in self.environment_variables: + # OpenBLAS is used by NumPy and creates a number of threads on import. + # Each thread allocates a bunch of virtual memory, so more than 2 threads break the default memory limit. + # By default, the number of threads is proportional to the available CPUs. + self.environment_variables["OPENBLAS_NUM_THREADS"] = "2" + async def start(self) -> None: """Start the worker process.""" # Turn off the worker's __debug__ flag unless ours is set as well. python_flags = [] if __debug__ else ["-O"] - env = { - # OpenBLAS is used by NumPy and creates a number of threads on import. - # Each thread allocates a bunch of virtual memory, so more than 2 threads breaks the default memory limit. - # By default, the number of threads is proportional to the available CPUs. - "OPENBLAS_NUM_THREADS": "2" - } - self._proc = await asyncio.create_subprocess_exec( sys.executable, *python_flags, @@ -136,7 +135,7 @@ async def start(self) -> None: stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - env=env, + env=self.environment_variables, cwd=self.worker_home, start_new_session=True, ) diff --git a/questionpy_server/worker/pool.py b/questionpy_server/worker/pool.py index 81b22af..37ce09b 100644 --- a/questionpy_server/worker/pool.py +++ b/questionpy_server/worker/pool.py @@ -100,7 +100,12 @@ def _memory_available(self, required_memory: int) -> bool: @asynccontextmanager async def get_worker( - self, package: PackageLocation, user: str | None, context: str, permissions: PackagePermissions + self, + package: PackageLocation, + user: str | None, + context: str, + permissions: PackagePermissions, + environment_variables: dict[str, str], ) -> AsyncIterator[Worker]: """Get a (new) worker executing a QuestionPy package. @@ -111,6 +116,7 @@ async def get_worker( user: the user requesting the worker context: context within the lms permissions: package permissions + environment_variables: environment variables to be set in the worker Returns: A worker @@ -132,7 +138,9 @@ async def get_worker( # `Lock.acquire` is not explicitly documented as fair. This ensures that no starvation occurs. async with self._lock, self._condition: await self._condition.wait_for(lambda: self._memory_available(permissions.memory)) - worker = await self._create_or_reuse_worker(package, user, context, permissions) + worker = await self._create_or_reuse_worker( + package, user, context, permissions, environment_variables + ) self._workers_in_use += 1 yield worker @@ -200,7 +208,12 @@ def _generate_worker_name(self, package: PackageLocation) -> str: return f"{package_part}-{index}" async def _create_or_reuse_worker( - self, package: PackageLocation, user: str | None, context: str, permissions: PackagePermissions + self, + package: PackageLocation, + user: str | None, + context: str, + permissions: PackagePermissions, + environment_variables: dict[str, str], ) -> Worker: """If possible, get an idle worker or create a new one.""" # Since the `PackagePermissions` only dependent on the `user` and `context` the worker @@ -223,7 +236,13 @@ async def _create_or_reuse_worker( worker_home = self._working_dir / f"worker-{name}" await asyncio.to_thread(worker_home.mkdir) - worker = self._worker_type(name=name, package=package, permissions=permissions, worker_home=worker_home) + worker = self._worker_type( + name=name, + package=package, + permissions=permissions, + worker_home=worker_home, + environment_variables=environment_variables, + ) await worker.start() # Reserve the memory. diff --git a/questionpy_server/worker/selector/__init__.py b/questionpy_server/worker/selector/__init__.py new file mode 100644 index 0000000..d576caf --- /dev/null +++ b/questionpy_server/worker/selector/__init__.py @@ -0,0 +1,67 @@ +# This file is part of the QuestionPy Server. (https://questionpy.org) +# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. +# (c) Technische Universität Berlin, innoCampus +from abc import ABC, abstractmethod +from typing import NamedTuple + +from questionpy_server.cache import LRUCacheMemory +from questionpy_server.package import Package +from questionpy_server.settings import PackageSelector, Selectable + + +class SelectorQuery(NamedTuple): + package: Package + user: str | None + context: str + + +def _is_wildcard_matching(selector_value: str, package_value: str) -> bool: + return selector_value in {package_value, "*"} + + +def _is_matching(selector: PackageSelector, query: SelectorQuery) -> bool: + return ( + # Package data. + _is_wildcard_matching(selector.hash, query.package.hash) + and _is_wildcard_matching(selector.namespace, query.package.manifest.namespace) + and _is_wildcard_matching(selector.short_name, query.package.manifest.short_name) + and (selector.version == "*" or query.package.manifest.version.match(selector.version)) + # Package origin. + and _is_wildcard_matching(selector.origin.repositories, "*") # TODO: handle repositories + and (selector.origin.local is None or selector.origin.local == query.package.sources.is_local()) + and _is_wildcard_matching(selector.origin.users, "*") # TODO: handle users + # Request data. + and _is_wildcard_matching(selector.request_user, str(query.user) if query.user else "") + and _is_wildcard_matching(selector.request_context, query.context) + ) + + +class Selector[T: Selectable, V](ABC): + """Provides helpful methods for getting package and request specific data.""" + + def __init__(self, selectables: list[T]): + self._cache: LRUCacheMemory[SelectorQuery, V] = LRUCacheMemory(max_size=128) + self._selectables = selectables + + def _get_matching(self, query: SelectorQuery) -> T | None: + """Gets the first matching selectable, if any. + + It assumes that the selectables are ordered from least specific to most specific. + """ + for selectable in reversed(self._selectables): + if _is_matching(selectable.package_selector, query): + return selectable + return None + + def get(self, query: SelectorQuery) -> V: + """Gets the result of the query, which may be cached.""" + if cached := self._cache.get(query): + return cached + + result = self._get(query) + self._cache.put(query, result) + + return result + + @abstractmethod + def _get(self, query: SelectorQuery) -> V: ... diff --git a/questionpy_server/worker/selector/environment_variables.py b/questionpy_server/worker/selector/environment_variables.py new file mode 100644 index 0000000..a9b29b9 --- /dev/null +++ b/questionpy_server/worker/selector/environment_variables.py @@ -0,0 +1,37 @@ +# This file is part of the QuestionPy Server. (https://questionpy.org) +# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. +# (c) Technische Universität Berlin, innoCampus +from questionpy_common.error import QPyBaseError +from questionpy_server.settings import EnvironmentVariablesSettings, SpecificPackageEnvironmentVariables +from questionpy_server.worker.selector import Selector, SelectorQuery + + +class PackageEnvironmentVariablesError(QPyBaseError): + pass + + +class EnvironmentVariablesHandler(Selector[SpecificPackageEnvironmentVariables, dict[str, str]]): + """Handles environment variables for a request.""" + + def __init__(self, settings: EnvironmentVariablesSettings): + super().__init__(settings.packages) + + self._global_environment_variables = settings.global_.root + + def _get(self, query: SelectorQuery) -> dict[str, str]: + environment_variables = self._global_environment_variables.copy() + + if (specific := self._get_matching(query)) and specific.environment_variables: + environment_variables.update(specific.environment_variables.root) + + requested_environment_variables = query.package.manifest.environment_variables + if requested_environment_variables and not requested_environment_variables.issubset( + environment_variables.keys() + ): + msg = ( + f"The package '{query.package.hash}' requested environment variables that are not provided by the " + "server." + ) + raise PackageEnvironmentVariablesError(msg) + + return environment_variables diff --git a/questionpy_server/worker/permissions.py b/questionpy_server/worker/selector/permissions.py similarity index 58% rename from questionpy_server/worker/permissions.py rename to questionpy_server/worker/selector/permissions.py index 8772721..ab7da61 100644 --- a/questionpy_server/worker/permissions.py +++ b/questionpy_server/worker/selector/permissions.py @@ -2,19 +2,17 @@ # The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus import logging -from typing import NamedTuple from questionpy_common.environment import PackagePermissions as EnvironmentPackagePermissions from questionpy_common.error import QPyBaseError -from questionpy_server.cache import LRUCacheMemory from questionpy_server.package import Package from questionpy_server.settings import ( CompletePackagePermissions, MainProcessExecutionModeValues, PackagePermissionsSettings, - PackageSelector, SpecificPackagePermissions, ) +from questionpy_server.worker.selector import Selector, SelectorQuery _log = logging.getLogger(__name__) @@ -33,45 +31,14 @@ def _has_enough_permissions(allowed: CompletePackagePermissions, requested: Comp ) -def _is_wildcard_matching(selector_value: str, package_value: str) -> bool: - return selector_value in {package_value, "*"} - - -def _is_selector_matching(selector: PackageSelector, package: Package, user: str | None, context: str) -> bool: - return ( - # Package data. - _is_wildcard_matching(selector.hash, package.hash) - and _is_wildcard_matching(selector.namespace, package.manifest.namespace) - and _is_wildcard_matching(selector.short_name, package.manifest.short_name) - and (selector.version == "*" or package.manifest.version.match(selector.version)) - # Package origin. - and _is_wildcard_matching(selector.origin.repositories, "*") # TODO: handle repositories - and (selector.origin.local is None or selector.origin.local == package.sources.is_local()) - and _is_wildcard_matching(selector.origin.users, "*") # TODO: handle users - # Request data. - and _is_wildcard_matching(selector.request_user, str(user) if user else "") - and _is_wildcard_matching(selector.request_context, context) - ) - - -class _PackagePermissionIdentifier(NamedTuple): - package: Package - user: str | None - context: str - - -class PackagePermissionsHandler: +class PackagePermissionsHandler(Selector[SpecificPackagePermissions, EnvironmentPackagePermissions]): """Handles package permissions for a request.""" def __init__(self, settings: PackagePermissionsSettings): - self._default_permissions = CompletePackagePermissions() + super().__init__(settings.packages) + self._default_permissions = CompletePackagePermissions() self._auto_grant_permissions = settings.auto_grant_permissions - self._specific_package_permissions = settings.packages - - self._cache: LRUCacheMemory[_PackagePermissionIdentifier, EnvironmentPackagePermissions] = LRUCacheMemory( - max_size=128 - ) def _get_requested_permissions(self, package: Package) -> CompletePackagePermissions: requested_permissions = package.manifest.permissions @@ -102,31 +69,11 @@ def _get_actual_auto_grant_permissions(self, permissions: SpecificPackagePermiss specific_auto_grant_permissions = permissions.auto_grant_permissions.model_dump(exclude_none=True) return self._auto_grant_permissions.model_copy(update=specific_auto_grant_permissions) - def _get_specific_permissions( - self, package: Package, user: str | None, context: str - ) -> SpecificPackagePermissions | None: - # We want to select the last defined one if multiple selectors match. - for permissions in reversed(self._specific_package_permissions): - if _is_selector_matching(permissions.package_selector, package, user, context): - return permissions - return None - - def get_effective_permissions( - self, package: Package, user: str | None, context: str - ) -> EnvironmentPackagePermissions: - """Gets the effective permissions for a package. - - Raises: - PackagePermissionError: If the package does not have enough permissions. - """ - key = _PackagePermissionIdentifier(package, user, context) - if cached_permissions := self._cache.get(key): - return cached_permissions - + def _get(self, query: SelectorQuery) -> EnvironmentPackagePermissions: auto_grant_permissions = self._auto_grant_permissions - requested_permissions = self._get_requested_permissions(package) + requested_permissions = self._get_requested_permissions(query.package) - if specific_permissions := self._get_specific_permissions(package, user, context): + if specific_permissions := self._get_matching(query): auto_grant_permissions = self._get_actual_auto_grant_permissions(specific_permissions) if specific_permissions.override_permissions: @@ -136,12 +83,10 @@ def get_effective_permissions( requested_permissions = requested_permissions.model_copy(update=overrides) if not _has_enough_permissions(auto_grant_permissions, requested_permissions): - msg = f"The package '{package.hash}' requested more permissions than allowed." + msg = f"The package '{query.package.hash}' requested permissions that are not granted by the server." raise PackagePermissionError(msg) # Only keep explicitly allowed lms attributes. requested_permissions.lms_attributes.intersection_update(auto_grant_permissions.lms_attributes) - effective_permissions = EnvironmentPackagePermissions(**requested_permissions.model_dump()) - self._cache.put(key, effective_permissions) - return effective_permissions + return EnvironmentPackagePermissions(**requested_permissions.model_dump()) diff --git a/tests/conftest.py b/tests/conftest.py index c579cda..40e99ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ CacheSettings, CollectorSettings, CompletePackagePermissions, + EnvironmentVariablesSettings, GeneralSettings, PackagePermissionsSettings, Settings, @@ -131,6 +132,7 @@ def qpy_server(tmp_path_factory: pytest.TempPathFactory) -> QPyServer: webservice=WebserviceSettings(listen_address="127.0.0.1", listen_port=0), worker_pool=WorkerPoolSettings(type=ThreadWorker), permissions=PackagePermissionsSettings(), + environment_variables=EnvironmentVariablesSettings(), cache=CacheSettings(directory=tmp_path_factory.mktemp("qpy_cache")), collector=CollectorSettings(), auth=AuthSettings(enabled=False), diff --git a/tests/questionpy_server/test_settings.py b/tests/questionpy_server/test_settings.py index 8d94df0..9cb318a 100644 --- a/tests/questionpy_server/test_settings.py +++ b/tests/questionpy_server/test_settings.py @@ -30,6 +30,7 @@ def path_with_empty_config_file(tmp_path: Path) -> Path: "webservice": None, "worker_pool": None, "permissions": None, + "environment_variables": None, "cache": None, "collector": None, "auth": None, diff --git a/tests/questionpy_server/worker/impl/test_base.py b/tests/questionpy_server/worker/impl/test_base.py index 8846507..529a489 100644 --- a/tests/questionpy_server/worker/impl/test_base.py +++ b/tests/questionpy_server/worker/impl/test_base.py @@ -23,7 +23,7 @@ async def test_should_get_manifest(worker_pool: WorkerPool) -> None: - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: manifest = await worker.get_manifest() assert manifest == PACKAGE.manifest @@ -41,7 +41,7 @@ async def test_should_get_static_file( package: PackageLocation = dir_package if package_type == "dir" else package_factory.to_zip_package(dir_package) - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: static_file = await worker.get_static_file(_STATIC_FILE_NAME) assert static_file.data == _STATIC_FILE_CONTENT.encode() @@ -58,7 +58,7 @@ async def test_should_raise_file_not_found_error_when_not_in_manifest( package: PackageLocation = dir_package if package_type == "dir" else package_factory.to_zip_package(dir_package) - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with pytest.raises(FileNotFoundError): await worker.get_static_file(_STATIC_FILE_NAME) @@ -83,7 +83,7 @@ async def test_should_raise_file_not_found_error_when_file_is_outside( else package_factory.to_zip_package(dir_package, include_siblings=("my_secret_file",)) ) - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with caplog.at_level(logging.INFO), pytest.raises(FileNotFoundError): await worker.get_static_file("static/../../my_secret_file") @@ -103,7 +103,7 @@ async def test_should_raise_file_not_found_error_when_symlink_target_is_outside( package.inject_static_file_into_manifest("static/my_secret_file", len(content), "text/plain") - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with caplog.at_level(logging.INFO), pytest.raises(FileNotFoundError): await worker.get_static_file("static/my_secret_file") @@ -119,7 +119,7 @@ async def test_should_raise_file_not_found_error_when_not_on_disk( package: PackageLocation = dir_package if package_type == "dir" else package_factory.to_zip_package(dir_package) - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with pytest.raises(FileNotFoundError): await worker.get_static_file(_STATIC_FILE_NAME) @@ -134,7 +134,7 @@ async def test_should_raise_static_file_size_mismatch_error_when_sizes_dont_matc package: PackageLocation = dir_package if package_type == "dir" else package_factory.to_zip_package(dir_package) - async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(package, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with pytest.raises(StaticFileSizeMismatchError): await worker.get_static_file(_STATIC_FILE_NAME) @@ -163,13 +163,13 @@ def _make_get_manifest_raise() -> Iterator[None]: @pytest.mark.filterwarnings("ignore:Exception in thread worker-") async def test_should_gracefully_handle_error_in_bootstrap(worker_pool: WorkerPool) -> None: with patch_worker_pool(worker_pool, _make_bootstrap_raise), pytest.raises(WorkerStartError): - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS): + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}): pass async def test_should_gracefully_handle_error_in_loop(worker_pool: WorkerPool) -> None: with patch_worker_pool(worker_pool, _make_get_manifest_raise): - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: with pytest.raises(WorkerUnknownError, match="some custom error"): await worker.get_manifest() diff --git a/tests/questionpy_server/worker/impl/test_subprocess.py b/tests/questionpy_server/worker/impl/test_subprocess.py index 94dfc20..b6fbe6a 100644 --- a/tests/questionpy_server/worker/impl/test_subprocess.py +++ b/tests/questionpy_server/worker/impl/test_subprocess.py @@ -25,7 +25,7 @@ @pytest.mark.parametrize("worker_pool", [SubprocessWorker], indirect=True) async def test_should_apply_limits(worker_pool: WorkerPool) -> None: - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: assert isinstance(worker, SubprocessWorker) assert worker._proc # Python's resource package can only get the rlimit of other processes on Linux, so we use psutil. @@ -52,7 +52,7 @@ async def test_should_raise_cpu_timout_error(worker_pool: WorkerPool) -> None: start_time = time() # Change the timeout for faster testing. with pytest.raises(WorkerStartError) as exc_info, patch.object(BaseWorker, "_init_worker_timeout", 0.05): - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS): + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}): pass assert isinstance(exc_info.value.__cause__, WorkerCPUTimeLimitExceededError) assert 0.05 < (time() - start_time) < 0.5 @@ -79,7 +79,7 @@ async def test_should_raise_real_timout_error(worker_pool: WorkerPool) -> None: patch.object(BaseWorker, "_init_worker_timeout", 0.6), patch.object(LimitTimeUsageMixin, "_real_time_limit_factor", 1.0), ): - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS) as worker: + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}) as worker: await worker.get_manifest() assert isinstance(exc_info.value.__cause__, WorkerRealTimeLimitExceededError) assert 0.6 < (time() - start_time) < 2.0 diff --git a/tests/questionpy_server/worker/impl/test_thread.py b/tests/questionpy_server/worker/impl/test_thread.py index 36f0d6c..4ea66fd 100644 --- a/tests/questionpy_server/worker/impl/test_thread.py +++ b/tests/questionpy_server/worker/impl/test_thread.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("worker_pool", [ThreadWorker], indirect=True) async def test_should_ignore_limits(worker_pool: WorkerPool) -> None: with patch.object(resource, "setrlimit") as mock: - async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS): + async with worker_pool.get_worker(PACKAGE, "tester", "tests", DEFAULT_PACKAGE_PERMISSIONS, {}): pass mock.assert_not_called()