From 90c163e00b9a3a3eec2452b171aa51b9dc58234e Mon Sep 17 00:00:00 2001 From: Eric Avdey Date: Thu, 16 Oct 2025 23:45:32 -0300 Subject: [PATCH 1/3] refactor: use CloudantBaseService as a super of CloudantV1 --- ibmcloudant/__init__.py | 15 +-- ibmcloudant/cloudant_base_service.py | 186 ++++++++++++++------------- ibmcloudant/cloudant_v1.py | 7 +- 3 files changed, 99 insertions(+), 109 deletions(-) diff --git a/ibmcloudant/__init__.py b/ibmcloudant/__init__.py index 3b83fbad..c385b4f3 100644 --- a/ibmcloudant/__init__.py +++ b/ibmcloudant/__init__.py @@ -1,5 +1,5 @@ # coding: utf-8 -# © Copyright IBM Corporation 2020, 2024. +# © Copyright IBM Corporation 2020, 2025. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ from ibm_cloud_sdk_core import IAMTokenManager, DetailedResponse, BaseService, ApiException, get_authenticator from .couchdb_session_authenticator import CouchDbSessionAuthenticator from .couchdb_session_get_authenticator_patch import new_construct_authenticator -from .cloudant_base_service import new_init, new_prepare_request, new_set_default_headers, new_set_http_client, new_set_service_url, new_set_disable_ssl_verification from .couchdb_session_token_manager import CouchDbSessionTokenManager from .cloudant_v1 import CloudantV1 from .features.changes_follower import ChangesFollower @@ -27,15 +26,3 @@ # sdk-core's __construct_authenticator works with a long switch-case so monkey-patching is required get_authenticator.__construct_authenticator = new_construct_authenticator - -CloudantV1.__init__ = new_init - -CloudantV1.set_service_url = new_set_service_url - -CloudantV1.set_default_headers = new_set_default_headers - -CloudantV1.prepare_request = new_prepare_request - -CloudantV1.set_http_client = new_set_http_client - -CloudantV1.set_disable_ssl_verification = new_set_disable_ssl_verification diff --git a/ibmcloudant/cloudant_base_service.py b/ibmcloudant/cloudant_base_service.py index 172e791f..e76275a2 100644 --- a/ibmcloudant/cloudant_base_service.py +++ b/ibmcloudant/cloudant_base_service.py @@ -1,6 +1,6 @@ # coding: utf-8 -# © Copyright IBM Corporation 2020, 2024. +# © Copyright IBM Corporation 2020, 2025. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,12 +24,12 @@ from json.decoder import JSONDecodeError from io import BytesIO +from ibm_cloud_sdk_core import BaseService from ibm_cloud_sdk_core.authenticators import Authenticator from requests import Response, Session from requests.cookies import RequestsCookieJar from .common import get_sdk_headers -from .cloudant_v1 import CloudantV1 from .couchdb_session_authenticator import CouchDbSessionAuthenticator # pylint: disable=missing-docstring @@ -72,52 +72,99 @@ def __hash__(self): # Since Py3.6 dict is ordered so use a key only dict for our set rules_by_operation.setdefault(operation_id, dict()).setdefault(rule) -_old_init = CloudantV1.__init__ - -def new_init(self, authenticator: Authenticator = None): - _old_init(self, authenticator) - # Overwrite default read timeout to 2.5 minutes - if not ('timeout' in self.http_config): - new_http_config = self.http_config.copy() - new_http_config['timeout'] = (CONNECT_TIMEOUT, READ_TIMEOUT) - self.set_http_config(new_http_config) - # Custom actions for CouchDbSessionAuthenticator - if isinstance(authenticator, CouchDbSessionAuthenticator): - # Replacing BaseService's http.cookiejar.CookieJar as RequestsCookieJar supports update(CookieJar) - self.jar = RequestsCookieJar(self.jar) - self.authenticator.set_jar(self.jar) # Authenticators don't have access to cookie jars by default - add_hooks(self) - -_old_set_service_url = CloudantV1.set_service_url - -def new_set_service_url(self, service_url: str): - _old_set_service_url(self, service_url) - try: +class CloudantBaseService(BaseService): + """ + The base class for service classes. + """ + def __init__( + self, + service_url: str = None, + authenticator: Authenticator = None, + ) -> None: + """ + Construct a new client for the Cloudant service. + + :param Authenticator authenticator: The authenticator specifies the authentication mechanism. + Get up to date information from https://github.com/IBM/python-sdk-core/blob/main/README.md + about initializing the authenticator of your choice. + """ + BaseService.__init__(self, service_url=service_url, authenticator=authenticator) + # Overwrite default read timeout to 2.5 minutes + if not ('timeout' in self.http_config): + new_http_config = self.http_config.copy() + new_http_config['timeout'] = (CONNECT_TIMEOUT, READ_TIMEOUT) + self.set_http_config(new_http_config) + # Custom actions for CouchDbSessionAuthenticator + if isinstance(authenticator, CouchDbSessionAuthenticator): + # Replacing BaseService's http.cookiejar.CookieJar as RequestsCookieJar supports update(CookieJar) + self.jar = RequestsCookieJar(self.jar) + self.authenticator.set_jar(self.jar) # Authenticators don't have access to cookie jars by default + add_hooks(self) + + def set_service_url(self, service_url: str): + super().set_service_url(service_url) + try: + if isinstance(self.authenticator, CouchDbSessionAuthenticator): + self.authenticator.token_manager.set_service_url(service_url) + except AttributeError: + pass # in case no authenticator is configured yet, pass + + def set_default_headers(self, headers: Dict[str, str]): + super().set_default_headers(headers) if isinstance(self.authenticator, CouchDbSessionAuthenticator): - self.authenticator.token_manager.set_service_url(service_url) - except AttributeError: - pass # in case no authenticator is configured yet, pass - -_old_set_default_headers = CloudantV1.set_default_headers - -def new_set_default_headers(self, headers: Dict[str, str]): - _old_set_default_headers(self, headers) - if isinstance(self.authenticator, CouchDbSessionAuthenticator): - combined_headers = {} - combined_headers.update(headers) - combined_headers.update(get_sdk_headers( - service_name=self.DEFAULT_SERVICE_NAME, - service_version='V1', - operation_id='authenticator_post_session') - ) - self.authenticator.token_manager.set_default_headers(combined_headers) - -_old_set_disable_ssl_verification = CloudantV1.set_disable_ssl_verification - -def new_set_disable_ssl_verification(self, status: bool = False) -> None: - _old_set_disable_ssl_verification(self, status) - if isinstance(self.authenticator, CouchDbSessionAuthenticator): - self.authenticator.token_manager.set_disable_ssl_verification(status) + combined_headers = {} + combined_headers.update(headers) + combined_headers.update(get_sdk_headers( + service_name=self.DEFAULT_SERVICE_NAME, + service_version='V1', + operation_id='authenticator_post_session') + ) + self.authenticator.token_manager.set_default_headers(combined_headers) + + def set_disable_ssl_verification(self, status: bool = False) -> None: + super().set_disable_ssl_verification(status) + if isinstance(self.authenticator, CouchDbSessionAuthenticator): + self.authenticator.token_manager.set_disable_ssl_verification(status) + + def set_http_client(self, http_client: Session) -> None: + super().set_http_client(http_client) + add_hooks(self) + + def prepare_request(self, + method: str, + url: str, + *args, + headers: Optional[dict] = None, + params: Optional[dict] = None, + data: Optional[Union[str, dict]] = None, + files: Optional[Union[Dict[str, Tuple[str]], + List[Tuple[str, + Tuple[str, + ...]]]]] = None, + **kwargs) -> dict: + # Extract the operation ID from the request headers. + operation_id = None + header = headers.get('X-IBMCloud-SDK-Analytics') + if header is not None: + for element in header.split(';'): + if element.startswith('operation_id'): + operation_id = element.split('=')[1] + break + if operation_id is not None: + # Check each validation rule that applies to the operation. + # Until the request URL is passed to old_prepare_request it does not include the + # service URL and is relative to it + request_url_path_segments = urlsplit(url).path.strip('/').split('/') + if len(request_url_path_segments) == 1 and request_url_path_segments[0] == '': + request_url_path_segments = [] + # Note the get returns a value-less dict, we are iterating only the keys + for rule in rules_by_operation.get(operation_id, {}): + if len(request_url_path_segments) > rule.path_segment_index: + segment_to_validate = request_url_path_segments[rule.path_segment_index] + if segment_to_validate.startswith('_'): + raise ValueError('{0} {1} starts with the invalid _ character.'.format(rule.error_parameter_name, + unquote(segment_to_validate))) + return super().prepare_request(method, url, *args, headers=headers, params=params, data=data, files=files, **kwargs) def _error_response_hook(response:Response, *args, **kwargs) -> Optional[Response]: # pylint: disable=W0613 @@ -186,52 +233,7 @@ def _error_response_hook(response:Response, *args, **kwargs) -> Optional[Respons # so the exception can surface elsewhere. pass return response - -_old_prepare_request = CloudantV1.prepare_request - -def new_prepare_request(self, - method: str, - url: str, - *args, - headers: Optional[dict] = None, - params: Optional[dict] = None, - data: Optional[Union[str, dict]] = None, - files: Optional[Union[Dict[str, Tuple[str]], - List[Tuple[str, - Tuple[str, - ...]]]]] = None, - **kwargs) -> dict: - # Extract the operation ID from the request headers. - operation_id = None - header = headers.get('X-IBMCloud-SDK-Analytics') - if header is not None: - for element in header.split(';'): - if element.startswith('operation_id'): - operation_id = element.split('=')[1] - break - if operation_id is not None: - # Check each validation rule that applies to the operation. - # Until the request URL is passed to old_prepare_request it does not include the - # service URL and is relative to it - request_url_path_segments = urlsplit(url).path.strip('/').split('/') - if len(request_url_path_segments) == 1 and request_url_path_segments[0] == '': - request_url_path_segments = [] - # Note the get returns a value-less dict, we are iterating only the keys - for rule in rules_by_operation.get(operation_id, {}): - if len(request_url_path_segments) > rule.path_segment_index: - segment_to_validate = request_url_path_segments[rule.path_segment_index] - if segment_to_validate.startswith('_'): - raise ValueError('{0} {1} starts with the invalid _ character.'.format(rule.error_parameter_name, - unquote(segment_to_validate))) - return _old_prepare_request(self, method, url, *args, headers=headers, params=params, data=data, files=files, **kwargs) - def add_hooks(self): response_hooks = self.get_http_client().hooks['response'] if _error_response_hook not in response_hooks: response_hooks.append(_error_response_hook) - -_old_set_http_client = CloudantV1.set_http_client - -def new_set_http_client(self, http_client: Session) -> None: - _old_set_http_client(self, http_client) - add_hooks(self) diff --git a/ibmcloudant/cloudant_v1.py b/ibmcloudant/cloudant_v1.py index af92973a..843868ef 100644 --- a/ibmcloudant/cloudant_v1.py +++ b/ibmcloudant/cloudant_v1.py @@ -27,11 +27,12 @@ import json import logging -from ibm_cloud_sdk_core import BaseService, DetailedResponse +from ibm_cloud_sdk_core import DetailedResponse from ibm_cloud_sdk_core.authenticators.authenticator import Authenticator from ibm_cloud_sdk_core.get_authenticator import get_authenticator_from_environment from ibm_cloud_sdk_core.utils import convert_list, convert_model, datetime_to_string, string_to_datetime +from .cloudant_base_service import CloudantBaseService from .common import get_sdk_headers ############################################################################## @@ -39,7 +40,7 @@ ############################################################################## -class CloudantV1(BaseService): +class CloudantV1(CloudantBaseService): """The Cloudant V1 service.""" DEFAULT_SERVICE_URL = 'https://~replace-with-cloudant-host~.cloudantnosqldb.appdomain.cloud' @@ -72,7 +73,7 @@ def __init__( Get up to date information from https://github.com/IBM/python-sdk-core/blob/main/README.md about initializing the authenticator of your choice. """ - BaseService.__init__(self, service_url=self.DEFAULT_SERVICE_URL, authenticator=authenticator) + CloudantBaseService.__init__(self, service_url=self.DEFAULT_SERVICE_URL, authenticator=authenticator) # enable gzip compression of request bodies self.set_enable_gzip_compression(True) From bdf065ac0739e8cf518f661b99e10d189da86bae Mon Sep 17 00:00:00 2001 From: Eric Avdey Date: Fri, 17 Oct 2025 12:03:07 -0300 Subject: [PATCH 2/3] refactor: make session auth to use base service's http client --- ibmcloudant/cloudant_base_service.py | 8 +-- ibmcloudant/couchdb_session_authenticator.py | 24 ++++----- ibmcloudant/couchdb_session_token_manager.py | 15 ++++-- test/integration/test_timeout.py | 54 ++++++++++++-------- 4 files changed, 61 insertions(+), 40 deletions(-) diff --git a/ibmcloudant/cloudant_base_service.py b/ibmcloudant/cloudant_base_service.py index e76275a2..3884671b 100644 --- a/ibmcloudant/cloudant_base_service.py +++ b/ibmcloudant/cloudant_base_service.py @@ -27,7 +27,6 @@ from ibm_cloud_sdk_core import BaseService from ibm_cloud_sdk_core.authenticators import Authenticator from requests import Response, Session -from requests.cookies import RequestsCookieJar from .common import get_sdk_headers from .couchdb_session_authenticator import CouchDbSessionAuthenticator @@ -96,9 +95,8 @@ def __init__( self.set_http_config(new_http_config) # Custom actions for CouchDbSessionAuthenticator if isinstance(authenticator, CouchDbSessionAuthenticator): - # Replacing BaseService's http.cookiejar.CookieJar as RequestsCookieJar supports update(CookieJar) - self.jar = RequestsCookieJar(self.jar) - self.authenticator.set_jar(self.jar) # Authenticators don't have access to cookie jars by default + # Make token manager of CouchDbSessionAuthenticator to use the same http client as main service + self.authenticator._set_http_client(self.get_http_client(), self.jar) add_hooks(self) def set_service_url(self, service_url: str): @@ -128,6 +126,8 @@ def set_disable_ssl_verification(self, status: bool = False) -> None: def set_http_client(self, http_client: Session) -> None: super().set_http_client(http_client) + if isinstance(self.authenticator, CouchDbSessionAuthenticator): + self.authenticator._set_http_client(self.get_http_client(), self.jar) add_hooks(self) def prepare_request(self, diff --git a/ibmcloudant/couchdb_session_authenticator.py b/ibmcloudant/couchdb_session_authenticator.py index 815a5f8e..90de04d5 100644 --- a/ibmcloudant/couchdb_session_authenticator.py +++ b/ibmcloudant/couchdb_session_authenticator.py @@ -1,6 +1,6 @@ # coding: utf-8 -# © Copyright IBM Corporation 2020, 2022. +# © Copyright IBM Corporation 2020, 2025. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ """ Module for handling session authentication """ -from requests import Request +from requests import Request, Session +from requests.cookies import RequestsCookieJar + from ibm_cloud_sdk_core.authenticators import Authenticator from .couchdb_session_token_manager import CouchDbSessionTokenManager @@ -46,8 +48,6 @@ def __init__(self, if not isinstance(disable_ssl_verification, bool): raise TypeError('disable_ssl_verification must be a bool') - self.jar = None - self.token_manager = CouchDbSessionTokenManager( username, password, @@ -55,11 +55,15 @@ def __init__(self, ) self.validate() - def set_jar(self, jar): - """Sets the cookie jar for the authenticator. + def _set_http_client(self, http_client: Session, jar: RequestsCookieJar) -> None: + """Sets base serivice's http client for the authenticator. This is an internal method called by BaseService. Not to be called directly. """ - self.jar = jar + if isinstance(http_client, Session): + self.token_manager.http_client = http_client + self.token_manager.jar = jar + else: + raise TypeError("http_client parameter must be a requests.sessions.Session") def validate(self): """Validates the username, and password for session token requests. @@ -82,11 +86,7 @@ def authenticate(self, req: Request): Args: req: Ignored. BaseService uses the cookie jar for every request """ - jar = self.token_manager.get_token() - # Requests seem to save cookies only for Sessions. BaseService is - # hard-coded to work with "regular" requests requests so updating - # the jar manually is necessary - self.jar.update(jar) + self.token_manager.get_token() def authentication_type(self) -> str: """Returns this authenticator's type ('COUCHDB_SESSION').""" diff --git a/ibmcloudant/couchdb_session_token_manager.py b/ibmcloudant/couchdb_session_token_manager.py index d39005f7..32512abf 100644 --- a/ibmcloudant/couchdb_session_token_manager.py +++ b/ibmcloudant/couchdb_session_token_manager.py @@ -1,6 +1,6 @@ # coding: utf-8 -# © Copyright IBM Corporation 2020, 2022. +# © Copyright IBM Corporation 2020, 2025. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ """ Module for managing session authentication token """ +from requests import Session + from ibm_cloud_sdk_core.token_managers.token_manager import TokenManager @@ -52,7 +54,9 @@ def __init__(self, username: str, password: str, self.token = None + self.http_client = None self.http_config = {} + self.jar = None self.headers = None def request_token(self): @@ -63,14 +67,19 @@ def request_token(self): A CookieJar of Cookies the server sent back. """ - response = self._request( + if not isinstance(self.http_client, Session): + raise TypeError("http_client parameter must be a requests.sessions.Session") + + response = self.http_client.request( method='POST', url=self.url + "/_session", headers=self.headers, json={ 'username': self.username, 'password': self.password, - } + }, + cookies=self.jar, + **self.http_config ) return response diff --git a/test/integration/test_timeout.py b/test/integration/test_timeout.py index 3a6a475a..ece750b9 100644 --- a/test/integration/test_timeout.py +++ b/test/integration/test_timeout.py @@ -1,6 +1,6 @@ # coding: utf-8 -# © Copyright IBM Corporation 2021. +# © Copyright IBM Corporation 2021, 2025. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -86,8 +86,24 @@ def get_authenticate_arguments(req): return req.mock_calls[0] @staticmethod - def get_request_arguments(tci, srv, idx): - tci.assertEqual(srv.http_client.request.call_count, idx + 1) + def get_request_arguments(tci, srv, idx, call_count=None): + """ + Retrieve the arguments passed to a mocked HTTP request. + + Parameters: + tci (unittest.mock.Mock): A Mock object used for assertions. + srv (Service): The instance of CloudantV1 service. + idx (int): The index of the request in srv.http_client.request.mock_calls. + call_count (int, optional): The total call count for the request. Defaults to None. + + Returns: + list: The arguments passed to the HTTP request at the specified index. + + Raises: + AssertionError: If the actual call count of the request does not match the provided or calculated call_count. + """ + call_count = idx + 1 if call_count is None else call_count + tci.assertEqual(srv.http_client.request.call_count, call_count) return srv.http_client.request.mock_calls[idx] # Every method tests an authenticator. @@ -142,31 +158,27 @@ def test_timeout_cloudantv1_sessionauth(self): ) my_service.set_service_url("http://cloudant.example") - # Mock out authentication - orig_request = requests.request - requests.request = Mock(return_value=Helpers.get_mocked_response()) - # Mock out request response Helpers.mock_out_cloudant_request(my_service) - testcases = Helpers.defineTestCases(my_service) + # Call the server + my_service.get_server_information() - for tc_num, tc in enumerate(testcases): - tc['set_timeout'](CUSTOM_TIMEOUT_CONFIG) - - # Call the server - my_service.get_server_information() + # Assert timeout is set to the authenticator + auth_args = Helpers.get_request_arguments(self, my_service.authenticator.token_manager, 0, call_count=2) + self.assertEqual(auth_args.kwargs['timeout'], DEFAULT_TIMEOUT) - # Assert timeout is set to the authenticator - auth_args = Helpers.get_authenticate_arguments(requests.request) - Helpers.assert_default_timeout_setting(self, auth_args) + # Assert timeout is set to the server request + req_args = Helpers.get_request_arguments(self, my_service, 1, call_count=2) + self.assertEqual(req_args.kwargs['timeout'], DEFAULT_TIMEOUT) - # Assert timeout is set in the server request - req_args = Helpers.get_request_arguments(self, my_service, tc_num) - tc['assert_func'](self, req_args) + # Set a custom timeout and repeat the request. Client should be already authenticated. + my_service.set_http_config(CUSTOM_TIMEOUT_CONFIG) + my_service.get_server_information() - # Set back requests.request - requests.request = orig_request + # Assert the custom timeout is set to the server request + req_args = Helpers.get_request_arguments(self, my_service, 2, call_count=3) + self.assertEqual(req_args.kwargs['timeout'], CUSTOM_TIMEOUT) def test_timeout_cloudantv1_iamauth(self): authenticator = IAMAuthenticator('apikey') From 16867c74b4d122e834e5cef5586ef114afef84e5 Mon Sep 17 00:00:00 2001 From: Eric Avdey Date: Fri, 17 Oct 2025 12:03:47 -0300 Subject: [PATCH 3/3] fix: rename pagination test class name to avoid warning --- test/unit/features/test_pagination_base.py | 92 +++++++++++----------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/test/unit/features/test_pagination_base.py b/test/unit/features/test_pagination_base.py index 77d2a16f..85006d51 100644 --- a/test/unit/features/test_pagination_base.py +++ b/test/unit/features/test_pagination_base.py @@ -32,7 +32,7 @@ def __init__(self, total_items: int, page_size: int): # Override plus_one_paging to accommodate this weird hybrid self.plus_one_paging = False -class TestPageIterator(_BasePageIterator): +class BaseTestPageIterator(_BasePageIterator): """ A test subclass of the _BasePager under test. """ @@ -40,7 +40,7 @@ class TestPageIterator(_BasePageIterator): page_keys: list[str] = [] def __init__(self, client, opts): - super().__init__(client, TestPageIterator.operation or client.post_view, TestPageIterator.page_keys, opts) + super().__init__(client, BaseTestPageIterator.operation or client.post_view, BaseTestPageIterator.page_keys, opts) def _result_converter(self) -> Callable[[dict], ViewResult]: return lambda d: ViewResult.from_dict(d) @@ -58,7 +58,7 @@ class TestBasePageIterator(MockClientBaseCase): def test_init(self): operation = self.client.post_view opts = {'db': 'test', 'limit': 20} - page_iterator: Iterable[ViewResultRow] = TestPageIterator(self.client, opts) + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator(self.client, opts) # Assert client is set self.assertEqual(page_iterator._client, self.client, 'The supplied client should be set.') # Assert operation is set @@ -72,8 +72,8 @@ def test_partial_options(self): page_opts = {'foo': 'boo', 'bar': 'far'} opts = {**static_opts, **page_opts} # Use page_opts.keys() to pass the list of names for page options - with patch('test_pagination_base.TestPageIterator.page_keys', page_opts.keys()): - page_iterator: Iterable[ViewResultRow] = TestPageIterator(self.client, opts) + with patch('test_pagination_base.BaseTestPageIterator.page_keys', page_opts.keys()): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator(self.client, opts) # Assert partial function has only static opts self.assertEqual(page_iterator._next_request_function.keywords, static_opts, 'The partial function kwargs should be only the static options.') # Assert next page options @@ -81,7 +81,7 @@ def test_partial_options(self): def test_default_page_size(self): opts = {'db': 'test'} - page_iterator: Iterable[ViewResultRow] = TestPageIterator(self.client, opts) + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator(self.client, opts) # Assert the default page size expected_page_size = 200 self.assertEqual(page_iterator._page_size, expected_page_size, 'The default page size should be set.') @@ -89,7 +89,7 @@ def test_default_page_size(self): def test_limit_page_size(self): opts = {'db': 'test', 'limit': 42} - page_iterator: Iterable[ViewResultRow] = TestPageIterator(self.client, opts) + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator(self.client, opts) # Assert the provided page size expected_page_size = 42 self.assertEqual(page_iterator._page_size, expected_page_size, 'The default page size should be set.') @@ -97,15 +97,15 @@ def test_limit_page_size(self): def test_has_next_initially_true(self): opts = {'limit': 1} - page_iterator: Iterable[ViewResultRow] = TestPageIterator(self.client, opts) + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator(self.client, opts) # Assert _has_next self.assertTrue(page_iterator._has_next, '_has_next should initially return True.') def test_has_next_true_for_result_equal_to_limit(self): page_size = 1 # Init with mock that returns only a single row - with patch('test_pagination_base.TestPageIterator.operation', BasePageMockResponses(1, page_size).get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', BasePageMockResponses(1, page_size).get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get first page with 1 result @@ -116,8 +116,8 @@ def test_has_next_true_for_result_equal_to_limit(self): def test_has_next_false_for_result_less_than_limit(self): page_size = 1 # Init with mock that returns zero rows - with patch('test_pagination_base.TestPageIterator.operation', BasePageMockResponses(0, page_size).get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', BasePageMockResponses(0, page_size).get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get first page with 0 result @@ -129,8 +129,8 @@ def test_next_first_page(self): page_size = 25 # Mock that returns one page of 25 items mock = BasePageMockResponses(page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get first page @@ -142,8 +142,8 @@ def test_next_two_pages(self): page_size = 3 # Mock that returns two pages of 3 items mock = BasePageMockResponses(2*page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get first page @@ -163,8 +163,8 @@ def test_next_until_empty(self): page_size = 3 # Mock that returns 3 pages of 3 items mock = BasePageMockResponses(3*page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) page_count = 0 @@ -182,8 +182,8 @@ def test_next_until_smaller(self): page_size = 3 # Mock that returns 3 pages of 3 items, then 1 more page with 1 item mock = BasePageMockResponses(3*page_size + 1, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) page_count = 0 @@ -201,8 +201,8 @@ def test_next_exception(self): page_size = 2 # Mock that returns one page of one item mock = BasePageMockResponses(page_size - 1, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get first and only page @@ -218,8 +218,8 @@ def test_next_exception(self): def test_pages_immutable(self): page_size = 1 mock = BasePageMockResponses(page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) # Get page @@ -230,8 +230,8 @@ def test_pages_immutable(self): def test_set_next_page_options(self): page_size = 1 mock = BasePageMockResponses(5*page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) self.assertIsNone(page_iterator._next_page_opts.get('start_key'), "The start key should intially be None.") @@ -248,8 +248,8 @@ def test_set_next_page_options(self): def test_next_resumes_after_error(self): page_size = 1 mock = BasePageMockResponses(3*page_size, page_size) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): - page_iterator: Iterable[ViewResultRow] = TestPageIterator( + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): + page_iterator: Iterable[ViewResultRow] = BaseTestPageIterator( self.client, {'limit': page_size}) self.assertIsNone(page_iterator._next_page_opts.get('start_key'), "The start key should intially be None.") @@ -269,8 +269,8 @@ def test_next_resumes_after_error(self): def test_pages_iterable(self): page_size = 23 mock = BasePageMockResponses(3*page_size-1, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): # Check pages are iterable page_number = 0 for page in pagination.pages(): @@ -282,8 +282,8 @@ def test_pages_iterable(self): def test_rows_iterable(self): page_size = 23 mock = BasePageMockResponses(3*page_size-1, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): actual_items = [] # Check rows are iterable for row in pagination.rows(): @@ -294,8 +294,8 @@ def test_as_pager_get_next_first_page(self): page_size = 7 # Mock that returns two pages of 7 items mock = BasePageMockResponses(2*page_size, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): pager = pagination.pager() # Get first page actual_page: list[ViewResultRow] = pager.get_next() @@ -305,8 +305,8 @@ def test_as_pager_get_all(self): page_size = 11 # Mock that returns 6 pages of 11 items, then 1 more page with 5 items mock = BasePageMockResponses(71, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): pager: Pager[ViewResultRow] = pagination.pager() actual_items = pager.get_all() self.assertSequenceEqual(actual_items, mock.all_expected_items(), "The results should match all the pages.") @@ -326,8 +326,8 @@ def test_as_pager_get_all_restarts_after_error(self): first_page, mock.get_next_page() ]) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mockmock): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mockmock): pager = pagination.pager() with self.assertRaises(Exception): pager.get_all() @@ -337,8 +337,8 @@ def test_as_pager_get_next_get_all_raises(self): page_size = 11 # Mock that returns 6 pages of 11 items, then 1 more page with 5 items mock = BasePageMockResponses(71, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): pager: Pager[ViewResultRow] = pagination.pager() first_page = pager.get_next() self.assertSequenceEqual(first_page, mock.get_expected_page(1), "The actual page should match the expected page") @@ -358,8 +358,8 @@ def test_as_pager_get_all_get_next_raises(self): first_page, Exception('test exception') ]) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mockmock): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mockmock): pager = pagination.pager() # Stop get all part way through so it isn't consumed when we call get Next with self.assertRaises(Exception): @@ -378,8 +378,8 @@ def test_as_pager_get_next_resumes_after_error(self): Exception('test exception'), mock.get_next_page() ]) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mockmock): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mockmock): pager = pagination.pager() # Assert first page self.assertSequenceEqual(pager.get_next(), mock.get_expected_page(1), "The actual page should match the expected page") @@ -392,8 +392,8 @@ def test_as_pager_get_next_until_consumed(self): page_size = 7 # Mock that returns two pages of 7 items mock = BasePageMockResponses(2*page_size, page_size) - pagination = Pagination(self.client, TestPageIterator, {'limit': page_size}) - with patch('test_pagination_base.TestPageIterator.operation', mock.get_next_page): + pagination = Pagination(self.client, BaseTestPageIterator, {'limit': page_size}) + with patch('test_pagination_base.BaseTestPageIterator.operation', mock.get_next_page): pager = pagination.pager() page_count = 0 while pager.has_next():