diff --git a/CHANGELOG.md b/CHANGELOG.md index f3925e84..3629d51a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ Changes are grouped as follows - `Fixed` for any bug fixes. - `Security` in case of vulnerabilities. +## 7.12.1 + +### Added +* In the `unstable` package: Adds context to the config validation layer. This also allows us to customize validation. + ## 7.12.0 ### Added diff --git a/cognite/extractorutils/__init__.py b/cognite/extractorutils/__init__.py index a8cec2e1..cdfe2b11 100644 --- a/cognite/extractorutils/__init__.py +++ b/cognite/extractorutils/__init__.py @@ -16,7 +16,7 @@ Cognite extractor utils is a Python package that simplifies the development of new extractors. """ -__version__ = "7.12.0" +__version__ = "7.12.1" from .base import Extractor __all__ = ["Extractor"] diff --git a/cognite/extractorutils/unstable/configuration/loaders.py b/cognite/extractorutils/unstable/configuration/loaders.py index b718fd87..2f2e849d 100644 --- a/cognite/extractorutils/unstable/configuration/loaders.py +++ b/cognite/extractorutils/unstable/configuration/loaders.py @@ -6,7 +6,7 @@ from enum import Enum from io import StringIO from pathlib import Path -from typing import TextIO, TypeVar +from typing import Any, TextIO, TypeVar from cognite.client import CogniteClient from cognite.client.exceptions import CogniteAPIError @@ -36,13 +36,14 @@ class ConfigFormat(Enum): YAML = "yaml" -def load_file(path: Path, schema: type[_T]) -> _T: +def load_file(path: Path, schema: type[_T], context: dict[str, Any] | None = None) -> _T: """ Load a configuration file from the given path and parse it into the specified schema. Args: path: Path to the configuration file. schema: The schema class to parse the configuration into. + context: Optional Pydantic validation context; see ``load_dict`` for semantics. Returns: An instance of the schema populated with the configuration data. @@ -58,11 +59,15 @@ def load_file(path: Path, schema: type[_T]) -> _T: raise InvalidConfigError(f"Unknown file type {path.suffix}") with open(path) as stream: - return load_io(stream, file_format, schema) + return load_io(stream, file_format, schema, context=context) def load_from_cdf( - cognite_client: CogniteClient, external_id: str, schema: type[_T], revision: int | None = None + cognite_client: CogniteClient, + external_id: str, + schema: type[_T], + revision: int | None = None, + context: dict[str, Any] | None = None, ) -> tuple[_T, int]: """ Load a configuration from a CDF integration using the provided external ID and schema. @@ -72,6 +77,7 @@ def load_from_cdf( external_id: The external ID of the integration to load configuration from. schema: The schema class to parse the configuration into. revision: the specific revision of the configuration to load, otherwise get the latest. + context: Optional Pydantic validation context; see ``load_dict`` for semantics. Returns: A tuple containing the parsed configuration instance and the revision number. @@ -97,7 +103,7 @@ def load_from_cdf( data = response.json() try: - return load_io(StringIO(data["config"]), ConfigFormat.YAML, schema), data["revision"] + return load_io(StringIO(data["config"]), ConfigFormat.YAML, schema, context), data["revision"] except InvalidConfigError as e: e.attempted_revision = data["revision"] @@ -108,7 +114,7 @@ def load_from_cdf( raise new_e from e -def load_io(stream: TextIO, file_format: ConfigFormat, schema: type[_T]) -> _T: +def load_io(stream: TextIO, file_format: ConfigFormat, schema: type[_T], context: dict[str, Any] | None = None) -> _T: """ Load a configuration from a stream (e.g., file or string) and parse it into the specified schema. @@ -116,6 +122,7 @@ def load_io(stream: TextIO, file_format: ConfigFormat, schema: type[_T]) -> _T: stream: A text stream containing the configuration data. file_format: The format of the configuration data. schema: The schema class to parse the configuration into. + context: Optional Pydantic validation context; see ``load_dict`` for semantics. Returns: An instance of the schema populated with the configuration data. @@ -134,7 +141,7 @@ def load_io(stream: TextIO, file_format: ConfigFormat, schema: type[_T]) -> _T: if "key-vault" in data: data.pop("key-vault") - return load_dict(data, schema) + return load_dict(data, schema, context=context) def _make_loc_str(loc: tuple) -> str: @@ -155,13 +162,20 @@ def _make_loc_str(loc: tuple) -> str: return loc_str -def load_dict(data: dict, schema: type[_T]) -> _T: +def load_dict(data: dict, schema: type[_T], context: dict[str, Any] | None = None) -> _T: """ Load a configuration from a dictionary and parse it into the specified schema. Args: data: A dictionary containing the configuration data. schema: The schema class to parse the configuration into. + context: Optional Pydantic validation context: forwarded to + ``schema.model_validate(..., context=...)`` and exposed to validators as + ``ValidationInfo.context``. Pydantic reuses one dict for the entire validation + run, so validators can add or change keys to pass data to validators that run + later (for example a model validator stashing derived data for nested field + validators). The dict you pass in is therefore mutated in place; pass a fresh + dict if you need the original object unchanged after load. Returns: An instance of the schema populated with the configuration data. @@ -170,7 +184,7 @@ def load_dict(data: dict, schema: type[_T]) -> _T: InvalidConfigError: If the configuration is invalid. """ try: - return schema.model_validate(data) + return schema.model_validate(data, context=context if context is not None else {}) except ValidationError as e: messages = [] diff --git a/pyproject.toml b/pyproject.toml index 77ece1f1..bda4ed8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cognite-extractor-utils" -version = "7.12.0" +version = "7.12.1" description = "Utilities for easier development of extractors for CDF" authors = [ {name = "Mathias Lohne", email = "mathias.lohne@cognite.com"} diff --git a/tests/test_unstable/test_configuration.py b/tests/test_unstable/test_configuration.py index be1393b4..08bc8e5e 100644 --- a/tests/test_unstable/test_configuration.py +++ b/tests/test_unstable/test_configuration.py @@ -1,22 +1,27 @@ import os +import re from io import StringIO +from typing import Any from unittest.mock import Mock import pytest from cognite.client.credentials import OAuthClientCredentials from cognite.client.data_classes import DataSet -from pydantic import Field +from pydantic import Field, ValidationInfo, field_validator, model_validator from cognite.extractorutils.exceptions import InvalidConfigError +from cognite.extractorutils.unstable.configuration.exceptions import InvalidConfigError as UnstableInvalidConfigError from cognite.extractorutils.unstable.configuration.loaders import ConfigFormat, load_io from cognite.extractorutils.unstable.configuration.models import ( ConfigModel, ConnectionConfig, EitherIdConfig, + ExtractorConfig, FileSizeConfig, LogLevel, TimeIntervalConfig, WithDataSetId, + _ClientCertificateConfig, _ClientCredentialsConfig, ) @@ -137,6 +142,25 @@ - thumbprint1 - thumbprint2 """ +TEST_REMOTE_CONFIG = """ +--- +sources: +- name: abc + option: option1 + +- name: def + option: option2 + +tasks: +- name: task1 + source: abc + +- name: task2 + source: def + +- name: task3 + source: ghi +""" @pytest.mark.parametrize("config_str", [CONFIG_EXAMPLE_ONLY_REQUIRED, CONFIG_EXAMPLE_ONLY_REQUIRED2]) @@ -148,6 +172,7 @@ def test_load_from_io(config_str: str) -> None: assert config.base_url == "https://baseurl.com" assert config.integration.external_id == "test-pipeline" assert config.authentication.type == "client-credentials" + assert isinstance(config.authentication, _ClientCredentialsConfig) assert config.authentication.client_secret == "very_secret123" assert list(config.authentication.scopes) == ["scopea", "scopeb"] @@ -165,6 +190,7 @@ def test_full_config_client_credentials(config_str: str) -> None: assert config.authentication.type == "client-credentials" assert config.authentication.client_id == "testid" + assert isinstance(config.authentication, _ClientCredentialsConfig) assert config.authentication.client_secret == "very_secret123" assert config.authentication.token_url == "https://get-a-token.com/token" assert list(config.authentication.scopes) == ["scopea", "scopeb"] @@ -192,6 +218,7 @@ def test_full_config_client_certificates(config_str: str) -> None: assert config.authentication.type == "client-certificate" assert config.authentication.client_id == "testid" + assert isinstance(config.authentication, _ClientCertificateConfig) assert config.authentication.password == "very-strong-password" assert config.authentication.path.as_posix() == "/path/to/cert.pem" assert config.authentication.authority_url == "https://you-are-authorized.com" @@ -272,6 +299,50 @@ def test_file_size_config_equality() -> None: assert file_size_3 != file_size_1 +class Source(ConfigModel): + name: str + option: str + + +class TaskConfig(ConfigModel): + name: str + source: str + + @field_validator( + "source", + mode="after", + ) + @classmethod + def validate_instance(cls, value: str, info: ValidationInfo) -> str: + source_names = (info.context or {}).get("source_names", []) + if value not in source_names: + raise ValueError(f"'{value}' is not defined in the list of sources") + return value + + +class TestRemoteConfig(ExtractorConfig): + sources: list[Source] + tasks: list[TaskConfig] + + @model_validator(mode="before") + @classmethod + def map_instances(cls, data: dict[str, Any], validation_info: ValidationInfo) -> dict[str, Any]: + if validation_info.context is not None: + validation_info.context.update( + {"source_names": [source["name"] for source in data.get("sources", [])]}, + ) + return data + + +def test_config_with_context() -> None: + stream = StringIO(TEST_REMOTE_CONFIG) + with pytest.raises( + UnstableInvalidConfigError, + match=re.escape("Invalid config: 'ghi' is not defined in the list of sources: tasks[2].source"), + ): + load_io(stream, ConfigFormat.YAML, TestRemoteConfig) + + @pytest.mark.parametrize( "expression", ["12.3kbkb", "10XY", "abcMB", "5.5.5GB", "MB", "", " ", "10 M B", "10MB extra", "tenMB"] )