Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions okta_jwt_verifier/request_executor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Module contains tools to perform http requests."""
import time
import asyncio

from acachecontrol import AsyncCacheControl
from acachecontrol.cache import AsyncCache
from retry.api import retry_call

from .constants import MAX_RETRIES, MAX_REQUESTS, REQUEST_TIMEOUT


class RequestExecutor:
"""Wrapper around HTTP API requests."""

def __init__(self,
max_retries=MAX_RETRIES,
max_requests=MAX_REQUESTS,
Expand All @@ -33,7 +33,7 @@ async def fire_request(self, uri, **params):
resp_json = await resp.json()
return resp_json

def get(self, uri, **params):
async def get(self, uri, **params):
"""Perform http(s) GET request with retry.

Return response in json-format.
Expand All @@ -44,14 +44,21 @@ def get(self, uri, **params):
request_params['proxy'] = self.proxy

while self.requests_count >= self.max_requests:
time.sleep(0.1)
await asyncio.sleep(0.1)

self.requests_count += 1
response = retry_call(self.fire_request,
fargs=(uri,),
fkwargs=request_params,
tries=self.max_retries)
self.requests_count -= 1
return response
try:
last_error = None
for attempt in range(self.max_retries):
try:
return await self.fire_request(uri, **request_params)
except Exception as e:
last_error = e
if attempt < self.max_retries - 1:
await asyncio.sleep(0.5 * (2 ** attempt))
raise last_error
finally:
self.requests_count -= 1

def clear_cache(self):
"""Remove all cached data from all adapters in cached session."""
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ readme = "README.md"
python = "^3.8"
PyJWT = "^2.8.0"
acachecontrol = "^0.3.6"
retry2 = "^0.9.5"

[tool.poetry.group.dev.dependencies]
cryptography = ">=43.0.0"
Expand Down
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
requests>=2.31.0
pyjwt>=2.8.0
acachecontrol>=0.3.6
retry2
aiohttp>=3.12.14
certifi>=2023.7.22
urllib3>=1.26.18
setuptools>=78.1.1
cryptography>=43.0.0
cryptography>=43.0.0
101 changes: 72 additions & 29 deletions tests/unit/test_request_executor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,77 @@
import asyncio

import pytest
from unittest.mock import AsyncMock, MagicMock

from okta_jwt_verifier import BaseJWTVerifier, __version__ as version
from okta_jwt_verifier.constants import REQUEST_TIMEOUT
from okta_jwt_verifier.request_executor import RequestExecutor


@pytest.mark.asyncio
async def test_proxy(mocker):
class AsyncMock(mocker.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(self, *args, **kwargs)

issuer = 'https://test_issuer.com'
jwt_verifier = BaseJWTVerifier(issuer)

mock_fire_request = AsyncMock()
jwt_verifier.request_executor.fire_request = mock_fire_request
await jwt_verifier.get_jwks()

mock_fire_request.assert_called_with(mock_fire_request,
f'{issuer}/oauth2/v1/keys',
headers={'User-Agent': f'okta-jwt-verifier-python/{version}',
'Content-Type': 'application/json'},
timeout=REQUEST_TIMEOUT)

jwt_verifier = BaseJWTVerifier(issuer, proxy='http://test_proxy.com')
jwt_verifier.request_executor.fire_request = mock_fire_request
await jwt_verifier.get_jwks()

mock_fire_request.assert_called_with(mock_fire_request,
f'{issuer}/oauth2/v1/keys',
headers={'User-Agent': f'okta-jwt-verifier-python/{version}',
'Content-Type': 'application/json'},
timeout=REQUEST_TIMEOUT,
proxy='http://test_proxy.com')
async def test_proxy():
"""Test that proxy parameter is passed to requests."""
issuer = 'https://test.okta.com'

# Without proxy
verifier = BaseJWTVerifier(issuer)
verifier.request_executor.fire_request = AsyncMock(return_value={'keys': []})
await verifier.get_jwks()

verifier.request_executor.fire_request.assert_called_with(
f'{issuer}/oauth2/v1/keys',
headers={'User-Agent': f'okta-jwt-verifier-python/{version}',
'Content-Type': 'application/json'},
timeout=30
)

# With proxy
verifier = BaseJWTVerifier(issuer, proxy='http://proxy:8080')
verifier.request_executor.fire_request = AsyncMock(return_value={'keys': []})
await verifier.get_jwks()

verifier.request_executor.fire_request.assert_called_with(
f'{issuer}/oauth2/v1/keys',
headers={'User-Agent': f'okta-jwt-verifier-python/{version}',
'Content-Type': 'application/json'},
timeout=30,
proxy='http://proxy:8080'
)


@pytest.mark.asyncio
async def test_retry_success():
"""Test that transient failures are retried."""
executor = RequestExecutor(max_retries=3)
executor.fire_request = AsyncMock(side_effect=[
Exception('fail'),
Exception('fail'),
{'keys': []}
])

result = await executor.get('https://test.com/keys')

assert result == {'keys': []}
assert executor.fire_request.call_count == 3


@pytest.mark.asyncio
async def test_retry_exhausted():
"""Test that exception is raised when retries exhausted."""
executor = RequestExecutor(max_retries=2)
executor.fire_request = AsyncMock(side_effect=Exception('network error'))

with pytest.raises(Exception):
await executor.get('https://test.com/keys')

assert executor.fire_request.call_count == 2


@pytest.mark.asyncio
async def test_clear_cache():
"""Test that clear_cache calls underlying cache."""
executor = RequestExecutor()
executor.cache.clear_cache = MagicMock()

executor.clear_cache()

executor.cache.clear_cache.assert_called_once()