Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/pulp-glue/+aiohttp.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
WIP: Added async api to Pulp glue.
2 changes: 2 additions & 0 deletions CHANGES/pulp-glue/+aiohttp.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Replaced requests with aiohttp.
Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library.
2 changes: 1 addition & 1 deletion lint_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ mypy~=1.19.1
shellcheck-py~=0.11.0.1

# Type annotation stubs
types-aiofiles
types-pygments
types-PyYAML
types-requests
types-setuptools
types-toml

Expand Down
2 changes: 2 additions & 0 deletions lower_bounds_constraints.lock
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
aiofiles==25.1.0
aiohttp==3.12.0
click==8.0.0
packaging==22.0
PyYAML==5.3
Expand Down
3 changes: 2 additions & 1 deletion pulp-glue/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"aiofiles>=25.1.0,<25.2",
"aiohttp>=3.12.0,<3.14",
"multidict>=6.0.5,<6.8",
"packaging>=22.0,<=26.0", # CalVer
"requests>=2.24.0,<2.33",
"tomli>=2.0.0,<2.1;python_version<'3.11'",
]

Expand Down
168 changes: 88 additions & 80 deletions pulp-glue/src/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from io import BufferedReader
from urllib.parse import urlencode, urljoin

import requests
import urllib3
import aiofiles
import aiofiles.os
import aiohttp
from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping

from pulp_glue.common import __version__
Expand Down Expand Up @@ -136,38 +137,13 @@ def __init__(
if cid:
self._headers["Correlation-Id"] = cid

self._setup_session()

self._oauth2_lock = asyncio.Lock()
self._oauth2_token: str | None = None
self._oauth2_expires: datetime = datetime.now()

self._patch_api_hook: t.Callable[[t.Any], t.Any] = patch_api_hook or (lambda data: data)
self.load_api(refresh_cache=refresh_cache)

def _setup_session(self) -> None:
# This is specific requests library.

if self._verify_ssl is False:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self._session: requests.Session = requests.session()
# Don't redirect, because carrying auth accross redirects is unsafe.
self._session.max_redirects = 0
self._session.headers.update(self._headers)
session_settings = self._session.merge_environment_settings(
self._base_url, {}, None, self._verify_ssl, None
)
self._session.verify = session_settings["verify"]
self._session.proxies = session_settings["proxies"]

if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS():
cert, key = self._auth_provider.tls_credentials()
if key is not None:
self._session.cert = (cert, key)
else:
self._session.cert = cert

@property
def base_url(self) -> str:
return self._base_url
Expand All @@ -191,7 +167,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
return _ssl_context

def load_api(self, refresh_cache: bool = False) -> None:
# TODO: Find a way to invalidate caches on upstream change
asyncio.run(self._load_api(refresh_cache=refresh_cache))

async def _load_api(self, refresh_cache: bool = False) -> None:
# TODO: Find a way to invalidate caches on upstream change.
xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache"
apidoc_cache: str = os.path.join(
os.path.expanduser(xdg_cache_home),
Expand All @@ -203,17 +182,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
if refresh_cache:
# Fake that we did not find the cache.
raise OSError()
with open(apidoc_cache, "rb") as f:
data: bytes = f.read()
async with aiofiles.open(apidoc_cache, mode="rb") as f:
data: bytes = await f.read()
self._parse_api(data)
except Exception:
# Try again with a freshly downloaded version
data = self._download_api()
# Try again with a freshly downloaded version.
data = await self._download_api()
self._parse_api(data)
# Write to cache as it seems to be valid
os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
with open(apidoc_cache, "bw") as f:
f.write(data)
# Write to cache as it seems to be valid.
await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
async with aiofiles.open(apidoc_cache, mode="bw") as f:
await f.write(data)

def _parse_api(self, data: bytes) -> None:
raw_spec = self._patch_api_hook(json.loads(data))
Expand All @@ -229,15 +208,18 @@ def _parse_api(self, data: bytes) -> None:
if method in METHODS
}

def _download_api(self) -> bytes:
try:
response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path))
except requests.RequestException as e:
raise OpenAPIError(str(e))
response.raise_for_status()
if "Correlation-Id" in response.headers:
self._set_correlation_id(response.headers["Correlation-Id"])
return response.content
async def _download_api(self) -> bytes:
response = await self._send_request(
_Request(
operation_id="",
method="get",
url=urljoin(self._base_url, self._doc_path),
headers=self._headers,
)
)
if response.status_code != 200:
raise OpenAPIError(_("Failed to find api docs."))
return response.body

def _set_correlation_id(self, correlation_id: str) -> None:
if "Correlation-Id" in self._headers:
Expand All @@ -249,8 +231,6 @@ def _set_correlation_id(self, correlation_id: str) -> None:
)
else:
self._headers["Correlation-Id"] = correlation_id
# Do it for requests too...
self._session.headers["Correlation-Id"] = correlation_id

def param_spec(
self, operation_id: str, param_type: str, required: bool = False
Expand Down Expand Up @@ -467,7 +447,7 @@ def _render_request(
security=security,
)

def _log_request(self, request: _Request) -> None:
async def _log_request(self, request: _Request) -> None:
if request.params:
qs = urlencode(request.params)
self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}")
Expand All @@ -493,7 +473,6 @@ def _select_proposal(
if (
request.security
and "Authorization" not in request.headers
and "Authorization" not in self._session.headers
and self._auth_provider is not None
):
security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][
Expand Down Expand Up @@ -565,7 +544,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
headers={"Authorization": f"Basic {secret.decode()}"},
data=data,
)
response = self._send_request(request)
response = await self._send_request(request)
if response.status_code < 200 or response.status_code >= 300:
raise OpenAPIError("Failed to fetch OAuth2 token")
result = json.loads(response.body)
Expand All @@ -574,38 +553,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
new_token = True
return new_token

