diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 2bd7d268d..fbecb5a3a 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -535,6 +535,7 @@ ) from airbyte_cdk.sources.declarative.retrievers import ( AsyncRetriever, + ClientSideIncrementalRetrieverDecorator, LazySimpleRetriever, SimpleRetriever, ) @@ -2077,6 +2078,7 @@ def create_default_stream( else concurrent_cursor ) + is_client_side_incremental = self._is_client_side_filtering_enabled(model) retriever = self._create_component_from_model( model=model.retriever, config=config, @@ -2086,7 +2088,7 @@ def create_default_stream( stream_slicer=stream_slicer, partition_router=partition_router, has_stop_condition_cursor=self._is_stop_condition_on_cursor(model), - is_client_side_incremental_sync=self._is_client_side_filtering_enabled(model), + is_client_side_incremental_sync=is_client_side_incremental, cursor=concurrent_cursor, transformations=transformations, file_uploader=file_uploader, @@ -2094,6 +2096,15 @@ def create_default_stream( ) if isinstance(retriever, AsyncRetriever): stream_slicer = retriever.stream_slicer + elif ( + is_client_side_incremental + and not isinstance(retriever, SimpleRetriever) + and not isinstance(concurrent_cursor, FinalStateCursor) + ): + retriever = ClientSideIncrementalRetrieverDecorator( + retriever=retriever, + cursor=concurrent_cursor, + ) schema_loader: SchemaLoader if model.schema_loader and isinstance(model.schema_loader, list): diff --git a/airbyte_cdk/sources/declarative/retrievers/__init__.py b/airbyte_cdk/sources/declarative/retrievers/__init__.py index 7349efcd5..77acdd747 100644 --- a/airbyte_cdk/sources/declarative/retrievers/__init__.py +++ b/airbyte_cdk/sources/declarative/retrievers/__init__.py @@ -3,6 +3,9 @@ # from airbyte_cdk.sources.declarative.retrievers.async_retriever import AsyncRetriever +from airbyte_cdk.sources.declarative.retrievers.client_side_incremental_retriever_decorator import ( + ClientSideIncrementalRetrieverDecorator, +) from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever from airbyte_cdk.sources.declarative.retrievers.simple_retriever import ( LazySimpleRetriever, @@ -10,6 +13,7 @@ ) __all__ = [ + "ClientSideIncrementalRetrieverDecorator", "Retriever", "SimpleRetriever", "AsyncRetriever", diff --git a/airbyte_cdk/sources/declarative/retrievers/client_side_incremental_retriever_decorator.py b/airbyte_cdk/sources/declarative/retrievers/client_side_incremental_retriever_decorator.py new file mode 100644 index 000000000..64542a093 --- /dev/null +++ b/airbyte_cdk/sources/declarative/retrievers/client_side_incremental_retriever_decorator.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Iterable, Mapping, Optional + +from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.types import Record, StreamSlice + + +class ClientSideIncrementalRetrieverDecorator(Retriever): + """ + Decorator that wraps a Retriever and applies client-side incremental filtering. + + This decorator filters out records that are older than the cursor state, + enabling client-side incremental sync for custom retrievers that don't + natively support the ClientSideIncrementalRecordFilterDecorator. + + When a stream uses `is_client_side_incremental: true` with a custom retriever, + this decorator ensures that only records newer than the cursor state are emitted. + + Attributes: + retriever: The underlying retriever to wrap + cursor: The cursor used to determine if records should be synced + """ + + def __init__( + self, + retriever: Retriever, + cursor: Cursor, + ): + self._retriever = retriever + self._cursor = cursor + + def read_records( + self, + records_schema: Mapping[str, Any], + stream_slice: Optional[StreamSlice] = None, + ) -> Iterable[StreamData]: + for record in self._retriever.read_records( + records_schema=records_schema, + stream_slice=stream_slice, + ): + if isinstance(record, Record): + if self._cursor.should_be_synced(record): + yield record + elif isinstance(record, Mapping): + record_obj = Record(data=record, associated_slice=stream_slice, stream_name="") + if self._cursor.should_be_synced(record_obj): + yield record + else: + yield record diff --git a/unit_tests/sources/declarative/retrievers/test_client_side_incremental_retriever_decorator.py b/unit_tests/sources/declarative/retrievers/test_client_side_incremental_retriever_decorator.py new file mode 100644 index 000000000..2a525987f --- /dev/null +++ b/unit_tests/sources/declarative/retrievers/test_client_side_incremental_retriever_decorator.py @@ -0,0 +1,203 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +from datetime import datetime, timedelta, timezone +from typing import Any +from unittest.mock import Mock + +import pytest + +from airbyte_cdk.sources.declarative.retrievers import ( + ClientSideIncrementalRetrieverDecorator, + Retriever, +) +from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField +from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( + CustomFormatConcurrentStreamStateConverter, +) +from airbyte_cdk.sources.types import Record, StreamSlice + +DATE_FORMAT = "%Y-%m-%d" + + +class MockRetriever(Retriever): + """Mock retriever that yields predefined records.""" + + def __init__(self, records: list[dict[str, Any]]): + self._records = records + + def read_records( + self, + records_schema: dict[str, Any], + stream_slice: StreamSlice | None = None, + ): + for record in self._records: + yield record + + +@pytest.fixture +def cursor_with_state(): + """Create a cursor with state set to 2021-01-03.""" + return ConcurrentCursor( + stream_name="test_stream", + stream_namespace=None, + stream_state={"created_at": "2021-01-03"}, + message_repository=Mock(), + connector_state_manager=Mock(), + connector_state_converter=CustomFormatConcurrentStreamStateConverter( + datetime_format=DATE_FORMAT + ), + cursor_field=CursorField("created_at"), + slice_boundary_fields=("start", "end"), + start=datetime(2021, 1, 1, tzinfo=timezone.utc), + end_provider=lambda: datetime(2021, 1, 10, tzinfo=timezone.utc), + slice_range=timedelta(days=365 * 10), + ) + + +@pytest.fixture +def cursor_without_state(): + """Create a cursor without state.""" + return ConcurrentCursor( + stream_name="test_stream", + stream_namespace=None, + stream_state={}, + message_repository=Mock(), + connector_state_manager=Mock(), + connector_state_converter=CustomFormatConcurrentStreamStateConverter( + datetime_format=DATE_FORMAT + ), + cursor_field=CursorField("created_at"), + slice_boundary_fields=("start", "end"), + start=datetime(2021, 1, 1, tzinfo=timezone.utc), + end_provider=lambda: datetime(2021, 1, 10, tzinfo=timezone.utc), + slice_range=timedelta(days=365 * 10), + ) + + +@pytest.mark.parametrize( + "records,cursor_state,expected_ids", + [ + pytest.param( + [ + {"id": 1, "created_at": "2020-01-03"}, + {"id": 2, "created_at": "2021-01-03"}, + {"id": 3, "created_at": "2021-01-04"}, + {"id": 4, "created_at": "2021-02-01"}, + ], + {"created_at": "2021-01-03"}, + [2, 3, 4], + id="filters_records_older_than_cursor_state", + ), + pytest.param( + [ + {"id": 1, "created_at": "2020-01-03"}, + {"id": 2, "created_at": "2021-01-03"}, + {"id": 3, "created_at": "2021-01-04"}, + ], + {}, + [2, 3], + id="no_state_uses_start_date_for_filtering", + ), + pytest.param( + [], + {"created_at": "2021-01-03"}, + [], + id="empty_records_returns_empty", + ), + ], +) +def test_client_side_incremental_retriever_decorator_with_dict_records( + records: list[dict[str, Any]], + cursor_state: dict[str, Any], + expected_ids: list[int], +): + """Test filtering with dict records.""" + cursor = ConcurrentCursor( + stream_name="test_stream", + stream_namespace=None, + stream_state=cursor_state, + message_repository=Mock(), + connector_state_manager=Mock(), + connector_state_converter=CustomFormatConcurrentStreamStateConverter( + datetime_format=DATE_FORMAT + ), + cursor_field=CursorField("created_at"), + slice_boundary_fields=("start", "end"), + start=datetime(2021, 1, 1, tzinfo=timezone.utc), + end_provider=lambda: datetime(2021, 12, 31, tzinfo=timezone.utc), + slice_range=timedelta(days=365 * 10), + ) + + mock_retriever = MockRetriever(records) + decorator = ClientSideIncrementalRetrieverDecorator( + retriever=mock_retriever, + cursor=cursor, + ) + + stream_slice = StreamSlice(partition={}, cursor_slice={}) + result = list(decorator.read_records(records_schema={}, stream_slice=stream_slice)) + + assert [r["id"] for r in result] == expected_ids + + +def test_client_side_incremental_retriever_decorator_with_record_objects( + cursor_with_state, +): + """Test filtering with Record objects.""" + stream_slice = StreamSlice(partition={}, cursor_slice={}) + records = [ + Record( + data={"id": 1, "created_at": "2020-01-03"}, + associated_slice=stream_slice, + stream_name="test_stream", + ), + Record( + data={"id": 2, "created_at": "2021-01-03"}, + associated_slice=stream_slice, + stream_name="test_stream", + ), + Record( + data={"id": 3, "created_at": "2021-01-04"}, + associated_slice=stream_slice, + stream_name="test_stream", + ), + ] + + class MockRetrieverWithRecords(Retriever): + def read_records(self, records_schema, stream_slice=None): + yield from records + + mock_retriever = MockRetrieverWithRecords() + decorator = ClientSideIncrementalRetrieverDecorator( + retriever=mock_retriever, + cursor=cursor_with_state, + ) + + result = list(decorator.read_records(records_schema={}, stream_slice=stream_slice)) + + assert [r["id"] for r in result] == [2, 3] + + +def test_client_side_incremental_retriever_decorator_passes_through_non_record_data( + cursor_with_state, +): + """Test that non-dict/non-Record data is passed through unchanged.""" + stream_slice = StreamSlice(partition={}, cursor_slice={}) + + class MockRetrieverWithMixedData(Retriever): + def read_records(self, records_schema, stream_slice=None): + yield "some_string" + yield 123 + yield {"id": 1, "created_at": "2021-01-04"} + + mock_retriever = MockRetrieverWithMixedData() + decorator = ClientSideIncrementalRetrieverDecorator( + retriever=mock_retriever, + cursor=cursor_with_state, + ) + + result = list(decorator.read_records(records_schema={}, stream_slice=stream_slice)) + + assert result == ["some_string", 123, {"id": 1, "created_at": "2021-01-04"}]