From 0bba64ab90d0296a893ae3a171894adafc348e83 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 28 Jan 2026 22:15:47 +0000 Subject: [PATCH 1/2] Add `response_cache` argument for HTTP-based crawlers for caching HTTP responses --- .../http_crawlers/selectolax_crawler.py | 6 +- .../_abstract_http/_abstract_http_crawler.py | 81 +++++++++- src/crawlee/crawlers/_abstract_http/_types.py | 26 +++ .../crawlers/_basic/_context_pipeline.py | 65 +++++++- .../_beautifulsoup/_beautifulsoup_crawler.py | 2 +- src/crawlee/crawlers/_http/_http_crawler.py | 2 +- .../crawlers/_parsel/_parsel_crawler.py | 2 +- .../crawlers/_basic/test_context_pipeline.py | 148 ++++++++++++++++++ .../unit/crawlers/_http/test_http_crawler.py | 54 ++++++- 9 files changed, 368 insertions(+), 18 deletions(-) create mode 100644 src/crawlee/crawlers/_abstract_http/_types.py diff --git a/docs/guides/code_examples/http_crawlers/selectolax_crawler.py b/docs/guides/code_examples/http_crawlers/selectolax_crawler.py index 677a6a3b00..b3cba91ce6 100644 --- a/docs/guides/code_examples/http_crawlers/selectolax_crawler.py +++ b/docs/guides/code_examples/http_crawlers/selectolax_crawler.py @@ -38,9 +38,9 @@ async def final_step( yield SelectolaxLexborContext.from_parsed_http_crawling_context(context) # Build context pipeline: HTTP request -> parsing -> custom context. - kwargs['_context_pipeline'] = ( - self._create_static_content_crawler_pipeline().compose(final_step) - ) + kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline( + **kwargs + ).compose(final_step) super().__init__( parser=SelectolaxLexborParser(), **kwargs, diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index db0ef366c8..d8c9d29d39 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -2,10 +2,16 @@ import asyncio import logging +import sys from abc import ABC from datetime import timedelta from typing import TYPE_CHECKING, Any, Generic +if sys.version_info >= (3, 14): + from compression import zstd as _compressor +else: + import zlib as _compressor + from more_itertools import partition from pydantic import ValidationError from typing_extensions import NotRequired, TypeVar @@ -19,6 +25,7 @@ from crawlee.statistics import StatisticsState from ._http_crawling_context import HttpCrawlingContext, ParsedHttpCrawlingContext, TParseResult, TSelectResult +from ._types import CachedHttpResponse if TYPE_CHECKING: from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator @@ -27,6 +34,7 @@ from crawlee import RequestTransformAction from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, ExtractLinksFunction + from crawlee.storages import KeyValueStore from ._abstract_http_parser import AbstractHttpParser @@ -46,6 +54,9 @@ class HttpCrawlerOptions( navigation_timeout: NotRequired[timedelta | None] """Timeout for the HTTP request.""" + response_cache: NotRequired[KeyValueStore | None] + """Key-value store for caching HTTP responses.""" + @docs_group('Crawlers') class AbstractHttpCrawler( @@ -72,12 +83,14 @@ def __init__( *, parser: AbstractHttpParser[TParseResult, TSelectResult], navigation_timeout: timedelta | None = None, + response_cache: KeyValueStore | None = None, **kwargs: Unpack[BasicCrawlerOptions[TCrawlingContext, StatisticsState]], ) -> None: self._parser = parser self._navigation_timeout = navigation_timeout or timedelta(minutes=1) self._pre_navigation_hooks: list[Callable[[BasicCrawlingContext], Awaitable[None]]] = [] self._shared_navigation_timeouts: dict[int, SharedTimeout] = {} + self._response_cache = response_cache if '_context_pipeline' not in kwargs: raise ValueError( @@ -106,7 +119,7 @@ def __init__( parser: AbstractHttpParser[TParseResult, TSelectResult] = static_parser, **kwargs: Unpack[BasicCrawlerOptions[ParsedHttpCrawlingContext[TParseResult]]], ) -> None: - kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline() + kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline(**kwargs) super().__init__( parser=parser, **kwargs, @@ -114,12 +127,25 @@ def __init__( return _ParsedHttpCrawler - def _create_static_content_crawler_pipeline(self) -> ContextPipeline[ParsedHttpCrawlingContext[TParseResult]]: + def _create_static_content_crawler_pipeline( + self, + response_cache: KeyValueStore | None = None, + **_kwargs: BasicCrawlerOptions[TCrawlingContext, StatisticsState], + ) -> ContextPipeline[ParsedHttpCrawlingContext[TParseResult]]: """Create static content crawler context pipeline with expected pipeline steps.""" + pipeline = ContextPipeline().compose(self._execute_pre_navigation_hooks) + + if response_cache: + return ( + pipeline.compose_with_skip(self._try_load_from_cache, skip_to='parse') + .compose(self._make_http_request) + .compose(self._handle_status_code_response) + .compose(self._save_response_to_cache) + .compose(self._parse_http_response, name='parse') + .compose(self._handle_blocked_request_by_content) + ) return ( - ContextPipeline() - .compose(self._execute_pre_navigation_hooks) - .compose(self._make_http_request) + pipeline.compose(self._make_http_request) .compose(self._handle_status_code_response) .compose(self._parse_http_response) .compose(self._handle_blocked_request_by_content) @@ -308,3 +334,48 @@ def pre_navigation_hook(self, hook: Callable[[BasicCrawlingContext], Awaitable[N hook: A coroutine function to be called before each navigation. """ self._pre_navigation_hooks.append(hook) + + async def _try_load_from_cache( + self, context: BasicCrawlingContext + ) -> AsyncGenerator[HttpCrawlingContext | None, None]: + """Try to load a cached HTTP response. Yields HttpCrawlingContext if found, None otherwise.""" + if not self._response_cache: + raise RuntimeError('Response cache is not configured.') + + key = f'response_{context.request.unique_key}' + raw = await self._response_cache.get_value(key) + + if raw is None: + yield None + return + + compressed: bytes = raw + data = _compressor.decompress(compressed) + cached = CachedHttpResponse.model_validate_json(data) + + context.request.loaded_url = cached.loaded_url or context.request.url + context.request.state = RequestState.AFTER_NAV + + yield HttpCrawlingContext.from_basic_crawling_context(context=context, http_response=cached) + + async def _save_response_to_cache(self, context: HttpCrawlingContext) -> AsyncGenerator[HttpCrawlingContext, None]: + """Save the HTTP response to cache after a successful request.""" + if not self._response_cache: + raise RuntimeError('Response cache is not configured.') + + body = await context.http_response.read() + + cached = CachedHttpResponse( + http_version=context.http_response.http_version, + status_code=context.http_response.status_code, + headers=context.http_response.headers, + body=body, + loaded_url=context.request.loaded_url, + ) + + compressed = _compressor.compress(cached.model_dump_json().encode()) + + key = f'response_{context.request.unique_key}' + await self._response_cache.set_value(key, compressed) + + yield context diff --git a/src/crawlee/crawlers/_abstract_http/_types.py b/src/crawlee/crawlers/_abstract_http/_types.py new file mode 100644 index 0000000000..c4b0d42ff6 --- /dev/null +++ b/src/crawlee/crawlers/_abstract_http/_types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from crawlee._types import HttpHeaders + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + +class CachedHttpResponse(BaseModel): + """An `HttpResponse` implementation that serves pre-stored response data from cache.""" + + http_version: str + status_code: int + headers: HttpHeaders + body: bytes + loaded_url: str | None = None + + async def read(self) -> bytes: + return self.body + + async def read_stream(self) -> AsyncIterator[bytes]: + yield self.body diff --git a/src/crawlee/crawlers/_basic/_context_pipeline.py b/src/crawlee/crawlers/_basic/_context_pipeline.py index 5a7dcc44c4..b120b446ce 100644 --- a/src/crawlee/crawlers/_basic/_context_pipeline.py +++ b/src/crawlee/crawlers/_basic/_context_pipeline.py @@ -69,9 +69,13 @@ def __init__( ] | None = None, _parent: ContextPipeline[BasicCrawlingContext] | None = None, + name: str | None = None, + skip_to: str | None = None, ) -> None: self._middleware = _middleware self._parent = _parent + self.name = name + self.skip_to = skip_to def _middleware_chain(self) -> Generator[ContextPipeline[Any], None, None]: yield self @@ -91,14 +95,24 @@ async def __call__( chain = list(self._middleware_chain()) cleanup_stack: list[_Middleware[Any]] = [] final_consumer_exception: Exception | None = None + skip_to_middleware: str | None = None try: for member in reversed(chain): + if skip_to_middleware is not None: + if member.name == skip_to_middleware: + skip_to_middleware = None + else: + continue + if member._middleware: # noqa: SLF001 - middleware_instance = _Middleware(middleware=member._middleware, input_context=crawling_context) # noqa: SLF001 + middleware_instance = _Middleware( + middleware=member._middleware, # noqa: SLF001 + input_context=crawling_context, + ) try: result = await middleware_instance.action() - except SessionError: # Session errors get special treatment + except SessionError: raise except StopAsyncIteration as e: raise RuntimeError('The middleware did not yield') from e @@ -107,12 +121,26 @@ async def __call__( except Exception as e: raise ContextPipelineInitializationError(e, crawling_context) from e + if result is None: + if member.skip_to is None: + raise RuntimeError( + 'Middleware yielded None but no skip_to target is configured. ' + 'Use compose_with_skip() for conditional middleware.' + ) + # Keep the existing context for next middleware + result = crawling_context + elif member.skip_to: + skip_to_middleware = member.skip_to + crawling_context = result cleanup_stack.append(middleware_instance) + if skip_to_middleware is not None: + raise RuntimeError(f'Skip target middleware "{skip_to_middleware}" not found in pipeline') + try: await final_context_consumer(cast('TCrawlingContext', crawling_context)) - except SessionError as e: # Session errors get special treatment + except SessionError as e: final_consumer_exception = e raise except Exception as e: @@ -128,6 +156,7 @@ def compose( [TCrawlingContext], AsyncGenerator[TMiddlewareCrawlingContext, None], ], + name: str | None = None, ) -> ContextPipeline[TMiddlewareCrawlingContext]: """Add a middleware to the pipeline. @@ -143,4 +172,34 @@ def compose( middleware, ), _parent=cast('ContextPipeline[BasicCrawlingContext]', self), + name=name, + ) + + def compose_with_skip( + self, + middleware: Callable[ + [TCrawlingContext], + AsyncGenerator[TMiddlewareCrawlingContext | None, None], + ], + skip_to: str, + ) -> ContextPipeline[TMiddlewareCrawlingContext]: + """Add a conditional middleware that can skip to a named target middleware. + + If middleware yields a context, that context is used and pipeline skips to the target middleware. + If middleware yields None, pipeline continues normally without changing context. + + Args: + middleware: Middleware that yields context (activates skip) or None (continue normally). + skip_to: Name of the target middleware to skip to (must exist in pipeline). + + Returns: + The extended pipeline instance, providing a fluent interface. + """ + return ContextPipeline[TMiddlewareCrawlingContext]( + _middleware=cast( + 'Callable[[BasicCrawlingContext], AsyncGenerator[TMiddlewareCrawlingContext, Exception | None]]', + middleware, + ), + _parent=cast('ContextPipeline[BasicCrawlingContext]', self), + skip_to=skip_to, ) diff --git a/src/crawlee/crawlers/_beautifulsoup/_beautifulsoup_crawler.py b/src/crawlee/crawlers/_beautifulsoup/_beautifulsoup_crawler.py index 919f26221e..d561c1abf8 100644 --- a/src/crawlee/crawlers/_beautifulsoup/_beautifulsoup_crawler.py +++ b/src/crawlee/crawlers/_beautifulsoup/_beautifulsoup_crawler.py @@ -73,7 +73,7 @@ async def final_step( """Enhance `ParsedHttpCrawlingContext[BeautifulSoup]` with `soup` property.""" yield BeautifulSoupCrawlingContext.from_parsed_http_crawling_context(context) - kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline().compose(final_step) + kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline(**kwargs).compose(final_step) super().__init__( parser=BeautifulSoupParser(parser=parser), diff --git a/src/crawlee/crawlers/_http/_http_crawler.py b/src/crawlee/crawlers/_http/_http_crawler.py index 2c098ecbc6..2b6f059d22 100644 --- a/src/crawlee/crawlers/_http/_http_crawler.py +++ b/src/crawlee/crawlers/_http/_http_crawler.py @@ -55,7 +55,7 @@ def __init__( Args: kwargs: Additional keyword arguments to pass to the underlying `AbstractHttpCrawler`. """ - kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline() + kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline(**kwargs) super().__init__( parser=NoParser(), **kwargs, diff --git a/src/crawlee/crawlers/_parsel/_parsel_crawler.py b/src/crawlee/crawlers/_parsel/_parsel_crawler.py index ac8e9c9f09..149c00cdbe 100644 --- a/src/crawlee/crawlers/_parsel/_parsel_crawler.py +++ b/src/crawlee/crawlers/_parsel/_parsel_crawler.py @@ -70,7 +70,7 @@ async def final_step( """Enhance `ParsedHttpCrawlingContext[Selector]` with a `selector` property.""" yield ParselCrawlingContext.from_parsed_http_crawling_context(context) - kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline().compose(final_step) + kwargs['_context_pipeline'] = self._create_static_content_crawler_pipeline(**kwargs).compose(final_step) super().__init__( parser=ParselParser(), **kwargs, diff --git a/tests/unit/crawlers/_basic/test_context_pipeline.py b/tests/unit/crawlers/_basic/test_context_pipeline.py index 51f5556cac..4931e9f87b 100644 --- a/tests/unit/crawlers/_basic/test_context_pipeline.py +++ b/tests/unit/crawlers/_basic/test_context_pipeline.py @@ -194,3 +194,151 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC assert consumer.called assert not cleanup.called + + +@pytest.mark.parametrize( + 'condition', + [ + pytest.param(True, id='skip to pipeline middleware to step 3'), + pytest.param(False, id='do not skip any middleware'), + ], +) +async def test_pipeline_with_skip(*, condition: bool) -> None: + consumer = AsyncMock() + mock_step_1 = AsyncMock() + mock_step_2 = AsyncMock() + mock_step_3 = AsyncMock() + + async def step_1(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_1() + if condition: + yield context + else: + yield None + + async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_2() + yield context + + async def step_3(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_3() + yield context + + pipeline = ( + ContextPipeline[BasicCrawlingContext]() + .compose_with_skip(step_1, skip_to='step_3') + .compose(step_2, name='step_2') + .compose(step_3, name='step_3') + ) + context = BasicCrawlingContext( + request=Request.from_url(url='https://test.io/'), + send_request=AsyncMock(), + add_requests=AsyncMock(), + session=Session(), + proxy_info=AsyncMock(), + push_data=AsyncMock(), + use_state=AsyncMock(), + get_key_value_store=AsyncMock(), + log=logging.getLogger(), + ) + await pipeline(context, consumer) + + assert mock_step_1.called + if condition: + assert not mock_step_2.called + else: + assert mock_step_2.called + assert mock_step_3.called + assert consumer.called + + +async def test_pipeline_with_error_in_skip() -> None: + consumer = AsyncMock() + mock_step_1 = AsyncMock() + mock_step_2 = AsyncMock() + mock_step_3 = AsyncMock() + + async def step_1(_context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_1() + raise RuntimeError('Crash during middleware initialization') + yield None + + async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_2() + yield context + + async def step_3(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_3() + yield context + + pipeline = ( + ContextPipeline[BasicCrawlingContext]() + .compose_with_skip(step_1, skip_to='step_3') + .compose(step_2, name='step_2') + .compose(step_3, name='step_3') + ) + context = BasicCrawlingContext( + request=Request.from_url(url='https://test.io/'), + send_request=AsyncMock(), + add_requests=AsyncMock(), + session=Session(), + proxy_info=AsyncMock(), + push_data=AsyncMock(), + use_state=AsyncMock(), + get_key_value_store=AsyncMock(), + log=logging.getLogger(), + ) + + with pytest.raises(ContextPipelineInitializationError): + await pipeline(context, consumer) + + assert mock_step_1.called + assert not mock_step_2.called + assert not mock_step_3.called + assert not consumer.called + + +async def test_pipeline_with_skip_without_target() -> None: + consumer = AsyncMock() + mock_step_1 = AsyncMock() + mock_step_2 = AsyncMock() + mock_step_3 = AsyncMock() + + async def step_1(_context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_1() + raise RuntimeError('Crash during middleware initialization') + yield None + + async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_2() + yield context + + async def step_3(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingContext, None]: + await mock_step_3() + yield context + + pipeline = ( + ContextPipeline[BasicCrawlingContext]() + .compose_with_skip(step_1, skip_to='step_4') # step_4 does not exist + .compose(step_2, name='step_2') + .compose(step_3, name='step_3') + ) + context = BasicCrawlingContext( + request=Request.from_url(url='https://test.io/'), + send_request=AsyncMock(), + add_requests=AsyncMock(), + session=Session(), + proxy_info=AsyncMock(), + push_data=AsyncMock(), + use_state=AsyncMock(), + get_key_value_store=AsyncMock(), + log=logging.getLogger(), + ) + + with pytest.raises(ContextPipelineInitializationError): + await pipeline(context, consumer) + + assert mock_step_1.called + assert not mock_step_2.called + assert not mock_step_3.called + assert not consumer.called diff --git a/tests/unit/crawlers/_http/test_http_crawler.py b/tests/unit/crawlers/_http/test_http_crawler.py index 21bfde2eaf..4c3669070f 100644 --- a/tests/unit/crawlers/_http/test_http_crawler.py +++ b/tests/unit/crawlers/_http/test_http_crawler.py @@ -1,8 +1,8 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING -from unittest.mock import AsyncMock, Mock +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, Mock, patch from urllib.parse import parse_qs, urlencode import pytest @@ -11,11 +11,11 @@ from crawlee.crawlers import HttpCrawler from crawlee.sessions import SessionPool from crawlee.statistics import Statistics -from crawlee.storages import RequestQueue +from crawlee.storages import KeyValueStore, RequestQueue from tests.unit.server_endpoints import HELLO_WORLD if TYPE_CHECKING: - from collections.abc import Awaitable, Callable + from collections.abc import AsyncGenerator, Awaitable, Callable from yarl import URL @@ -632,3 +632,49 @@ async def failed_request_handler(context: BasicCrawlingContext, _error: Exceptio } await queue.drop() + + +async def test_save_response_to_cache(http_client: HttpClient, server_url: URL) -> None: + cache_kvs = await KeyValueStore.open(alias='http-request-cache') + + call_tracker = Mock() + original_method = HttpCrawler._make_http_request + + response_data: dict[str, dict[str, Any]] = {} + + # Wrap the original _make_http_request to track calls + async def tracked_make_http_request( + self: HttpCrawler, context: HttpCrawlingContext + ) -> AsyncGenerator[HttpCrawlingContext, None]: + call_tracker() + async for result in original_method(self, context): + yield result + + with patch.object(HttpCrawler, '_make_http_request', tracked_make_http_request): + crawler = HttpCrawler(http_client=http_client, response_cache=cache_kvs) + + @crawler.router.default_handler + async def handler(context: HttpCrawlingContext) -> None: + run_key = context.request.user_data['run'] + if not isinstance(run_key, str): + raise TypeError('Invalid run key in user_data') + + response_data[run_key] = { + 'status': context.http_response.status_code, + 'body': await context.http_response.read(), + 'http_version': context.http_response.http_version, + 'headers': dict(context.http_response.headers), + 'loaded_url': context.request.loaded_url, + } + + # First run. Make request and save in cache + await crawler.run([Request.from_url(str(server_url), user_data={'run': 'first-run'})]) + assert call_tracker.call_count == 1 + + # Second run. Load from cache, no actual request + await crawler.run([Request.from_url(str(server_url), user_data={'run': 'second-run'})]) + assert call_tracker.call_count == 1 + + assert response_data['first-run'] == response_data['second-run'] + + await cache_kvs.drop() From 46b36d8d719ccb020ba727b0de28cb55bc34a3f4 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 28 Jan 2026 22:27:53 +0000 Subject: [PATCH 2/2] standardize key length --- .../crawlers/_abstract_http/_abstract_http_crawler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index d8c9d29d39..2290ac2473 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -5,6 +5,7 @@ import sys from abc import ABC from datetime import timedelta +from hashlib import sha256 from typing import TYPE_CHECKING, Any, Generic if sys.version_info >= (3, 14): @@ -342,7 +343,7 @@ async def _try_load_from_cache( if not self._response_cache: raise RuntimeError('Response cache is not configured.') - key = f'response_{context.request.unique_key}' + key = self._get_cache_key(context.request.unique_key) raw = await self._response_cache.get_value(key) if raw is None: @@ -375,7 +376,13 @@ async def _save_response_to_cache(self, context: HttpCrawlingContext) -> AsyncGe compressed = _compressor.compress(cached.model_dump_json().encode()) - key = f'response_{context.request.unique_key}' + key = self._get_cache_key(context.request.unique_key) await self._response_cache.set_value(key, compressed) yield context + + @staticmethod + def _get_cache_key(unique_key: str) -> str: + """Generate a deterministic cache key for a unique_key.""" + hashed_key = sha256(unique_key.encode('utf-8')).hexdigest() + return f'response_{hashed_key[:15]}'