From e383e5a41f5eb861ed6ecbb2a5afe200500805ee Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:45:18 +0000 Subject: [PATCH 01/11] feat: Add singleton record fetch functionality to PyAirbyte Implements Source.get_record() and DeclarativeExecutor.fetch_record() methods to enable fetching single records by primary key value from declarative sources. Key features: - Source.get_record(stream_name, pk_value) - Public API for fetching records - DeclarativeExecutor.fetch_record() - Internal implementation using CDK components - Primary key validation and normalization (supports string, int, dict formats) - Composite primary key detection (raises NotImplementedError) - New AirbyteRecordNotFoundError exception for missing records - Comprehensive unit tests with proper mocking This implementation reuses existing CDK components (SimpleRetriever, HttpClient, RecordSelector) without monkey-patching or pinning CDK versions, providing a hybrid approach that works with the current CDK release. Related to CDK PR airbytehq/airbyte-python-cdk#846 Co-Authored-By: AJ Steers --- airbyte/_executors/declarative.py | 155 ++++++++++++++++++ airbyte/exceptions.py | 8 + airbyte/sources/base.py | 112 +++++++++++++ tests/unit_tests/test_get_record.py | 233 ++++++++++++++++++++++++++++ 4 files changed, 508 insertions(+) create mode 100644 tests/unit_tests/test_get_record.py diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index e227eca34..ec5cd06ad 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -15,7 +15,13 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, ) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) +from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.types import StreamSlice +from airbyte import exceptions as exc from airbyte._executors.base import Executor @@ -140,3 +146,152 @@ def install(self) -> None: def uninstall(self) -> None: """No-op. The declarative source is included with PyAirbyte.""" pass + + def fetch_record( # noqa: PLR0914 + self, + stream_name: str, + primary_key_value: str, + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Fetch a single record by primary key from a declarative stream. + + This method constructs an HTTP GET request to fetch a single record by appending + the primary key value to the stream's base path (e.g., /users/123). + + Args: + stream_name: The name of the stream to fetch from. + primary_key_value: The primary key value as a string. + config: Optional config overrides to merge with the executor's config. + + Returns: + The fetched record as a dictionary. + + Raises: + exc.AirbyteStreamNotFoundError: If the stream is not found in the manifest. + exc.AirbyteRecordNotFoundError: If the record is not found (empty response). + NotImplementedError: If the stream does not use SimpleRetriever. + """ + merged_config = {**self._config_dict, **(config or {})} + + stream_configs = self._manifest_dict.get("streams", []) + stream_config = None + for config_item in stream_configs: + if config_item.get("name") == stream_name: + stream_config = config_item + break + + if stream_config is None: + available_streams = [s.get("name") for s in stream_configs] + raise exc.AirbyteStreamNotFoundError( + stream_name=stream_name, + connector_name=self.name, + available_streams=available_streams, + message=f"Stream '{stream_name}' not found in manifest.", + ) + + factory = ModelToComponentFactory() + + retriever_config = stream_config.get("retriever") + if retriever_config is None: + raise NotImplementedError( + f"Stream '{stream_name}' does not have a retriever configuration. " + "fetch_record() is only supported for streams with retrievers." + ) + + try: + retriever = factory.create_component( + model_type=type(retriever_config), # type: ignore[arg-type] + component_definition=retriever_config, + config=merged_config, + ) + except Exception as e: + raise NotImplementedError( + f"Failed to create retriever for stream '{stream_name}': {e}" + ) from e + + if not isinstance(retriever, SimpleRetriever): + raise NotImplementedError( + f"Stream '{stream_name}' uses {type(retriever).__name__}, but fetch_record() " + "only supports SimpleRetriever." + ) + + empty_slice = StreamSlice(partition={}, cursor_slice={}) + base_path = retriever.requester.get_path( + stream_state={}, + stream_slice=empty_slice, + next_page_token=None, + ) + + fetch_path = f"{base_path}/{primary_key_value}".lstrip("/") + + response = retriever.requester.send_request( + path=fetch_path, + stream_state={}, + stream_slice=empty_slice, + next_page_token=None, + request_headers=retriever._request_headers( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_params=retriever._request_params( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_body_data=retriever._request_body_data( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_body_json=retriever._request_body_json( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + ) + + if response is None: + msg = ( + f"No response received when fetching record with primary key " + f"'{primary_key_value}' from stream '{stream_name}'." + ) + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + primary_key_value=primary_key_value, + connector_name=self.name, + message=msg, + ) + + schema = stream_config.get("schema_loader", {}) + records_schema = schema if isinstance(schema, dict) else {} + + records = list( + retriever.record_selector.select_records( + response=response, + stream_state={}, + records_schema=records_schema, + stream_slice=empty_slice, + next_page_token=None, + ) + ) + + if not records: + try: + response_json = response.json() + if isinstance(response_json, dict) and response_json: + return response_json + except Exception: + pass + + msg = ( + f"Record with primary key '{primary_key_value}' " + f"not found in stream '{stream_name}'." + ) + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + primary_key_value=primary_key_value, + connector_name=self.name, + message=msg, + ) + + first_record = records[0] + if hasattr(first_record, "data"): + return dict(first_record.data) # type: ignore[arg-type] + return dict(first_record) # type: ignore[arg-type] diff --git a/airbyte/exceptions.py b/airbyte/exceptions.py index e082f7a8d..f7cb1d133 100644 --- a/airbyte/exceptions.py +++ b/airbyte/exceptions.py @@ -412,6 +412,14 @@ class AirbyteStateNotFoundError(AirbyteConnectorError, KeyError): available_streams: list[str] | None = None +@dataclass +class AirbyteRecordNotFoundError(AirbyteConnectorError): + """Record not found in stream.""" + + stream_name: str | None = None + primary_key_value: str | None = None + + @dataclass class PyAirbyteSecretNotFoundError(PyAirbyteError): """Secret not found.""" diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 7fd5093c5..19c8f5555 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -30,6 +30,7 @@ from airbyte import exceptions as exc from airbyte._connector_base import ConnectorBase +from airbyte._executors.declarative import DeclarativeExecutor from airbyte._message_iterators import AirbyteMessageIterator from airbyte._util.temp_files import as_temp_files from airbyte.caches.util import get_default_cache @@ -601,6 +602,117 @@ def get_documents( render_metadata=render_metadata, ) + def _get_stream_primary_key(self, stream_name: str) -> list[str]: + """Get the primary key for a stream. + + Returns the primary key as a flat list of field names. + Handles the Airbyte protocol's nested list structure. + """ + catalog = self.configured_catalog + for configured_stream in catalog.streams: + if configured_stream.stream.name == stream_name: + pk = configured_stream.primary_key + if not pk: + return [] + if isinstance(pk, list) and len(pk) > 0: + if isinstance(pk[0], list): + return [field[0] if isinstance(field, list) else field for field in pk] + return list(pk) # type: ignore[arg-type] + return [] + raise exc.AirbyteStreamNotFoundError( + stream_name=stream_name, + connector_name=self.name, + available_streams=self.get_available_streams(), + ) + + def _normalize_and_validate_pk_value( + self, + stream_name: str, + pk_value: Any, # noqa: ANN401 + ) -> str: + """Normalize and validate a primary key value. + + Accepts: + - A string or int (converted to string) + - A dict with a single entry matching the stream's primary key field + + Returns the PK value as a string. + """ + primary_key_fields = self._get_stream_primary_key(stream_name) + + if not primary_key_fields: + raise exc.PyAirbyteInputError( + message=f"Stream '{stream_name}' does not have a primary key defined.", + input_value=str(pk_value), + ) + + if len(primary_key_fields) > 1: + raise NotImplementedError( + f"Stream '{stream_name}' has a composite primary key {primary_key_fields}. " + "Fetching by composite primary key is not yet supported." + ) + + pk_field = primary_key_fields[0] + + if isinstance(pk_value, dict): + if len(pk_value) != 1: + raise exc.PyAirbyteInputError( + message="When providing pk_value as a dict, it must contain exactly one entry.", + input_value=str(pk_value), + ) + provided_key = next(iter(pk_value.keys())) + if provided_key != pk_field: + msg = ( + f"Primary key field '{provided_key}' does not match " + f"stream's primary key '{pk_field}'." + ) + raise exc.PyAirbyteInputError( + message=msg, + input_value=str(pk_value), + ) + return str(pk_value[provided_key]) + + return str(pk_value) + + def get_record( + self, + stream_name: str, + *, + pk_value: Any, # noqa: ANN401 + ) -> dict[str, Any]: + """Fetch a single record by primary key value. + + This method is currently only supported for declarative (YAML-based) sources. + + Args: + stream_name: The name of the stream to fetch from. + pk_value: The primary key value. Can be: + - A string or integer value (e.g., "123" or 123) + - A dict with a single entry (e.g., {"id": "123"}) + + Returns: + The fetched record as a dictionary. + + Raises: + exc.AirbyteStreamNotFoundError: If the stream does not exist. + exc.AirbyteRecordNotFoundError: If the record is not found. + exc.PyAirbyteInputError: If the pk_value format is invalid. + NotImplementedError: If the source is not declarative or uses composite keys. + """ + if not isinstance(self.executor, DeclarativeExecutor): + raise NotImplementedError( + f"get_record() is only supported for declarative sources. " + f"This source uses {type(self.executor).__name__}." + ) + + pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) + + return self.executor.fetch_record( + stream_name=stream_name, + primary_key_value=pk_value_str, + config=self._config_dict, + ) + def get_samples( self, streams: list[str] | Literal["*"] | None = None, diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py new file mode 100644 index 000000000..34969b257 --- /dev/null +++ b/tests/unit_tests/test_get_record.py @@ -0,0 +1,233 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Unit tests for Source.get_record() and DeclarativeExecutor.fetch_record().""" + +from __future__ import annotations + +from unittest.mock import Mock, PropertyMock, patch + +import pytest + +from airbyte import exceptions as exc +from airbyte._executors.declarative import DeclarativeExecutor +from airbyte.sources.base import Source + + +@pytest.mark.parametrize( + "stream_name,pk_value,expected_error", + [ + pytest.param("users", "123", None, id="valid_stream_and_pk"), + pytest.param( + "nonexistent", "123", exc.AirbyteStreamNotFoundError, id="stream_not_found" + ), + ], +) +def test_declarative_executor_fetch_record_stream_validation( + stream_name: str, + pk_value: str, + expected_error: type[Exception] | None, +) -> None: + """Test stream validation in DeclarativeExecutor.fetch_record().""" + manifest = { + "streams": [ + { + "name": "users", + "retriever": { + "type": "SimpleRetriever", + "requester": { + "url_base": "https://api.example.com", + "path": "/users", + }, + "record_selector": {"extractor": {"field_path": []}}, + }, + } + ] + } + + executor = DeclarativeExecutor( + name="test-source", + manifest=manifest, + ) + + if expected_error: + with pytest.raises(expected_error): + executor.fetch_record(stream_name, pk_value) + else: + with patch.object(executor, "_manifest_dict", manifest): + with pytest.raises((NotImplementedError, AttributeError, KeyError)): + executor.fetch_record(stream_name, pk_value) + + +@pytest.mark.parametrize( + "primary_key,expected_result", + [ + pytest.param([["id"]], ["id"], id="nested_single_field"), + pytest.param(["id"], ["id"], id="flat_single_field"), + pytest.param([["id"], ["org_id"]], ["id", "org_id"], id="nested_composite"), + pytest.param([], [], id="no_primary_key"), + pytest.param(None, [], id="none_primary_key"), + ], +) +def test_source_get_stream_primary_key( + primary_key: list | None, + expected_result: list[str], +) -> None: + """Test _get_stream_primary_key() handles various PK formats.""" + mock_executor = Mock() + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + mock_stream = Mock() + mock_stream.stream.name = "test_stream" + mock_stream.primary_key = primary_key + + mock_catalog = Mock() + mock_catalog.streams = [mock_stream] + + with patch.object(type(source), "configured_catalog", new_callable=PropertyMock) as mock_prop: + mock_prop.return_value = mock_catalog + result = source._get_stream_primary_key("test_stream") + assert result == expected_result + + +def test_source_get_stream_primary_key_stream_not_found() -> None: + """Test _get_stream_primary_key() raises error for nonexistent stream.""" + mock_executor = Mock() + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + mock_catalog = Mock() + mock_catalog.streams = [] + + with patch.object(type(source), "configured_catalog", new_callable=PropertyMock) as mock_prop: + mock_prop.return_value = mock_catalog + with patch.object(source, "get_available_streams", return_value=[]): + with pytest.raises(exc.AirbyteStreamNotFoundError): + source._get_stream_primary_key("nonexistent_stream") + + +@pytest.mark.parametrize( + "pk_value,primary_key_fields,expected_result,expected_error", + [ + pytest.param("123", ["id"], "123", None, id="string_value"), + pytest.param(123, ["id"], "123", None, id="int_value"), + pytest.param({"id": "123"}, ["id"], "123", None, id="dict_with_correct_key"), + pytest.param( + {"wrong_key": "123"}, + ["id"], + None, + exc.PyAirbyteInputError, + id="dict_with_wrong_key", + ), + pytest.param( + {"id": "123", "extra": "456"}, + ["id"], + None, + exc.PyAirbyteInputError, + id="dict_with_multiple_entries", + ), + pytest.param( + "123", + ["id", "org_id"], + None, + NotImplementedError, + id="composite_primary_key", + ), + pytest.param("123", [], None, exc.PyAirbyteInputError, id="no_primary_key"), + ], +) +def test_source_normalize_and_validate_pk_value( + pk_value: any, + primary_key_fields: list[str], + expected_result: str | None, + expected_error: type[Exception] | None, +) -> None: + """Test _normalize_and_validate_pk_value() handles various input formats.""" + mock_executor = Mock() + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + with patch.object( + source, "_get_stream_primary_key", return_value=primary_key_fields + ): + if expected_error: + with pytest.raises(expected_error): + source._normalize_and_validate_pk_value("test_stream", pk_value) + else: + result = source._normalize_and_validate_pk_value("test_stream", pk_value) + assert result == expected_result + + +def test_source_get_record_requires_declarative_executor() -> None: + """Test get_record() raises NotImplementedError for non-declarative executors.""" + from airbyte._executors.python import VenvExecutor + + mock_executor = Mock(spec=VenvExecutor) + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + with pytest.raises( + NotImplementedError, match="only supported for declarative sources" + ): + source.get_record("test_stream", pk_value="123") + + +def test_source_get_record_calls_executor_fetch_record() -> None: + """Test get_record() calls executor.fetch_record() with correct parameters.""" + mock_executor = Mock(spec=DeclarativeExecutor) + mock_executor.fetch_record.return_value = {"id": "123", "name": "Test"} + + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + source._config_dict = {"api_key": "test"} + + with patch.object(source, "_normalize_and_validate_pk_value", return_value="123"): + result = source.get_record("test_stream", pk_value="123") + + assert result == {"id": "123", "name": "Test"} + mock_executor.fetch_record.assert_called_once_with( + stream_name="test_stream", + primary_key_value="123", + config={"api_key": "test"}, + ) + + +@pytest.mark.parametrize( + "pk_value", + [ + pytest.param("123", id="string_pk"), + pytest.param(123, id="int_pk"), + pytest.param({"id": "123"}, id="dict_pk"), + ], +) +def test_source_get_record_accepts_various_pk_formats(pk_value: any) -> None: + """Test get_record() accepts various PK value formats.""" + mock_executor = Mock(spec=DeclarativeExecutor) + mock_executor.fetch_record.return_value = {"id": "123", "name": "Test"} + + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + source._config_dict = {"api_key": "test"} + + with patch.object(source, "_normalize_and_validate_pk_value", return_value="123"): + result = source.get_record("test_stream", pk_value=pk_value) + + assert result == {"id": "123", "name": "Test"} + mock_executor.fetch_record.assert_called_once() From 1d28249ac9753a2959cdc812223a233f953e2488 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:48:05 +0000 Subject: [PATCH 02/11] style: Fix ruff formatting in test_get_record.py Co-Authored-By: AJ Steers --- tests/unit_tests/test_get_record.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index 34969b257..5fbe87afb 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -86,7 +86,9 @@ def test_source_get_stream_primary_key( mock_catalog = Mock() mock_catalog.streams = [mock_stream] - with patch.object(type(source), "configured_catalog", new_callable=PropertyMock) as mock_prop: + with patch.object( + type(source), "configured_catalog", new_callable=PropertyMock + ) as mock_prop: mock_prop.return_value = mock_catalog result = source._get_stream_primary_key("test_stream") assert result == expected_result @@ -104,7 +106,9 @@ def test_source_get_stream_primary_key_stream_not_found() -> None: mock_catalog = Mock() mock_catalog.streams = [] - with patch.object(type(source), "configured_catalog", new_callable=PropertyMock) as mock_prop: + with patch.object( + type(source), "configured_catalog", new_callable=PropertyMock + ) as mock_prop: mock_prop.return_value = mock_catalog with patch.object(source, "get_available_streams", return_value=[]): with pytest.raises(exc.AirbyteStreamNotFoundError): From eadcfb23ed998297d5f38fd9ec9d64e91758ae0c Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:24:59 +0000 Subject: [PATCH 03/11] refactor: Use already-instantiated streams in DeclarativeExecutor.fetch_record() - Remove ModelToComponentFactory usage in favor of accessing existing streams - Add _unwrap_to_declarative_stream() helper to navigate concurrent wrappers - Update fetch_record() to call declarative_source.streams() for existing components - Fix unit tests to mock declarative_source property correctly - Add type ignore comments for duck-typed attribute access Co-Authored-By: AJ Steers --- airbyte/_executors/declarative.py | 122 ++++++++++++++++++++-------- tests/unit_tests/test_get_record.py | 42 ++++++++-- 2 files changed, 123 insertions(+), 41 deletions(-) diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index ec5cd06ad..e453ee2eb 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -3,6 +3,7 @@ from __future__ import annotations +import contextlib import hashlib import warnings from pathlib import Path @@ -15,9 +16,6 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, ) -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( - ModelToComponentFactory, -) from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.types import StreamSlice @@ -44,6 +42,53 @@ def _suppress_cdk_pydantic_deprecation_warnings() -> None: ) +def _unwrap_to_declarative_stream(stream: object) -> object: + """Unwrap a concurrent stream wrapper to access the underlying declarative stream. + + This function uses duck-typing to navigate through various wrapper layers that may + exist around declarative streams, depending on the CDK version. It tries common + wrapper attribute names and returns the first object that has a 'retriever' attribute. + + Args: + stream: A stream object that may be wrapped (e.g., AbstractStream wrapper). + + Returns: + The underlying declarative stream object with a retriever attribute. + + Raises: + NotImplementedError: If unable to locate a declarative stream with a retriever. + """ + if hasattr(stream, "retriever"): + return stream + + wrapper_attrs = [ + "declarative_stream", + "wrapped_stream", + "stream", + "_stream", + "underlying_stream", + "inner", + ] + + for attr_name in wrapper_attrs: + if hasattr(stream, attr_name): + unwrapped = getattr(stream, attr_name) + if unwrapped is not None and hasattr(unwrapped, "retriever"): + return unwrapped + + for branch_attr in ["full_refresh_stream", "incremental_stream"]: + if hasattr(stream, branch_attr): + branch_stream = getattr(stream, branch_attr) + if branch_stream is not None and hasattr(branch_stream, "retriever"): + return branch_stream + + stream_type = type(stream).__name__ + raise NotImplementedError( + f"Unable to locate declarative stream with retriever from {stream_type}. " + f"fetch_record() requires access to the stream's retriever component." + ) + + class DeclarativeExecutor(Executor): """An executor for declarative sources.""" @@ -147,7 +192,7 @@ def uninstall(self) -> None: """No-op. The declarative source is included with PyAirbyte.""" pass - def fetch_record( # noqa: PLR0914 + def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 self, stream_name: str, primary_key_value: str, @@ -155,7 +200,8 @@ def fetch_record( # noqa: PLR0914 ) -> dict[str, Any]: """Fetch a single record by primary key from a declarative stream. - This method constructs an HTTP GET request to fetch a single record by appending + This method uses the already-instantiated streams from the declarative source + to access the stream's retriever and make an HTTP GET request by appending the primary key value to the stream's base path (e.g., /users/123). Args: @@ -167,47 +213,44 @@ def fetch_record( # noqa: PLR0914 The fetched record as a dictionary. Raises: - exc.AirbyteStreamNotFoundError: If the stream is not found in the manifest. + exc.AirbyteStreamNotFoundError: If the stream is not found. exc.AirbyteRecordNotFoundError: If the record is not found (empty response). NotImplementedError: If the stream does not use SimpleRetriever. """ merged_config = {**self._config_dict, **(config or {})} - stream_configs = self._manifest_dict.get("streams", []) - stream_config = None - for config_item in stream_configs: - if config_item.get("name") == stream_name: - stream_config = config_item - break + streams = self.declarative_source.streams(merged_config) - if stream_config is None: - available_streams = [s.get("name") for s in stream_configs] + target_stream = None + for stream in streams: + stream_name_attr = getattr(stream, "name", None) + if stream_name_attr == stream_name: + target_stream = stream + break + try: + unwrapped = _unwrap_to_declarative_stream(stream) + if getattr(unwrapped, "name", None) == stream_name: + target_stream = stream + break + except NotImplementedError: + continue + + if target_stream is None: + available_streams = [] + for s in streams: + name = getattr(s, "name", None) + if name: + available_streams.append(name) raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, connector_name=self.name, available_streams=available_streams, - message=f"Stream '{stream_name}' not found in manifest.", + message=f"Stream '{stream_name}' not found in source.", ) - factory = ModelToComponentFactory() - - retriever_config = stream_config.get("retriever") - if retriever_config is None: - raise NotImplementedError( - f"Stream '{stream_name}' does not have a retriever configuration. " - "fetch_record() is only supported for streams with retrievers." - ) + declarative_stream = _unwrap_to_declarative_stream(target_stream) - try: - retriever = factory.create_component( - model_type=type(retriever_config), # type: ignore[arg-type] - component_definition=retriever_config, - config=merged_config, - ) - except Exception as e: - raise NotImplementedError( - f"Failed to create retriever for stream '{stream_name}': {e}" - ) from e + retriever = declarative_stream.retriever # type: ignore[attr-defined] if not isinstance(retriever, SimpleRetriever): raise NotImplementedError( @@ -222,7 +265,10 @@ def fetch_record( # noqa: PLR0914 next_page_token=None, ) - fetch_path = f"{base_path}/{primary_key_value}".lstrip("/") + if base_path: + fetch_path = f"{base_path.rstrip('/')}/{primary_key_value}" + else: + fetch_path = primary_key_value response = retriever.requester.send_request( path=fetch_path, @@ -259,8 +305,12 @@ def fetch_record( # noqa: PLR0914 message=msg, ) - schema = stream_config.get("schema_loader", {}) - records_schema = schema if isinstance(schema, dict) else {} + records_schema = {} + if hasattr(declarative_stream, "schema_loader"): + schema_loader = declarative_stream.schema_loader + if hasattr(schema_loader, "get_json_schema"): + with contextlib.suppress(Exception): + records_schema = schema_loader.get_json_schema() records = list( retriever.record_selector.select_records( diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index 5fbe87afb..92bcd4d63 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -7,6 +7,8 @@ import pytest +from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever + from airbyte import exceptions as exc from airbyte._executors.declarative import DeclarativeExecutor from airbyte.sources.base import Source @@ -48,13 +50,43 @@ def test_declarative_executor_fetch_record_stream_validation( manifest=manifest, ) + mock_stream = Mock() + mock_stream.name = "users" + + mock_retriever = Mock(spec=SimpleRetriever) + mock_retriever.requester = Mock() + mock_retriever.requester.get_path = Mock(return_value="/users") + mock_retriever.requester.send_request = Mock( + return_value=Mock(json=lambda: {"id": "123"}) + ) + mock_retriever._request_headers = Mock(return_value={}) + mock_retriever._request_params = Mock(return_value={}) + mock_retriever._request_body_data = Mock(return_value=None) + mock_retriever._request_body_json = Mock(return_value=None) + mock_retriever.record_selector = Mock() + mock_retriever.record_selector.select_records = Mock(return_value=[{"id": "123"}]) + + mock_stream.retriever = mock_retriever + + mock_streams = [mock_stream] if stream_name == "users" else [] + + mock_declarative_source = Mock() + mock_declarative_source.streams = Mock(return_value=mock_streams) + if expected_error: - with pytest.raises(expected_error): - executor.fetch_record(stream_name, pk_value) - else: - with patch.object(executor, "_manifest_dict", manifest): - with pytest.raises((NotImplementedError, AttributeError, KeyError)): + with patch.object( + type(executor), "declarative_source", new_callable=PropertyMock + ) as mock_prop: + mock_prop.return_value = mock_declarative_source + with pytest.raises(expected_error): executor.fetch_record(stream_name, pk_value) + else: + with patch.object( + type(executor), "declarative_source", new_callable=PropertyMock + ) as mock_prop: + mock_prop.return_value = mock_declarative_source + result = executor.fetch_record(stream_name, pk_value) + assert result == {"id": "123"} @pytest.mark.parametrize( From 17f6c9ee085f39cff07fb1fce3d3ae18e6f33e9c Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:36:16 +0000 Subject: [PATCH 04/11] refactor: Remove unnecessary config parameter from fetch_record() - Remove config parameter from DeclarativeExecutor.fetch_record() - Remove config argument from Source.get_record() call to fetch_record() - Executor already has full config in self._config_dict, no need to pass it Co-Authored-By: AJ Steers --- airbyte/_executors/declarative.py | 6 +----- airbyte/sources/base.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index e453ee2eb..687dd1bf9 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -196,7 +196,6 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 self, stream_name: str, primary_key_value: str, - config: dict[str, Any] | None = None, ) -> dict[str, Any]: """Fetch a single record by primary key from a declarative stream. @@ -207,7 +206,6 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 Args: stream_name: The name of the stream to fetch from. primary_key_value: The primary key value as a string. - config: Optional config overrides to merge with the executor's config. Returns: The fetched record as a dictionary. @@ -217,9 +215,7 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 exc.AirbyteRecordNotFoundError: If the record is not found (empty response). NotImplementedError: If the stream does not use SimpleRetriever. """ - merged_config = {**self._config_dict, **(config or {})} - - streams = self.declarative_source.streams(merged_config) + streams = self.declarative_source.streams(self._config_dict) target_stream = None for stream in streams: diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 19c8f5555..9facef9ed 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -710,7 +710,6 @@ def get_record( return self.executor.fetch_record( stream_name=stream_name, primary_key_value=pk_value_str, - config=self._config_dict, ) def get_samples( From 42bfd1fbf7f82c73090ead09caf38f9c62894dab Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:42:33 +0000 Subject: [PATCH 05/11] test: Update test to match fetch_record() signature without config parameter The test was expecting the config parameter that was removed in the previous commit. Updated the assertion to match the new signature. Co-Authored-By: AJ Steers --- tests/unit_tests/test_get_record.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index 92bcd4d63..dec6b45c8 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -238,7 +238,6 @@ def test_source_get_record_calls_executor_fetch_record() -> None: mock_executor.fetch_record.assert_called_once_with( stream_name="test_stream", primary_key_value="123", - config={"api_key": "test"}, ) From cf94b877354d6e50030febe56f5835cec363417f Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:53:04 +0000 Subject: [PATCH 06/11] refactor: Use isinstance() checks and guard statements in fetch_record() - Remove _unwrap_to_declarative_stream() function entirely - Validate streams directly with isinstance(stream, AbstractStream) - Replace getattr() with direct attribute access after validation - Use guard statements instead of graceful error handling - Remove contextlib.suppress() around schema loading - Remove JSON fallback when select_records() returns empty - Update test to use Mock(spec=AbstractStream) for proper type checking This addresses review feedback to use strong typing and fail-fast error handling instead of duck-typing and graceful degradation. Co-Authored-By: AJ Steers --- airbyte/_executors/declarative.py | 108 ++++++---------------------- tests/unit_tests/test_get_record.py | 3 +- 2 files changed, 24 insertions(+), 87 deletions(-) diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index 687dd1bf9..ee38cfe8b 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -3,7 +3,6 @@ from __future__ import annotations -import contextlib import hashlib import warnings from pathlib import Path @@ -17,6 +16,7 @@ ConcurrentDeclarativeSource, ) from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.types import StreamSlice from airbyte import exceptions as exc @@ -42,53 +42,6 @@ def _suppress_cdk_pydantic_deprecation_warnings() -> None: ) -def _unwrap_to_declarative_stream(stream: object) -> object: - """Unwrap a concurrent stream wrapper to access the underlying declarative stream. - - This function uses duck-typing to navigate through various wrapper layers that may - exist around declarative streams, depending on the CDK version. It tries common - wrapper attribute names and returns the first object that has a 'retriever' attribute. - - Args: - stream: A stream object that may be wrapped (e.g., AbstractStream wrapper). - - Returns: - The underlying declarative stream object with a retriever attribute. - - Raises: - NotImplementedError: If unable to locate a declarative stream with a retriever. - """ - if hasattr(stream, "retriever"): - return stream - - wrapper_attrs = [ - "declarative_stream", - "wrapped_stream", - "stream", - "_stream", - "underlying_stream", - "inner", - ] - - for attr_name in wrapper_attrs: - if hasattr(stream, attr_name): - unwrapped = getattr(stream, attr_name) - if unwrapped is not None and hasattr(unwrapped, "retriever"): - return unwrapped - - for branch_attr in ["full_refresh_stream", "incremental_stream"]: - if hasattr(stream, branch_attr): - branch_stream = getattr(stream, branch_attr) - if branch_stream is not None and hasattr(branch_stream, "retriever"): - return branch_stream - - stream_type = type(stream).__name__ - raise NotImplementedError( - f"Unable to locate declarative stream with retriever from {stream_type}. " - f"fetch_record() requires access to the stream's retriever component." - ) - - class DeclarativeExecutor(Executor): """An executor for declarative sources.""" @@ -192,7 +145,7 @@ def uninstall(self) -> None: """No-op. The declarative source is included with PyAirbyte.""" pass - def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 + def fetch_record( self, stream_name: str, primary_key_value: str, @@ -219,24 +172,14 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 target_stream = None for stream in streams: - stream_name_attr = getattr(stream, "name", None) - if stream_name_attr == stream_name: + if not isinstance(stream, AbstractStream): + continue + if stream.name == stream_name: target_stream = stream break - try: - unwrapped = _unwrap_to_declarative_stream(stream) - if getattr(unwrapped, "name", None) == stream_name: - target_stream = stream - break - except NotImplementedError: - continue if target_stream is None: - available_streams = [] - for s in streams: - name = getattr(s, "name", None) - if name: - available_streams.append(name) + available_streams = [s.name for s in streams if isinstance(s, AbstractStream)] raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, connector_name=self.name, @@ -244,10 +187,15 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 message=f"Stream '{stream_name}' not found in source.", ) - declarative_stream = _unwrap_to_declarative_stream(target_stream) + if not hasattr(target_stream, "retriever"): + raise NotImplementedError( + f"Stream '{stream_name}' does not have a retriever attribute. " + f"fetch_record() requires access to the stream's retriever component." + ) - retriever = declarative_stream.retriever # type: ignore[attr-defined] + retriever = target_stream.retriever + # Guard: Retriever must be SimpleRetriever if not isinstance(retriever, SimpleRetriever): raise NotImplementedError( f"Stream '{stream_name}' uses {type(retriever).__name__}, but fetch_record() " @@ -289,24 +237,21 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 ), ) + # Guard: Response must not be None if response is None: - msg = ( - f"No response received when fetching record with primary key " - f"'{primary_key_value}' from stream '{stream_name}'." - ) raise exc.AirbyteRecordNotFoundError( stream_name=stream_name, primary_key_value=primary_key_value, connector_name=self.name, - message=msg, + message=f"No response received when fetching record with primary key " + f"'{primary_key_value}' from stream '{stream_name}'.", ) records_schema = {} - if hasattr(declarative_stream, "schema_loader"): - schema_loader = declarative_stream.schema_loader + if hasattr(target_stream, "schema_loader"): + schema_loader = target_stream.schema_loader if hasattr(schema_loader, "get_json_schema"): - with contextlib.suppress(Exception): - records_schema = schema_loader.get_json_schema() + records_schema = schema_loader.get_json_schema() records = list( retriever.record_selector.select_records( @@ -318,23 +263,14 @@ def fetch_record( # noqa: PLR0914, PLR0912, PLR0915 ) ) + # Guard: Records must not be empty if not records: - try: - response_json = response.json() - if isinstance(response_json, dict) and response_json: - return response_json - except Exception: - pass - - msg = ( - f"Record with primary key '{primary_key_value}' " - f"not found in stream '{stream_name}'." - ) raise exc.AirbyteRecordNotFoundError( stream_name=stream_name, primary_key_value=primary_key_value, connector_name=self.name, - message=msg, + message=f"Record with primary key '{primary_key_value}' " + f"not found in stream '{stream_name}'.", ) first_record = records[0] diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index dec6b45c8..1c1880568 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -8,6 +8,7 @@ import pytest from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte import exceptions as exc from airbyte._executors.declarative import DeclarativeExecutor @@ -50,7 +51,7 @@ def test_declarative_executor_fetch_record_stream_validation( manifest=manifest, ) - mock_stream = Mock() + mock_stream = Mock(spec=AbstractStream) mock_stream.name = "users" mock_retriever = Mock(spec=SimpleRetriever) From 8028397fa94c9bbdc19ed8f35853c0f1c4cef2a7 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:58:31 +0000 Subject: [PATCH 07/11] fix: Check stream name before type validation in fetch_record() Address review feedback to validate stream name match first, then check type compatibility. This ensures we raise NotImplementedError for found streams of incompatible types rather than silently skipping them. - Match stream by name first - Validate AbstractStream type after name match - Raise NotImplementedError with clear message for incompatible types - Remove isinstance() filter from available_streams list Co-Authored-By: AJ Steers --- airbyte/_executors/declarative.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index ee38cfe8b..9240f9f51 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -172,14 +172,17 @@ def fetch_record( target_stream = None for stream in streams: - if not isinstance(stream, AbstractStream): - continue if stream.name == stream_name: + if not isinstance(stream, AbstractStream): + raise NotImplementedError( + f"Stream '{stream_name}' is type {type(stream).__name__}; " + "fetch_record() supports only AbstractStream." + ) target_stream = stream break if target_stream is None: - available_streams = [s.name for s in streams if isinstance(s, AbstractStream)] + available_streams = [s.name for s in streams] raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, connector_name=self.name, From a2c4b42731d9be5ce9d6d38c6fd0a0747fb22734 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 21:30:28 +0000 Subject: [PATCH 08/11] refactor: Use CatalogProvider for primary keys and add scanning fallback to get_record() - Refactor _get_stream_primary_key() to use CatalogProvider.get_primary_keys() - Handle both flat and nested primary key formats - Add allow_scanning and scan_timeout_seconds parameters to get_record() - Implement scanning fallback when direct fetch fails or for non-declarative sources - Normalize dict keys with LowerCaseNormalizer for case-insensitive comparison - Defer PK validation to avoid early catalog access on non-declarative executors Addresses GitHub review comments from @aaronsteers on PR #872 Co-Authored-By: AJ Steers --- airbyte/sources/base.py | 79 ++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 9facef9ed..e8151e6b6 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -32,6 +32,7 @@ from airbyte._connector_base import ConnectorBase from airbyte._executors.declarative import DeclarativeExecutor from airbyte._message_iterators import AirbyteMessageIterator +from airbyte._util.name_normalizers import LowerCaseNormalizer from airbyte._util.temp_files import as_temp_files from airbyte.caches.util import get_default_cache from airbyte.datasets._lazy import LazyDataset @@ -606,7 +607,7 @@ def _get_stream_primary_key(self, stream_name: str) -> list[str]: """Get the primary key for a stream. Returns the primary key as a flat list of field names. - Handles the Airbyte protocol's nested list structure. + Uses CatalogProvider to handle the Airbyte protocol's nested list structure. """ catalog = self.configured_catalog for configured_stream in catalog.streams: @@ -614,11 +615,22 @@ def _get_stream_primary_key(self, stream_name: str) -> list[str]: pk = configured_stream.primary_key if not pk: return [] - if isinstance(pk, list) and len(pk) > 0: - if isinstance(pk[0], list): - return [field[0] if isinstance(field, list) else field for field in pk] - return list(pk) # type: ignore[arg-type] - return [] + + # Normalize flat format to nested format for CatalogProvider + if isinstance(pk, list) and len(pk) > 0 and not isinstance(pk[0], list): + pk = [[field] for field in pk] + + temp_stream = type(configured_stream)( + stream=configured_stream.stream, + sync_mode=configured_stream.sync_mode, + destination_sync_mode=configured_stream.destination_sync_mode, + primary_key=pk, + cursor_field=configured_stream.cursor_field, + ) + temp_catalog = type(catalog)(streams=[temp_stream]) + catalog_provider = CatalogProvider(temp_catalog) + return catalog_provider.get_primary_keys(stream_name) + raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, connector_name=self.name, @@ -661,7 +673,8 @@ def _normalize_and_validate_pk_value( input_value=str(pk_value), ) provided_key = next(iter(pk_value.keys())) - if provided_key != pk_field: + normalized_provided_key = LowerCaseNormalizer.normalize(provided_key) + if normalized_provided_key != pk_field: msg = ( f"Primary key field '{provided_key}' does not match " f"stream's primary key '{pk_field}'." @@ -679,6 +692,8 @@ def get_record( stream_name: str, *, pk_value: Any, # noqa: ANN401 + allow_scanning: bool = False, + scan_timeout_seconds: int = 5, ) -> dict[str, Any]: """Fetch a single record by primary key value. @@ -689,6 +704,8 @@ def get_record( pk_value: The primary key value. Can be: - A string or integer value (e.g., "123" or 123) - A dict with a single entry (e.g., {"id": "123"}) + allow_scanning: If True, fall back to scanning the stream if direct fetch fails. + scan_timeout_seconds: Maximum time to spend scanning for the record. Returns: The fetched record as a dictionary. @@ -699,17 +716,55 @@ def get_record( exc.PyAirbyteInputError: If the pk_value format is invalid. NotImplementedError: If the source is not declarative or uses composite keys. """ - if not isinstance(self.executor, DeclarativeExecutor): + if isinstance(self.executor, DeclarativeExecutor): + pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) + try: + return self.executor.fetch_record( + stream_name=stream_name, + primary_key_value=pk_value_str, + ) + except (NotImplementedError, exc.AirbyteRecordNotFoundError) as e: + if not allow_scanning: + raise + scan_reason = type(e).__name__ + + elif not allow_scanning: raise NotImplementedError( - f"get_record() is only supported for declarative sources. " - f"This source uses {type(self.executor).__name__}." + f"get_record() direct fetch is only supported for declarative sources. " + f"This source uses {type(self.executor).__name__}. " + f"Set allow_scanning=True to enable scanning fallback." ) + else: + scan_reason = "non-declarative source" pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) + primary_key_fields = self._get_stream_primary_key(stream_name) + pk_field = primary_key_fields[0] + + start_time = time.monotonic() + for record in self.get_records(stream_name): + if time.monotonic() - start_time > scan_timeout_seconds: + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + context={ + "primary_key_field": pk_field, + "primary_key_value": pk_value_str, + "scan_timeout_seconds": scan_timeout_seconds, + "scan_reason": scan_reason, + }, + ) + + record_data = record if isinstance(record, dict) else record.data + if str(record_data.get(pk_field)) == pk_value_str: + return record_data - return self.executor.fetch_record( + raise exc.AirbyteRecordNotFoundError( stream_name=stream_name, - primary_key_value=pk_value_str, + context={ + "primary_key_field": pk_field, + "primary_key_value": pk_value_str, + "scan_reason": scan_reason, + }, ) def get_samples( From 8cdfb84bab1385a99b8461a049ed30504945b439 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:20:49 +0000 Subject: [PATCH 09/11] fix: Change lowercase 'any' to 'Any' in test type annotations - Add 'from typing import Any' import - Fix type annotation on line 183: pk_value: any -> pk_value: Any - Fix type annotation on line 254: pk_value: any -> pk_value: Any Addresses GitHub review comment #2547624128 Co-Authored-By: AJ Steers --- tests/unit_tests/test_get_record.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index 1c1880568..6c403e955 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -3,6 +3,7 @@ from __future__ import annotations +from typing import Any from unittest.mock import Mock, PropertyMock, patch import pytest @@ -179,7 +180,7 @@ def test_source_get_stream_primary_key_stream_not_found() -> None: ], ) def test_source_normalize_and_validate_pk_value( - pk_value: any, + pk_value: Any, primary_key_fields: list[str], expected_result: str | None, expected_error: type[Exception] | None, @@ -250,7 +251,7 @@ def test_source_get_record_calls_executor_fetch_record() -> None: pytest.param({"id": "123"}, id="dict_pk"), ], ) -def test_source_get_record_accepts_various_pk_formats(pk_value: any) -> None: +def test_source_get_record_accepts_various_pk_formats(pk_value: Any) -> None: """Test get_record() accepts various PK value formats.""" mock_executor = Mock(spec=DeclarativeExecutor) mock_executor.fetch_record.return_value = {"id": "123", "name": "Test"} From d418b16df7c5fe2963e660c8bd726c2f0edf13ce Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:27:29 +0000 Subject: [PATCH 10/11] refactor: Replace _get_stream_primary_key with catalog_provider property - Add catalog_provider property to Source class that returns CatalogProvider instance - Remove _get_stream_primary_key() helper method (lines 606-638) - Replace both usages with self.catalog_provider.get_primary_keys() - Update test to patch catalog_provider property instead of private method - Remove tests for deleted private method (test_source_get_stream_primary_key) This simplifies the code by using the existing CatalogProvider utility directly instead of maintaining duplicate primary key extraction logic. Co-Authored-By: AJ Steers --- airbyte/sources/base.py | 43 +++--------------- tests/unit_tests/test_get_record.py | 69 ++++------------------------- 2 files changed, 16 insertions(+), 96 deletions(-) diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index e8151e6b6..82891797e 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -419,6 +419,11 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog: streams_filter: list[str] = self._selected_stream_names or self.get_available_streams() return self.get_configured_catalog(streams=streams_filter) + @property + def catalog_provider(self) -> CatalogProvider: + """Return a catalog provider for this source.""" + return CatalogProvider(self.configured_catalog) + def get_configured_catalog( self, streams: Literal["*"] | list[str] | None = None, @@ -603,40 +608,6 @@ def get_documents( render_metadata=render_metadata, ) - def _get_stream_primary_key(self, stream_name: str) -> list[str]: - """Get the primary key for a stream. - - Returns the primary key as a flat list of field names. - Uses CatalogProvider to handle the Airbyte protocol's nested list structure. - """ - catalog = self.configured_catalog - for configured_stream in catalog.streams: - if configured_stream.stream.name == stream_name: - pk = configured_stream.primary_key - if not pk: - return [] - - # Normalize flat format to nested format for CatalogProvider - if isinstance(pk, list) and len(pk) > 0 and not isinstance(pk[0], list): - pk = [[field] for field in pk] - - temp_stream = type(configured_stream)( - stream=configured_stream.stream, - sync_mode=configured_stream.sync_mode, - destination_sync_mode=configured_stream.destination_sync_mode, - primary_key=pk, - cursor_field=configured_stream.cursor_field, - ) - temp_catalog = type(catalog)(streams=[temp_stream]) - catalog_provider = CatalogProvider(temp_catalog) - return catalog_provider.get_primary_keys(stream_name) - - raise exc.AirbyteStreamNotFoundError( - stream_name=stream_name, - connector_name=self.name, - available_streams=self.get_available_streams(), - ) - def _normalize_and_validate_pk_value( self, stream_name: str, @@ -650,7 +621,7 @@ def _normalize_and_validate_pk_value( Returns the PK value as a string. """ - primary_key_fields = self._get_stream_primary_key(stream_name) + primary_key_fields = self.catalog_provider.get_primary_keys(stream_name) if not primary_key_fields: raise exc.PyAirbyteInputError( @@ -738,7 +709,7 @@ def get_record( scan_reason = "non-declarative source" pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) - primary_key_fields = self._get_stream_primary_key(stream_name) + primary_key_fields = self.catalog_provider.get_primary_keys(stream_name) pk_field = primary_key_fields[0] start_time = time.monotonic() diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py index 6c403e955..277088c60 100644 --- a/tests/unit_tests/test_get_record.py +++ b/tests/unit_tests/test_get_record.py @@ -91,64 +91,6 @@ def test_declarative_executor_fetch_record_stream_validation( assert result == {"id": "123"} -@pytest.mark.parametrize( - "primary_key,expected_result", - [ - pytest.param([["id"]], ["id"], id="nested_single_field"), - pytest.param(["id"], ["id"], id="flat_single_field"), - pytest.param([["id"], ["org_id"]], ["id", "org_id"], id="nested_composite"), - pytest.param([], [], id="no_primary_key"), - pytest.param(None, [], id="none_primary_key"), - ], -) -def test_source_get_stream_primary_key( - primary_key: list | None, - expected_result: list[str], -) -> None: - """Test _get_stream_primary_key() handles various PK formats.""" - mock_executor = Mock() - source = Source( - executor=mock_executor, - name="test-source", - config={"api_key": "test"}, - ) - - mock_stream = Mock() - mock_stream.stream.name = "test_stream" - mock_stream.primary_key = primary_key - - mock_catalog = Mock() - mock_catalog.streams = [mock_stream] - - with patch.object( - type(source), "configured_catalog", new_callable=PropertyMock - ) as mock_prop: - mock_prop.return_value = mock_catalog - result = source._get_stream_primary_key("test_stream") - assert result == expected_result - - -def test_source_get_stream_primary_key_stream_not_found() -> None: - """Test _get_stream_primary_key() raises error for nonexistent stream.""" - mock_executor = Mock() - source = Source( - executor=mock_executor, - name="test-source", - config={"api_key": "test"}, - ) - - mock_catalog = Mock() - mock_catalog.streams = [] - - with patch.object( - type(source), "configured_catalog", new_callable=PropertyMock - ) as mock_prop: - mock_prop.return_value = mock_catalog - with patch.object(source, "get_available_streams", return_value=[]): - with pytest.raises(exc.AirbyteStreamNotFoundError): - source._get_stream_primary_key("nonexistent_stream") - - @pytest.mark.parametrize( "pk_value,primary_key_fields,expected_result,expected_error", [ @@ -186,6 +128,8 @@ def test_source_normalize_and_validate_pk_value( expected_error: type[Exception] | None, ) -> None: """Test _normalize_and_validate_pk_value() handles various input formats.""" + from airbyte.shared.catalog_providers import CatalogProvider + mock_executor = Mock() source = Source( executor=mock_executor, @@ -193,9 +137,14 @@ def test_source_normalize_and_validate_pk_value( config={"api_key": "test"}, ) + mock_catalog_provider = Mock(spec=CatalogProvider) + mock_catalog_provider.get_primary_keys.return_value = primary_key_fields + with patch.object( - source, "_get_stream_primary_key", return_value=primary_key_fields - ): + type(source), "catalog_provider", new_callable=PropertyMock + ) as mock_provider_prop: + mock_provider_prop.return_value = mock_catalog_provider + if expected_error: with pytest.raises(expected_error): source._normalize_and_validate_pk_value("test_stream", pk_value) From ba85aaf0b9d8787a9526a539592f4396a9f69dd5 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:21:31 +0000 Subject: [PATCH 11/11] feat: Add get_source_record MCP tool for singleton record fetching Co-Authored-By: AJ Steers --- airbyte/mcp/local_ops.py | 118 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/airbyte/mcp/local_ops.py b/airbyte/mcp/local_ops.py index df018a2b4..7ba355690 100644 --- a/airbyte/mcp/local_ops.py +++ b/airbyte/mcp/local_ops.py @@ -461,6 +461,124 @@ def read_source_stream_records( return records +@mcp_tool( + domain="local", + read_only=True, + idempotent=True, + extra_help_text=_CONFIG_HELP, +) +def get_source_record( # noqa: PLR0913, PLR0917 + source_connector_name: Annotated[ + str, + Field(description="The name of the source connector."), + ], + stream_name: Annotated[ + str, + Field(description="The name of the stream to fetch the record from."), + ], + pk_value: Annotated[ + str | int | dict[str, Any], + Field( + description=( + "The primary key value to fetch. " + "Can be a string, int, or dict with PK field name(s) as keys." + ) + ), + ], + config: Annotated[ + dict | str | None, + Field( + description="The configuration for the source connector as a dict or JSON string.", + default=None, + ), + ], + config_file: Annotated[ + str | Path | None, + Field( + description="Path to a YAML or JSON file containing the source connector config.", + default=None, + ), + ], + config_secret_name: Annotated[ + str | None, + Field( + description="The name of the secret containing the configuration.", + default=None, + ), + ], + override_execution_mode: Annotated[ + Literal["docker", "python", "yaml", "auto"], + Field( + description="Optionally override the execution method to use for the connector. " + "This parameter is ignored if manifest_path is provided (yaml mode will be used).", + default="auto", + ), + ], + manifest_path: Annotated[ + str | Path | None, + Field( + description="Path to a local YAML manifest file for declarative connectors.", + default=None, + ), + ], + allow_scanning: Annotated[ + bool, + Field( + description="If True, fall back to scanning stream records if direct fetch fails.", + default=False, + ), + ], + scan_timeout_seconds: Annotated[ + int, + Field( + description="Maximum time in seconds to spend scanning for the record.", + default=60, + ), + ], +) -> dict[str, Any] | str: + """Fetch a single record from a source connector by primary key value. + + This operation requires a valid configuration and only works with + declarative (YAML-based) sources. For sources with SimpleRetriever-based + streams, it will attempt a direct fetch by constructing the appropriate + API request. If allow_scanning is True and direct fetch fails, it will + fall back to scanning through stream records. + """ + try: + source: Source = _get_mcp_source( + connector_name=source_connector_name, + override_execution_mode=override_execution_mode, + manifest_path=manifest_path, + ) + config_dict = resolve_config( + config=config, + config_file=config_file, + config_secret_name=config_secret_name, + config_spec_jsonschema=source.config_spec, + ) + source.set_config(config_dict) + + record = source.get_record( + stream_name=stream_name, + pk_value=pk_value, + allow_scanning=allow_scanning, + scan_timeout_seconds=scan_timeout_seconds, + ) + + print( + f"Retrieved record from stream '{stream_name}' with pk_value={pk_value!r}", + file=sys.stderr, + ) + + except Exception as ex: + tb_str = traceback.format_exc() + return ( + f"Error fetching record from source '{source_connector_name}': {ex!r}, {ex!s}\n{tb_str}" + ) + else: + return record + + @mcp_tool( domain="local", read_only=True,