Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/auth0/authentication/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from typing import Any

from .client_authentication import add_client_authentication
from .rest import RestClient, RestClientOptions
from .types import RequestData, TimeoutType

from .client_authentication import add_client_authentication

UNKNOWN_ERROR = "a0.sdk.internal.unknown"


Expand All @@ -22,6 +21,9 @@ class AuthenticationBase:
telemetry (bool, optional): Enable or disable telemetry (defaults to True)
timeout (float or tuple, optional): Change the requests connect and read timeout. Pass a tuple to specify both values separately or a float to set both to it. (defaults to 5.0 for both)
protocol (str, optional): Useful for testing. (defaults to 'https')
client_info (dict, optional): Custom telemetry data for the Auth0-Client header.
When provided, overrides the default SDK telemetry. Useful for wrapper
SDKs that need to identify themselves. Ignored when telemetry is False.
"""

def __init__(
Expand All @@ -34,6 +36,7 @@ def __init__(
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
client_info: dict[str, Any] | None = None,
) -> None:
self.domain = domain
self.client_id = client_id
Expand All @@ -43,7 +46,9 @@ def __init__(
self.protocol = protocol
self.client = RestClient(
None,
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
options=RestClientOptions(
telemetry=telemetry, timeout=timeout, retries=0, client_info=client_info
),
)

def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]:
Expand Down
21 changes: 16 additions & 5 deletions src/auth0/authentication/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from urllib.parse import urlencode

import requests

from .exceptions import Auth0Error, RateLimitError
from .types import RequestData, TimeoutType

Expand Down Expand Up @@ -38,17 +37,26 @@ class RestClientOptions:
times using an exponential backoff strategy, before
raising a RateLimitError exception. 10 retries max.
(defaults to 3)
client_info (dict, optional): Custom telemetry data to send
in the Auth0-Client header instead of the default SDK
info. Useful for wrapper SDKs that need to identify
themselves. When provided, this dict is JSON-encoded
and base64-encoded as the header value. Ignored when
telemetry is False.
(defaults to None)
"""

def __init__(
self,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
retries: int = 3,
client_info: dict[str, Any] | None = None,
) -> None:
self.telemetry = telemetry
self.timeout = timeout
self.retries = retries
self.client_info = client_info