def _send_request(
async def _send_request(
self,
request: _Request,
) -> _Response:
# This function uses requests to translate the _Request into a _Response.
# This function uses aiohttp to translate the _Request into a _Response.
data: aiohttp.FormData | dict[str, t.Any] | str | None
if request.files:
assert isinstance(request.data, dict)
# Maybe assert on the content type header.
data = aiohttp.FormData(default_to_multipart=True)
for key, value in request.data.items():
data.add_field(key, encode_param(value))
for key, (name, value, content_type) in request.files.items():
data.add_field(key, value, filename=name, content_type=content_type)
else:
data = request.data
try:
r = self._session.request(
request.method,
request.url,
params=request.params,
headers=request.headers,
data=request.data,
files=request.files,
)
response = _Response(status_code=r.status_code, headers=r.headers, body=r.content)
except requests.TooManyRedirects as e:
assert e.response is not None
async with aiohttp.ClientSession() as session:
async with session.request(
request.method,
request.url,
params=request.params,
headers=request.headers,
data=data,
ssl=self.ssl_context,
max_redirects=0,
) as r:
response_body = await r.read()
response = _Response(
status_code=r.status, headers=r.headers, body=response_body
)
except aiohttp.TooManyRedirects as e:
# We could handle that in the middleware...
assert e.history[-1] is not None
raise OpenAPIError(
_(
"Received redirect to '{new_url} from {old_url}'."
" Please check your configuration."
).format(
new_url=e.response.headers["location"],
new_url=e.history[-1].headers["location"],
old_url=request.url,
)
)
except requests.RequestException as e:
except aiohttp.ClientResponseError as e:
raise OpenAPIError(str(e))

return response

def _log_response(self, response: _Response) -> None:
async def _log_response(self, response: _Response) -> None:
self._debug_callback(
1, _("Response: {status_code}").format(status_code=response.status_code)
)
Expand Down Expand Up @@ -652,6 +648,22 @@ def call(
parameters: dict[str, t.Any] | None = None,
body: dict[str, t.Any] | None = None,
validate_body: bool = True,
) -> t.Any:
return asyncio.run(
self.async_call(
operation_id=operation_id,
parameters=parameters,
body=body,
validate_body=validate_body,
)
)

async def async_call(
self,
operation_id: str,
parameters: dict[str, t.Any] | None = None,
body: dict[str, t.Any] | None = None,
validate_body: bool = True,
) -> t.Any:
"""
Make a call to the server.
Expand Down Expand Up @@ -706,37 +718,33 @@ def call(
body,
validate_body=validate_body,
)
self._log_request(request)
await self._log_request(request)

if self._dry_run and request.method.upper() not in SAFE_METHODS:
raise UnsafeCallError(_("Call aborted due to safe mode"))

may_retry = False
if proposal := self._select_proposal(request):
assert len(proposal) == 1, "More complex security proposals are not implemented."
may_retry = asyncio.run(self._authenticate_request(request, proposal))
may_retry = await self._authenticate_request(request, proposal)

response = self._send_request(request)
response = await self._send_request(request)

if proposal is not None:
assert self._auth_provider is not None
if may_retry and response.status_code == 401:
self._oauth2_token = None
asyncio.run(self._authenticate_request(request, proposal))
response = self._send_request(request)
await self._authenticate_request(request, proposal)
response = await self._send_request(request)

if response.status_code >= 200 and response.status_code < 300:
asyncio.run(
self._auth_provider.auth_success_hook(
proposal, self.api_spec["components"]["securitySchemes"]
)
await self._auth_provider.auth_success_hook(
proposal, self.api_spec["components"]["securitySchemes"]
)
elif response.status_code == 401:
asyncio.run(
self._auth_provider.auth_failure_hook(
proposal, self.api_spec["components"]["securitySchemes"]
)
await self._auth_provider.auth_failure_hook(
proposal, self.api_spec["components"]["securitySchemes"]
)

self._log_response(response)
await self._log_response(response)
return self._parse_response(method_spec, response)
10 changes: 2 additions & 8 deletions pulp-glue/tests/test_auth_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ def test_can_complete_basic(self, provider: AuthProviderBase) -> None:
assert provider.can_complete_http_basic()

def test_provides_username_and_password(self, provider: AuthProviderBase) -> None:
assert asyncio.run(provider.http_basic_credentials()) == (
b"user1",
b"password1",
)
assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1")

def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None:
assert not provider.can_complete_mutualTLS()
Expand Down Expand Up @@ -104,10 +101,7 @@ def test_client_id_needs_client_secret(self) -> None:
def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None:
provider = GlueAuthProvider(client_id="client1", client_secret="secret1")
assert provider.can_complete_oauth2_client_credentials([]) is True
assert asyncio.run(provider.oauth2_client_credentials()) == (
b"client1",
b"secret1",
)
assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1")

def test_can_complete_mutualTLS_and_provide_cert(self) -> None:
provider = GlueAuthProvider(cert="FAKECERTIFICATE")
Expand Down
2 changes: 1 addition & 1 deletion pulp-glue/tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
).encode()


def mock_send_request(request: _Request) -> _Response:
async def mock_send_request(request: _Request) -> _Response:
if request.url.endswith("oauth/token"):
assert request.method.lower() == "post"
# $ echo -n "client1:secret1" | base64
Expand Down
Loading
Loading