diff --git a/src/auth0_server_python/auth_server/mfa_client.py b/src/auth0_server_python/auth_server/mfa_client.py index 7940c38..c42e71b 100644 --- a/src/auth0_server_python/auth_server/mfa_client.py +++ b/src/auth0_server_python/auth_server/mfa_client.py @@ -59,7 +59,8 @@ def __init__( client_secret: str, secret: str, state_store=None, - state_identifier: str = "_a0_session" + state_identifier: str = "_a0_session", + headers: Optional[dict[str, str]] = None ): if callable(domain): self._domain = None @@ -72,6 +73,12 @@ def __init__( self._secret = secret self._state_store = state_store self._state_identifier = state_identifier + self._headers = headers or {} + + def _get_http_client(self, **kwargs) -> httpx.AsyncClient: + """Return an httpx.AsyncClient with default headers injected.""" + headers = {**kwargs.pop("headers", {}), **self._headers} + return httpx.AsyncClient(headers=headers, **kwargs) async def _resolve_base_url( self, @@ -157,7 +164,7 @@ async def list_authenticators( url = f"{base_url}/mfa/authenticators" try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.get( url, auth=BearerAuth(mfa_token) @@ -232,7 +239,7 @@ async def enroll_authenticator( body["email"] = options["email"] try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( url, json=body, @@ -311,7 +318,7 @@ async def challenge_authenticator( body["authenticator_id"] = options["authenticator_id"] try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( url, json=body, @@ -395,7 +402,7 @@ async def verify( base_url = await self._resolve_base_url(store_options) token_endpoint = f"{base_url}/oauth/token" - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( token_endpoint, data=body, diff --git a/src/auth0_server_python/auth_server/my_account_client.py b/src/auth0_server_python/auth_server/my_account_client.py index 330e837..499b981 100644 --- a/src/auth0_server_python/auth_server/my_account_client.py +++ b/src/auth0_server_python/auth_server/my_account_client.py @@ -25,14 +25,21 @@ class MyAccountClient: Client for interacting with the Auth0 MyAccount API. """ - def __init__(self, domain: str): + def __init__(self, domain: str, headers: Optional[dict[str, str]] = None): """ Initialize the MyAccount API client. Args: domain: Auth0 domain (e.g., '..auth0.com') + headers: Optional default headers to include on every request """ self._domain = domain + self._headers = headers or {} + + def _get_http_client(self, **kwargs) -> httpx.AsyncClient: + """Return an httpx.AsyncClient with default headers injected.""" + headers = {**kwargs.pop("headers", {}), **self._headers} + return httpx.AsyncClient(headers=headers, **kwargs) @property def audience(self): @@ -64,7 +71,7 @@ async def connect_account( ApiError: If the request fails due to network or other issues """ try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( url=f"{self.audience}v1/connected-accounts/connect", json=request.model_dump(exclude_none=True), @@ -114,7 +121,7 @@ async def complete_connect_account( ApiError: If the request fails due to network or other issues """ try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( url=f"{self.audience}v1/connected-accounts/complete", json=request.model_dump(exclude_none=True), @@ -176,7 +183,7 @@ async def list_connected_accounts( raise InvalidArgumentError("take", "The 'take' parameter must be a positive integer.") try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: params = {} if connection: params["connection"] = connection @@ -243,7 +250,7 @@ async def delete_connected_account( raise MissingRequiredArgumentError("connected_account_id") try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.delete( url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}", auth=BearerAuth(access_token) @@ -298,7 +305,7 @@ async def list_connected_account_connections( raise InvalidArgumentError("take", "The 'take' parameter must be a positive integer.") try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: params = {} if from_param: params["from"] = from_param diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 9222182..b5d2dbc 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -57,6 +57,7 @@ PollingApiError, StartLinkUserError, ) +from auth0_server_python.telemetry import Telemetry from auth0_server_python.utils import PKCE, URL, State from auth0_server_python.utils.helpers import ( build_domain_resolver_context, @@ -152,13 +153,20 @@ def __init__( self._transaction_identifier = transaction_identifier self._state_identifier = state_identifier + # Initialize telemetry + self._telemetry = Telemetry.default() + self._telemetry_headers = self._telemetry.headers + # Initialize OAuth client self._oauth = AsyncOAuth2Client( client_id=client_id, client_secret=client_secret, + headers=self._telemetry_headers, ) - self._my_account_client = MyAccountClient(domain=domain) + self._my_account_client = MyAccountClient( + domain=domain, headers=self._telemetry_headers + ) # Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL) self._discovery_cache: OrderedDict[str, dict] = OrderedDict() @@ -172,9 +180,15 @@ def __init__( client_secret=self._client_secret, secret=self._secret, state_store=self._state_store, - state_identifier=self._state_identifier + state_identifier=self._state_identifier, + headers=self._telemetry_headers, ) + def _get_http_client(self, **kwargs) -> httpx.AsyncClient: + """Return an httpx.AsyncClient with telemetry headers injected.""" + headers = {**kwargs.pop("headers", {}), **self._telemetry_headers} + return httpx.AsyncClient(headers=headers, **kwargs) + def _normalize_url(self, value: str) -> str: """ Normalize a URL-like value (domain or issuer) for comparison. @@ -281,7 +295,7 @@ async def _fetch_oidc_metadata(self, domain: str) -> dict: """Fetch OIDC metadata from domain.""" normalized_domain = self._normalize_url(domain) metadata_url = f"{normalized_domain}/.well-known/openid-configuration" - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.get(metadata_url) response.raise_for_status() return response.json() @@ -352,7 +366,7 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict: ApiError: If JWKS fetch fails """ try: - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.get(jwks_uri) response.raise_for_status() return response.json() @@ -516,7 +530,7 @@ async def start_interactive_login( auth_params["client_id"] = self._client_id # Post the auth_params to the PAR endpoint - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: par_response = await client.post( par_endpoint, data=auth_params, @@ -1077,7 +1091,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str, token_params["scope"] = merged_scope # Exchange the refresh token for an access token - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( token_endpoint, data=token_params, @@ -1391,7 +1405,7 @@ async def initiate_backchannel_authentication( params.update(authorization_params) # Make the backchannel authentication request - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: backchannel_response = await client.post( backchannel_endpoint, data=params, @@ -1466,7 +1480,7 @@ async def backchannel_authentication_grant( } # Exchange the auth_req_id for an access token - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( token_endpoint, data=token_params, @@ -1918,7 +1932,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A params["login_hint"] = options["login_hint"] # Make the request - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( token_endpoint, data=params, @@ -2272,7 +2286,7 @@ async def custom_token_exchange( params[key] = value # Make the token exchange request - async with httpx.AsyncClient() as client: + async with self._get_http_client() as client: response = await client.post( token_endpoint, data=params, diff --git a/src/auth0_server_python/telemetry.py b/src/auth0_server_python/telemetry.py new file mode 100644 index 0000000..2f07165 --- /dev/null +++ b/src/auth0_server_python/telemetry.py @@ -0,0 +1,39 @@ +""" +Telemetry support for auth0-server-python SDK. + +Builds and caches the Auth0-Client and User-Agent headers sent +on every HTTP request to Auth0 endpoints. +""" + +import base64 +import importlib.metadata +import json +import platform +from typing import Optional + + +class Telemetry: + """Builds telemetry headers for Auth0 HTTP requests.""" + + _PACKAGE_NAME = "auth0-server-python" + + def __init__(self, name: str, version: str, env: Optional[dict[str, str]] = None): + self.name = name + self.version = version + self.env = env if env is not None else {"python": platform.python_version()} + payload = {"name": self.name, "version": self.version, "env": self.env} + self.headers: dict[str, str] = { + "Auth0-Client": base64.b64encode( + json.dumps(payload).encode("utf-8") + ).decode("utf-8"), + "User-Agent": f"Python/{platform.python_version()}", + } + + @staticmethod + def default() -> "Telemetry": + """Create a Telemetry instance with this SDK's package metadata.""" + try: + version = importlib.metadata.version(Telemetry._PACKAGE_NAME) + except importlib.metadata.PackageNotFoundError: + version = "unknown" + return Telemetry(name=Telemetry._PACKAGE_NAME, version=version) diff --git a/src/auth0_server_python/tests/test_telemetry.py b/src/auth0_server_python/tests/test_telemetry.py new file mode 100644 index 0000000..a980a34 --- /dev/null +++ b/src/auth0_server_python/tests/test_telemetry.py @@ -0,0 +1,131 @@ +import base64 +import importlib.metadata +import json +import platform +from unittest.mock import AsyncMock, patch + +import pytest + +from auth0_server_python.auth_server.server_client import ServerClient +from auth0_server_python.telemetry import Telemetry + + +class TestTelemetry: + """Tests for the Telemetry class.""" + + def test_headers_contains_expected_keys(self): + telemetry = Telemetry(name="test-sdk", version="1.0.0") + assert "Auth0-Client" in telemetry.headers + assert "User-Agent" in telemetry.headers + + def test_auth0_client_header_format(self): + telemetry = Telemetry( + name="auth0-server-python", + version="1.0.0b9", + env={"python": "3.10.16"}, + ) + decoded = json.loads(base64.b64decode(telemetry.headers["Auth0-Client"])) + assert decoded == { + "name": "auth0-server-python", + "version": "1.0.0b9", + "env": {"python": "3.10.16"}, + } + + def test_user_agent_header(self): + telemetry = Telemetry(name="test-sdk", version="1.0.0") + assert telemetry.headers["User-Agent"] == f"Python/{platform.python_version()}" + + def test_default_env_uses_python_version(self): + telemetry = Telemetry(name="test-sdk", version="1.0.0") + assert telemetry.env == {"python": platform.python_version()} + + def test_custom_env_override(self): + telemetry = Telemetry( + name="test-sdk", version="1.0.0", env={"python": "3.9.0", "framework": "fastapi"} + ) + decoded = json.loads(base64.b64decode(telemetry.headers["Auth0-Client"])) + assert decoded["env"] == {"python": "3.9.0", "framework": "fastapi"} + + def test_default_factory(self): + telemetry = Telemetry.default() + assert telemetry.name == "auth0-server-python" + assert telemetry.version != "" + assert "python" in telemetry.env + + @patch( + "auth0_server_python.telemetry.importlib.metadata.version", + side_effect=importlib.metadata.PackageNotFoundError("not installed"), + ) + def test_default_factory_unknown_version_on_error(self, _mock): + telemetry = Telemetry.default() + assert telemetry.version == "unknown" + + +class TestServerClientTelemetry: + """Tests that ServerClient injects telemetry headers into HTTP requests.""" + + def _make_client(self): + return ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + secret="test-secret", + state_store=AsyncMock(), + transaction_store=AsyncMock(), + ) + + def test_server_client_has_telemetry_headers(self): + client = self._make_client() + assert client._telemetry_headers is not None + assert "Auth0-Client" in client._telemetry_headers + assert "User-Agent" in client._telemetry_headers + + def test_server_client_telemetry_payload_structure(self): + client = self._make_client() + decoded = json.loads(base64.b64decode(client._telemetry_headers["Auth0-Client"])) + assert decoded["name"] == "auth0-server-python" + assert "version" in decoded + assert "python" in decoded["env"] + + @pytest.mark.asyncio + async def test_get_http_client_includes_telemetry_headers(self): + client = self._make_client() + http_client = client._get_http_client() + for key, value in client._telemetry_headers.items(): + assert http_client.headers.get(key) == value + await http_client.aclose() + + @pytest.mark.asyncio + async def test_get_http_client_per_request_headers_do_not_override_telemetry(self): + client = self._make_client() + http_client = client._get_http_client(headers={"User-Agent": "custom", "X-Custom": "val"}) + # Telemetry headers must win over caller-provided duplicates + assert http_client.headers.get("User-Agent") == client._telemetry_headers["User-Agent"] + assert http_client.headers.get("Auth0-Client") == client._telemetry_headers["Auth0-Client"] + # Non-conflicting caller headers are preserved + assert http_client.headers.get("X-Custom") == "val" + await http_client.aclose() + + def test_my_account_client_receives_telemetry_headers(self): + client = self._make_client() + assert client._my_account_client._headers == client._telemetry_headers + + def test_mfa_client_receives_telemetry_headers(self): + client = self._make_client() + assert client._mfa_client._headers == client._telemetry_headers + + @pytest.mark.asyncio + async def test_fetch_oidc_metadata_sends_telemetry(self): + client = self._make_client() + http_client = client._get_http_client() + # Verify the client that _fetch_oidc_metadata would use has telemetry headers + for key, value in client._telemetry_headers.items(): + assert http_client.headers.get(key) == value + await http_client.aclose() + + def test_oauth_client_receives_telemetry_headers(self): + client = self._make_client() + # AsyncOAuth2Client stores headers passed at construction on its session + oauth_headers = client._oauth.headers + for key, value in client._telemetry_headers.items(): + assert oauth_headers.get(key) == value