diff --git a/cognite/extractorutils/metrics.py b/cognite/extractorutils/metrics.py index 2313a11c..ff91581b 100644 --- a/cognite/extractorutils/metrics.py +++ b/cognite/extractorutils/metrics.py @@ -458,3 +458,6 @@ def stop(self) -> None: self._push_to_server() self.upload_queue.stop() self.cancellation_token.cancel() + + +MetricsType = TypeVar("MetricsType", bound=BaseMetrics) diff --git a/cognite/extractorutils/unstable/core/base.py b/cognite/extractorutils/unstable/core/base.py index 6b4af39e..9befd387 100644 --- a/cognite/extractorutils/unstable/core/base.py +++ b/cognite/extractorutils/unstable/core/base.py @@ -22,6 +22,11 @@ class MyConfig(ExtractorConfig): another_parameter: int schedule: ScheduleConfig + class MyMetrics(BaseMetrics): + def __init__(self, extractor_name: str, extractor_version: str): + super().__init__(extractor_name, extractor_version) + self.custom_counter = Counter("custom_counter", "A custom counter") + class MyExtractor(Extractor[MyConfig]): NAME = "My Extractor" EXTERNAL_ID = "my-extractor" @@ -30,6 +35,9 @@ class MyExtractor(Extractor[MyConfig]): CONFIG_TYPE = MyConfig + # Override metrics type annotation for IDE support + metrics: MyMetrics + def __init_tasks__(self) -> None: self.add_task( ScheduledTask( @@ -42,6 +50,8 @@ def __init_tasks__(self) -> None: def my_task_function(self, task_context: TaskContext) -> None: task_context.logger.info("Running my task") + # IDE will now autocomplete custom_counter + self.metrics.custom_counter.inc() """ import logging @@ -59,7 +69,7 @@ def my_task_function(self, task_context: TaskContext) -> None: from typing_extensions import Self, assert_never from cognite.extractorutils._inner_util import _resolve_log_level -from cognite.extractorutils.metrics import BaseMetrics +from cognite.extractorutils.metrics import BaseMetrics, MetricsType, safe_get from cognite.extractorutils.statestore import ( AbstractStateStore, LocalStateStore, @@ -117,11 +127,13 @@ def __init__( application_config: _T, current_config_revision: ConfigRevision, log_level_override: str | None = None, + metrics_class: type[MetricsType] | None = None, ) -> None: self.connection_config = connection_config self.application_config = application_config self.current_config_revision: ConfigRevision = current_config_revision self.log_level_override = log_level_override + self.metrics_class: type[MetricsType] | None = metrics_class class Extractor(Generic[ConfigType], CogniteLogger): @@ -149,9 +161,7 @@ class Extractor(Generic[ConfigType], CogniteLogger): cancellation_token: CancellationToken - def __init__( - self, config: FullConfig[ConfigType], checkin_worker: CheckinWorker, metrics: BaseMetrics | None = None - ) -> None: + def __init__(self, config: FullConfig[ConfigType], checkin_worker: CheckinWorker) -> None: self._logger = logging.getLogger(f"{self.EXTERNAL_ID}.main") self._checkin_worker = checkin_worker @@ -175,7 +185,8 @@ def __init__( self._tasks: list[Task] = [] self._start_time: datetime - self._metrics: BaseMetrics | None = metrics + + self.metrics: BaseMetrics = self._load_metrics(config.metrics_class) self.metrics_push_manager = ( self.metrics_config.create_manager(self.cognite_client, cancellation_token=self.cancellation_token) @@ -262,6 +273,16 @@ def _setup_logging(self) -> None: "Defaulted to console logging." ) + def _load_metrics(self, metrics_class: type[MetricsType] | None = None) -> MetricsType | BaseMetrics: + """ + Loads metrics based on the provided metrics class. + + Reuses existing singleton if available to avoid Prometheus registry conflicts. + """ + if metrics_class: + return safe_get(metrics_class) + return safe_get(BaseMetrics, extractor_name=self.EXTERNAL_ID, extractor_version=self.VERSION) + def _load_state_store(self) -> None: """ Searches through the config object for a StateStoreConfig. @@ -379,10 +400,8 @@ def restart(self) -> None: self.cancellation_token.cancel() @classmethod - def _init_from_runtime( - cls, config: FullConfig[ConfigType], checkin_worker: CheckinWorker, metrics: BaseMetrics - ) -> Self: - return cls(config, checkin_worker, metrics) + def _init_from_runtime(cls, config: FullConfig[ConfigType], checkin_worker: CheckinWorker) -> Self: + return cls(config, checkin_worker) def add_task(self, task: Task) -> None: """ diff --git a/cognite/extractorutils/unstable/core/runtime.py b/cognite/extractorutils/unstable/core/runtime.py index 78ec701b..2cdae9c3 100644 --- a/cognite/extractorutils/unstable/core/runtime.py +++ b/cognite/extractorutils/unstable/core/runtime.py @@ -47,7 +47,7 @@ def main() -> None: CogniteAuthError, CogniteConnectionError, ) -from cognite.extractorutils.metrics import BaseMetrics +from cognite.extractorutils.metrics import BaseMetrics, MetricsType from cognite.extractorutils.threading import CancellationToken from cognite.extractorutils.unstable.configuration.exceptions import InvalidArgumentError, InvalidConfigError from cognite.extractorutils.unstable.configuration.loaders import ( @@ -79,16 +79,13 @@ def _extractor_process_entrypoint( controls: _RuntimeControls, config: FullConfig, checkin_worker: CheckinWorker, - metrics: BaseMetrics | None = None, ) -> None: logger = logging.getLogger(f"{extractor_class.EXTERNAL_ID}.runtime") checkin_worker.active_revision = config.current_config_revision checkin_worker.set_on_fatal_error_handler(lambda _: on_fatal_error(controls)) checkin_worker.set_on_revision_change_handler(lambda _: on_revision_changed(controls)) checkin_worker.set_retry_startup(extractor_class.RETRY_STARTUP) - if not metrics: - metrics = BaseMetrics(extractor_name=extractor_class.NAME, extractor_version=extractor_class.VERSION) - extractor = extractor_class._init_from_runtime(config, checkin_worker, metrics) + extractor = extractor_class._init_from_runtime(config, checkin_worker) extractor._attach_runtime_controls( cancel_event=controls.cancel_event, message_queue=controls.message_queue, @@ -138,13 +135,13 @@ class Runtime(Generic[ExtractorType]): def __init__( self, extractor: type[ExtractorType], - metrics: BaseMetrics | None = None, + metrics: type[MetricsType] | None = None, ) -> None: self._extractor_class = extractor self._cancellation_token = CancellationToken() self._cancellation_token.cancel_on_interrupt() self._message_queue: Queue[RuntimeMessage] = Queue() - self._metrics = metrics + self._metrics_class = metrics self.logger = logging.getLogger(f"{self._extractor_class.EXTERNAL_ID}.runtime") self._setup_logging() self._cancel_event: MpEvent | None = None @@ -273,7 +270,7 @@ def _spawn_extractor( process = Process( target=_extractor_process_entrypoint, - args=(self._extractor_class, controls, config, checkin_worker, self._metrics), + args=(self._extractor_class, controls, config, checkin_worker), ) process.start() @@ -477,6 +474,14 @@ def _main_runtime(self, args: Namespace) -> None: if not args.skip_init_checks and not self._verify_connection_config(connection_config): sys.exit(1) + if self._metrics_class is not None and ( + not isinstance(self._metrics_class, type) or not issubclass(self._metrics_class, BaseMetrics) + ): + self.logger.critical( + "The provided metrics class does not inherit from BaseMetrics. Metrics will not be collected." + ) + sys.exit(1) + # This has to be Any. We don't know the type of the extractors' config at type checking since the self doesn't # exist yet, and I have not found a way to represent it in a generic way that isn't just an Any in disguise. application_config: Any @@ -507,6 +512,7 @@ def _main_runtime(self, args: Namespace) -> None: application_config=application_config, current_config_revision=current_config_revision, log_level_override=args.log_level, + metrics_class=self._metrics_class, ), checkin_worker, ) diff --git a/tests/conftest.py b/tests/conftest.py index 58e53fe4..a20f2127 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,17 +5,47 @@ from enum import Enum import pytest +from prometheus_client.core import REGISTRY from cognite.client import CogniteClient from cognite.client.config import ClientConfig from cognite.client.credentials import OAuthClientCredentials from cognite.client.data_classes.data_modeling import NodeId from cognite.client.exceptions import CogniteAPIError, CogniteNotFoundError +from cognite.extractorutils import metrics NUM_NODES = 5000 NUM_EDGES = NUM_NODES // 100 +@pytest.fixture(autouse=True) +def reset_singleton() -> Generator[None, None, None]: + """ + This fixture ensures that the _metrics_singularities + class variables are reset, and Prometheus collectors are unregistered, + providing test isolation. + """ + # Clean up before test + metrics._metrics_singularities.clear() + + # Unregister all collectors to prevent "Duplicated timeseries" errors + collectors = list(REGISTRY._collector_to_names.keys()) + for collector in collectors: + with contextlib.suppress(Exception): + REGISTRY.unregister(collector) + + yield + + # Clean up after test + metrics._metrics_singularities.clear() + + # Unregister all collectors again + collectors = list(REGISTRY._collector_to_names.keys()) + for collector in collectors: + with contextlib.suppress(Exception): + REGISTRY.unregister(collector) + + class ETestType(Enum): TIME_SERIES = "time_series" CDM_TIME_SERIES = "cdm_time_series" diff --git a/tests/test_unstable/conftest.py b/tests/test_unstable/conftest.py index 4ebf88b7..bd5001b5 100644 --- a/tests/test_unstable/conftest.py +++ b/tests/test_unstable/conftest.py @@ -1,7 +1,6 @@ import gzip import json import os -from collections import Counter from collections.abc import Callable, Generator, Iterator from threading import RLock from time import sleep, time @@ -10,6 +9,7 @@ import pytest import requests_mock +from prometheus_client.core import Counter from cognite.client import CogniteClient from cognite.client.config import ClientConfig diff --git a/tests/test_unstable/test_base.py b/tests/test_unstable/test_base.py index 1f8527d8..68348f67 100644 --- a/tests/test_unstable/test_base.py +++ b/tests/test_unstable/test_base.py @@ -303,10 +303,11 @@ def counting_push(self: CognitePusher) -> None: application_config=app_config, current_config_revision=1, log_level_override=override_level, + metrics_class=TestMetrics, ) worker = get_checkin_worker(connection_config) - extractor = TestExtractor(full_config, worker, metrics=TestMetrics) - assert isinstance(extractor._metrics, TestMetrics) or extractor._metrics == TestMetrics + extractor = TestExtractor(full_config, worker) + assert isinstance(extractor.metrics, TestMetrics) with contextlib.ExitStack() as stack: stack.enter_context(contextlib.suppress(Exception)) diff --git a/tests/test_unstable/test_runtime.py b/tests/test_unstable/test_runtime.py index 5536916a..29f7757f 100644 --- a/tests/test_unstable/test_runtime.py +++ b/tests/test_unstable/test_runtime.py @@ -15,14 +15,41 @@ from typing_extensions import Self from cognite.examples.unstable.extractors.simple_extractor.main import SimpleExtractor +from cognite.extractorutils.metrics import BaseMetrics from cognite.extractorutils.unstable.configuration.exceptions import InvalidArgumentError from cognite.extractorutils.unstable.configuration.models import ConnectionConfig from cognite.extractorutils.unstable.core.base import ConfigRevision, FullConfig from cognite.extractorutils.unstable.core.checkin_worker import CheckinWorker from cognite.extractorutils.unstable.core.runtime import Runtime +from cognite.extractorutils.unstable.core.tasks import StartupTask, TaskContext from test_unstable.conftest import TestConfig, TestExtractor, TestMetrics +class MetricsTestExtractor(SimpleExtractor): + """Custom extractor for testing metrics in multiprocessing context.""" + + def __init_tasks__(self) -> None: + super().__init_tasks__() + + def test_metrics_task(context: TaskContext) -> None: + # Increment counter twice + self.metrics.a_counter.inc() + self.metrics.a_counter.inc() + + # Log the counter value so we can verify it in output + counter_value = self.metrics.a_counter._value.get() + context.info(f"METRICS_TEST: Counter value is {counter_value}") + + # Add startup task to test metrics + self.add_task( + StartupTask( + name="test-metrics", + description="Test metrics increment", + target=test_metrics_task, + ) + ) + + @pytest.fixture def local_config_file() -> Generator[Path, None, None]: file = Path(__file__).parent.parent.parent / f"test-{randint(0, 1000000)}.yaml" @@ -396,11 +423,156 @@ def test_logging_on_windows_with_import_error( assert mock_root_logger.addHandler.call_count == 1 -def test_extractor_with_metrics() -> None: - runtime = Runtime(TestExtractor, metrics=TestMetrics) - assert isinstance(runtime._metrics, TestMetrics) or runtime._metrics == TestMetrics +def test_extractor_with_metrics( + connection_config: ConnectionConfig, tmp_path: Path, monkeypatch: MonkeyPatch, capfd: pytest.CaptureFixture[str] +) -> None: + """ + Test metrics_class is properly passed through Runtime to child process. + This test verifies multiprocessing integration with metrics and counter increments. + """ + cfg_dir = Path("cognite/examples/unstable/extractors/simple_extractor/config") + base_conn = cfg_dir / "connection_config.yaml" + base_app = cfg_dir / "config.yaml" + + conn_file = tmp_path / f"test-{randint(0, 1000000)}-connection_config.yaml" + _write_conn_from_fixture(base_conn, conn_file, connection_config) + + app_file = tmp_path / f"test-{randint(0, 1000000)}-config.yaml" + app_file.write_text(base_app.read_text(encoding="utf-8")) + + argv = [ + "simple-extractor", + "--cwd", + str(tmp_path), + "-c", + conn_file.name, + "-f", + app_file.name, + "--skip-init-checks", + "-l", + "info", + ] + + monkeypatch.setattr(sys, "argv", argv) + + runtime = Runtime(MetricsTestExtractor, metrics=TestMetrics) + + # Verify runtime stores metrics class + assert runtime._metrics_class is TestMetrics, "Runtime should store TestMetrics class" + + child_holder = {} + original_spawn = Runtime._spawn_extractor + + def spy_spawn(self: Self, config: FullConfig, checkin_worker: CheckinWorker) -> Process: + assert config.metrics_class is TestMetrics, "FullConfig should carry TestMetrics class" + + p = original_spawn( + self, + config, + checkin_worker, + ) + child_holder["proc"] = p + return p + + monkeypatch.setattr(Runtime, "_spawn_extractor", spy_spawn, raising=True) + + t = Thread(target=runtime.run, name="RuntimeMain") + t.start() + + start = time.time() + while "proc" not in child_holder and time.time() - start < 10: + time.sleep(0.05) + + assert "proc" in child_holder, "Extractor process was not spawned in time." + proc = child_holder["proc"] + + time.sleep(1.5) # Give more time for the startup task to run + + runtime._cancellation_token.cancel() + + t.join(timeout=30) + assert not t.is_alive(), "Runtime did not shut down within timeout after cancellation." + + proc.join(timeout=0) + assert not proc.is_alive(), "Extractor process is still alive" + + out, err = capfd.readouterr() + combined = (out or "") + (err or "") - # The metrics instance should be a singleton - another_runtime = Runtime(TestExtractor, metrics=TestMetrics) - assert another_runtime._metrics is runtime._metrics - assert isinstance(another_runtime._metrics, TestMetrics) or another_runtime._metrics == TestMetrics + # Verify metrics counter was incremented + assert "METRICS_TEST: Counter value is 2" in combined, ( + f"Expected metrics counter to be 2 in child process.\nCaptured output:\n{combined}" + ) + + +class InvalidMetrics: + """A dummy metrics class that does not inherit from BaseMetrics.""" + + pass + + +@pytest.mark.parametrize( + "metrics_input, should_raise", + [ + (TestMetrics, False), + (TestMetrics(), True), + (InvalidMetrics, True), + (None, False), + ], +) +def test_metrics_class_validation_parametrized( + caplog: pytest.LogCaptureFixture, metrics_input: type[BaseMetrics] | None, should_raise: bool +) -> None: + """ + Combined parameterized test for metrics class validation behavior. + For cases that should not raise, we only assert the runtime stored the value. + For invalid cases we assert _main_runtime exits with SystemExit(1). + """ + runtime = Runtime(TestExtractor, metrics=metrics_input) + + if should_raise: + mock_connection_config = MagicMock() + args = MagicMock( + connection_config=[Path("dummy.yaml")], + force_local_config=None, + cwd=None, + skip_init_checks=True, + log_level="info", + ) + + with ( + patch("cognite.extractorutils.unstable.core.runtime.load_file", return_value=mock_connection_config), + pytest.raises(SystemExit) as excinfo, + ): + runtime._main_runtime(args) + + assert excinfo.value.code == 1 + assert any( + "The provided metrics class does not inherit from BaseMetrics" in record.message + for record in caplog.records + ), f"Expected critical log not found. Captured logs: {[r.message for r in caplog.records]}" + else: + assert runtime._metrics_class is metrics_input + + +def test_type_checker_would_catch_invalid_metrics() -> None: + """ + This test validates that the type signature itself is correct and would + provide IDE/linter feedback to developers. + """ + from typing import get_type_hints + + hints = get_type_hints(Runtime.__init__) + assert "metrics" in hints + metrics_type_str = str(hints["metrics"]) + + # Verify it expects a type/class, not an instance + assert "type[" in metrics_type_str or "Type[" in metrics_type_str, ( + f"Expected metrics parameter to be type[...], got: {metrics_type_str}" + ) + + # Verify None is allowed (Optional) + # Python 3.10 uses typing.Optional[X], 3.10+ can use X | None + assert any(pattern in metrics_type_str for pattern in ["None", "| None", "Optional"]), ( + f"Expected metrics parameter to be Optional, got: {metrics_type_str}" + ) diff --git a/tests/tests_unit/test_metrics.py b/tests/tests_unit/test_metrics.py index d94439df..6d2f0505 100644 --- a/tests/tests_unit/test_metrics.py +++ b/tests/tests_unit/test_metrics.py @@ -23,13 +23,14 @@ from cognite.client import CogniteClient from cognite.client.data_classes import Asset from cognite.client.exceptions import CogniteDuplicatedError, CogniteNotFoundError -from cognite.extractorutils import metrics from cognite.extractorutils.metrics import CognitePusher, safe_get # For testing PrometheusPusher @pytest.fixture def altered_metrics() -> ModuleType: + from cognite.extractorutils import metrics + altered_metrics = metrics altered_metrics.delete_from_gateway = Mock() altered_metrics.pushadd_to_gateway = Mock() @@ -179,11 +180,12 @@ def test_init_existing_all(MockCogniteClient: Mock) -> None: @patch("cognite.client.CogniteClient") def test_push(MockCogniteClient: Mock) -> None: - init_gauge() + gauge = Gauge("gauge", "Test gauge") + client: CogniteClient = MockCogniteClient() pusher = CognitePusher(client, "pre_", push_interval=1) - GaugeSetUp.gauge.set(5) + gauge.set(5) pusher._push_to_server() client.time_series.data.insert_multiple.assert_called_once() @@ -201,7 +203,7 @@ def test_push(MockCogniteClient: Mock) -> None: @patch("cognite.client.CogniteClient") def test_push_creates_missing_timeseries(MockCogniteClient: Mock) -> None: """Test that push logic creates missing time series when enabled.""" - init_gauge() + gauge = Gauge("gauge", "Test gauge") client: CogniteClient = MockCogniteClient() # Create a mock CogniteNotFoundError with not_found and failed attributes @@ -217,7 +219,7 @@ def test_push_creates_missing_timeseries(MockCogniteClient: Mock) -> None: pusher = CognitePusher(client, "pre_", push_interval=1) - GaugeSetUp.gauge.set(5) + gauge.set(5) pusher._push_to_server() # Assert that we tried to create the timeseries