class RestClient:
Expand Down Expand Up @@ -94,17 +102,20 @@ def __init__(

if options.telemetry:
py_version = platform.python_version()
version = sys.modules["auth0"].__version__

auth0_client = dumps(
{
if options.client_info is not None:
auth0_client_dict = options.client_info
else:
version = sys.modules["auth0"].__version__
auth0_client_dict = {
"name": "auth0-python",
"version": version,
"env": {
"python": py_version,
},
}
).encode("utf-8")

auth0_client = dumps(auth0_client_dict).encode("utf-8")

self.base_headers.update(
{
Expand Down
28 changes: 27 additions & 1 deletion src/auth0/management/management_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import base64
from json import dumps
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import httpx
from .client import AsyncAuth0, Auth0
Expand Down Expand Up @@ -86,6 +88,10 @@ class ManagementClient:
The API audience. Defaults to https://{domain}/api/v2/
headers : Optional[Dict[str, str]]
Additional headers to send with requests.
client_info : Optional[Dict[str, Any]]
Custom telemetry data for the Auth0-Client header. When provided,
overrides the default SDK telemetry. Useful for wrapper SDKs that
need to identify themselves (e.g., ``{"name": "my-sdk", "version": "1.0.0"}``).
timeout : Optional[float]
Request timeout in seconds. Defaults to 60.
httpx_client : Optional[httpx.Client]
Expand All @@ -106,6 +112,7 @@ def __init__(
client_secret: Optional[str] = None,
audience: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
client_info: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
httpx_client: Optional[httpx.Client] = None,
):
Expand All @@ -128,6 +135,13 @@ def __init__(
else:
resolved_token = token # type: ignore[assignment]

# Encode client_info into Auth0-Client header to override default telemetry
if client_info is not None:
encoded = base64.b64encode(
dumps(client_info).encode("utf-8")
).decode()
headers = {**(headers or {}), "Auth0-Client": encoded}

# Create underlying client
self._api = Auth0(
base_url=f"https://{domain}/api/v2",
Expand Down Expand Up @@ -333,6 +347,10 @@ class AsyncManagementClient:
The API audience. Defaults to https://{domain}/api/v2/
headers : Optional[Dict[str, str]]
Additional headers to send with requests.
client_info : Optional[Dict[str, Any]]
Custom telemetry data for the Auth0-Client header. When provided,
overrides the default SDK telemetry. Useful for wrapper SDKs that
need to identify themselves (e.g., ``{"name": "my-sdk", "version": "1.0.0"}``).
timeout : Optional[float]
Request timeout in seconds. Defaults to 60.
httpx_client : Optional[httpx.AsyncClient]
Expand All @@ -353,6 +371,7 @@ def __init__(
client_secret: Optional[str] = None,
audience: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
client_info: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
httpx_client: Optional[httpx.AsyncClient] = None,
):
Expand All @@ -378,6 +397,13 @@ def __init__(
else:
resolved_token = token # type: ignore[assignment]

# Encode client_info into Auth0-Client header to override default telemetry
if client_info is not None:
encoded = base64.b64encode(
dumps(client_info).encode("utf-8")
).decode()
headers = {**(headers or {}), "Auth0-Client": encoded}

# Create underlying client
self._api = AsyncAuth0(
base_url=f"https://{domain}/api/v2",
Expand Down
35 changes: 33 additions & 2 deletions tests/authentication/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import unittest
from unittest import mock

import requests

from auth0.authentication.base import AuthenticationBase
from auth0.authentication.exceptions import Auth0Error, RateLimitError

Expand Down Expand Up @@ -42,6 +40,39 @@ def test_telemetry_disabled(self):

self.assertEqual(ab.client.base_headers, {"Content-Type": "application/json"})

def test_telemetry_with_custom_client_info(self):
custom_info = {
"name": "auth0-ai-langchain",
"version": "1.0.0",
"env": {"python": "3.11.0"},
}
ab = AuthenticationBase("auth0.com", "cid", client_info=custom_info)
base_headers = ab.client.base_headers

auth0_client_bytes = base64.b64decode(base_headers["Auth0-Client"])
auth0_client = json.loads(auth0_client_bytes.decode("utf-8"))

self.assertEqual(auth0_client, custom_info)

def test_telemetry_disabled_ignores_client_info(self):
custom_info = {"name": "my-sdk", "version": "2.0.0"}
ab = AuthenticationBase(
"auth0.com", "cid", telemetry=False, client_info=custom_info
)

self.assertNotIn("Auth0-Client", ab.client.base_headers)
self.assertNotIn("User-Agent", ab.client.base_headers)

def test_custom_client_info_preserves_user_agent(self):
custom_info = {"name": "my-sdk", "version": "1.0.0"}
ab = AuthenticationBase("auth0.com", "cid", client_info=custom_info)
base_headers = ab.client.base_headers

python_version = "{}.{}.{}".format(
sys.version_info.major, sys.version_info.minor, sys.version_info.micro
)
self.assertEqual(base_headers["User-Agent"], f"Python/{python_version}")

@mock.patch("requests.request")
def test_post(self, mock_request):
ab = AuthenticationBase("auth0.com", "cid", telemetry=False, timeout=(10, 2))
Expand Down
68 changes: 68 additions & 0 deletions tests/management/test_management_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import time
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -78,6 +80,53 @@ def test_init_with_custom_headers(self):
)
assert client._api is not None

def test_init_with_custom_client_info(self):
"""Should encode client_info as Auth0-Client header."""
custom_info = {
"name": "auth0-ai-langchain",
"version": "1.0.0",
"env": {"python": "3.11.0"},
}
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
client_info=custom_info,
)
# Verify the header was set on the underlying client wrapper
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
encoded_header = custom_headers.get("Auth0-Client")
assert encoded_header is not None
decoded = json.loads(base64.b64decode(encoded_header).decode("utf-8"))
assert decoded == custom_info

def test_init_with_client_info_and_custom_headers(self):
"""Should merge client_info with custom headers."""
custom_info = {"name": "my-sdk", "version": "2.0.0"}
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
headers={"X-Custom": "value"},
client_info=custom_info,
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
assert custom_headers.get("X-Custom") == "value"
assert "Auth0-Client" in custom_headers

def test_init_without_client_info_uses_default_telemetry(self):
"""Should use default auth0-python telemetry when client_info is not provided."""
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
)
# get_headers() includes the default Auth0-Client telemetry
headers = client._api._client_wrapper.get_headers()
encoded = headers.get("Auth0-Client")
assert encoded is not None
decoded = json.loads(base64.b64decode(encoded).decode("utf-8"))
assert decoded["name"] == "auth0-python"


class TestManagementClientProperties:
"""Tests for ManagementClient sub-client properties."""
Expand Down Expand Up @@ -173,6 +222,25 @@ def test_init_requires_auth(self):
with pytest.raises(ValueError):
AsyncManagementClient(domain="test.auth0.com")

def test_init_with_custom_client_info(self):
"""Should encode client_info as Auth0-Client header."""
custom_info = {
"name": "auth0-ai-langchain",
"version": "1.0.0",
"env": {"python": "3.11.0"},
}
client = AsyncManagementClient(
domain="test.auth0.com",
token="my-token",
client_info=custom_info,
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
encoded_header = custom_headers.get("Auth0-Client")
assert encoded_header is not None
decoded = json.loads(base64.b64decode(encoded_header).decode("utf-8"))
assert decoded == custom_info


class TestTokenProvider:
"""Tests for TokenProvider."""
Expand Down
Loading