Skip to content

Commit deb70b3

Browse files
chore(internal): refactor authentication internals
1 parent b9079a9 commit deb70b3

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

src/cas_parser/_base_client.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
)
6464
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
6565
from ._compat import PYDANTIC_V1, model_copy, model_dump
66-
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
66+
from ._models import GenericModel, SecurityOptions, FinalRequestOptions, validate_type, construct_type
6767
from ._response import (
6868
APIResponse,
6969
BaseAPIResponse,
@@ -432,9 +432,27 @@ def _make_status_error(
432432
) -> _exceptions.APIStatusError:
433433
raise NotImplementedError()
434434

435+
def _auth_headers(
436+
self,
437+
security: SecurityOptions, # noqa: ARG002
438+
) -> dict[str, str]:
439+
return {}
440+
441+
def _auth_query(
442+
self,
443+
security: SecurityOptions, # noqa: ARG002
444+
) -> dict[str, str]:
445+
return {}
446+
447+
def _custom_auth(
448+
self,
449+
security: SecurityOptions, # noqa: ARG002
450+
) -> httpx.Auth | None:
451+
return None
452+
435453
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
436454
custom_headers = options.headers or {}
437-
headers_dict = _merge_mappings(self.default_headers, custom_headers)
455+
headers_dict = _merge_mappings({**self._auth_headers(options.security), **self.default_headers}, custom_headers)
438456
self._validate_headers(headers_dict, custom_headers)
439457

440458
# headers are case-insensitive while dictionaries are not.
@@ -506,7 +524,7 @@ def _build_request(
506524
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
507525

508526
headers = self._build_headers(options, retries_taken=retries_taken)
509-
params = _merge_mappings(self.default_query, options.params)
527+
params = _merge_mappings({**self._auth_query(options.security), **self.default_query}, options.params)
510528
content_type = headers.get("Content-Type")
511529
files = options.files
512530

@@ -671,7 +689,6 @@ def default_headers(self) -> dict[str, str | Omit]:
671689
"Content-Type": "application/json",
672690
"User-Agent": self.user_agent,
673691
**self.platform_headers(),
674-
**self.auth_headers,
675692
**self._custom_headers,
676693
}
677694

@@ -990,8 +1007,9 @@ def request(
9901007
self._prepare_request(request)
9911008

9921009
kwargs: HttpxSendArgs = {}
993-
if self.custom_auth is not None:
994-
kwargs["auth"] = self.custom_auth
1010+
custom_auth = self._custom_auth(options.security)
1011+
if custom_auth is not None:
1012+
kwargs["auth"] = custom_auth
9951013

9961014
if options.follow_redirects is not None:
9971015
kwargs["follow_redirects"] = options.follow_redirects
@@ -1952,6 +1970,7 @@ def make_request_options(
19521970
idempotency_key: str | None = None,
19531971
timeout: float | httpx.Timeout | None | NotGiven = not_given,
19541972
post_parser: PostParser | NotGiven = not_given,
1973+
security: SecurityOptions | None = None,
19551974
) -> RequestOptions:
19561975
"""Create a dict of type RequestOptions without keys of NotGiven values."""
19571976
options: RequestOptions = {}
@@ -1977,6 +1996,9 @@ def make_request_options(
19771996
# internal
19781997
options["post_parser"] = post_parser # type: ignore
19791998

1999+
if security is not None:
2000+
options["security"] = security
2001+
19802002
return options
19812003

19822004

src/cas_parser/_client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from ._utils import is_given, get_async_library
2323
from ._compat import cached_property
24+
from ._models import SecurityOptions
2425
from ._version import __version__
2526
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
2627
from ._exceptions import APIStatusError, CasParserError
@@ -274,9 +275,14 @@ def with_streaming_response(self) -> CasParserWithStreamedResponse:
274275
def qs(self) -> Querystring:
275276
return Querystring(array_format="comma")
276277

277-
@property
278278
@override
279-
def auth_headers(self) -> dict[str, str]:
279+
def _auth_headers(self, security: SecurityOptions) -> dict[str, str]:
280+
return {
281+
**(self._api_key_auth if security.get("api_key_auth", False) else {}),
282+
}
283+
284+
@property
285+
def _api_key_auth(self) -> dict[str, str]:
280286
api_key = self.api_key
281287
return {"x-api-key": api_key}
282288

@@ -578,9 +584,14 @@ def with_streaming_response(self) -> AsyncCasParserWithStreamedResponse:
578584
def qs(self) -> Querystring:
579585
return Querystring(array_format="comma")
580586

581-
@property
582587
@override
583-
def auth_headers(self) -> dict[str, str]:
588+
def _auth_headers(self, security: SecurityOptions) -> dict[str, str]:
589+
return {
590+
**(self._api_key_auth if security.get("api_key_auth", False) else {}),
591+
}
592+
593+
@property
594+
def _api_key_auth(self) -> dict[str, str]:
584595
api_key = self.api_key
585596
return {"x-api-key": api_key}
586597

src/cas_parser/_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,10 @@ def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
791791
return RootModel[type_] # type: ignore
792792

793793

794+
class SecurityOptions(TypedDict, total=False):
795+
api_key_auth: bool
796+
797+
794798
class FinalRequestOptionsInput(TypedDict, total=False):
795799
method: Required[str]
796800
url: Required[str]
@@ -804,6 +808,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
804808
json_data: Body
805809
extra_json: AnyMapping
806810
follow_redirects: bool
811+
security: SecurityOptions
807812

808813

809814
@final
@@ -818,6 +823,7 @@ class FinalRequestOptions(pydantic.BaseModel):
818823
idempotency_key: Union[str, None] = None
819824
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
820825
follow_redirects: Union[bool, None] = None
826+
security: SecurityOptions = {"api_key_auth": True}
821827

822828
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None
823829
# It should be noted that we cannot use `json` here as that would override

src/cas_parser/_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
3737

3838
if TYPE_CHECKING:
39-
from ._models import BaseModel
39+
from ._models import BaseModel, SecurityOptions
4040
from ._response import APIResponse, AsyncAPIResponse
4141

4242
Transport = BaseTransport
@@ -121,6 +121,7 @@ class RequestOptions(TypedDict, total=False):
121121
extra_json: AnyMapping
122122
idempotency_key: str
123123
follow_redirects: bool
124+
security: SecurityOptions
124125

125126

126127
# Sentinel class used until PEP 0661 is accepted

0 commit comments

Comments
 (0)