Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/auth0_server_python/auth_server/mfa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(
client_secret: str,
secret: str,
state_store=None,
state_identifier: str = "_a0_session"
state_identifier: str = "_a0_session",
headers: Optional[dict[str, str]] = None
):
if callable(domain):
self._domain = None
Expand All @@ -72,6 +73,12 @@ def __init__(
self._secret = secret
self._state_store = state_store
self._state_identifier = state_identifier
self._headers = headers or {}

def _get_http_client(self, **kwargs) -> httpx.AsyncClient:
"""Return an httpx.AsyncClient with default headers injected."""
headers = {**kwargs.pop("headers", {}), **self._headers}
return httpx.AsyncClient(headers=headers, **kwargs)

async def _resolve_base_url(
self,
Expand Down Expand Up @@ -157,7 +164,7 @@ async def list_authenticators(
url = f"{base_url}/mfa/authenticators"

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.get(
url,
auth=BearerAuth(mfa_token)
Expand Down Expand Up @@ -232,7 +239,7 @@ async def enroll_authenticator(
body["email"] = options["email"]

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
url,
json=body,
Expand Down Expand Up @@ -311,7 +318,7 @@ async def challenge_authenticator(
body["authenticator_id"] = options["authenticator_id"]

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
url,
json=body,
Expand Down Expand Up @@ -395,7 +402,7 @@ async def verify(
base_url = await self._resolve_base_url(store_options)
token_endpoint = f"{base_url}/oauth/token"

async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
token_endpoint,
data=body,
Expand Down
19 changes: 13 additions & 6 deletions src/auth0_server_python/auth_server/my_account_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ class MyAccountClient:
Client for interacting with the Auth0 MyAccount API.
"""

def __init__(self, domain: str):
def __init__(self, domain: str, headers: Optional[dict[str, str]] = None):
"""
Initialize the MyAccount API client.

Args:
domain: Auth0 domain (e.g., '<tenant>.<locality>.auth0.com')
headers: Optional default headers to include on every request
"""
self._domain = domain
self._headers = headers or {}

def _get_http_client(self, **kwargs) -> httpx.AsyncClient:
"""Return an httpx.AsyncClient with default headers injected."""
headers = {**kwargs.pop("headers", {}), **self._headers}
return httpx.AsyncClient(headers=headers, **kwargs)

@property
def audience(self):
Expand Down Expand Up @@ -64,7 +71,7 @@ async def connect_account(
ApiError: If the request fails due to network or other issues
"""
try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
url=f"{self.audience}v1/connected-accounts/connect",
json=request.model_dump(exclude_none=True),
Expand Down Expand Up @@ -114,7 +121,7 @@ async def complete_connect_account(
ApiError: If the request fails due to network or other issues
"""
try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
url=f"{self.audience}v1/connected-accounts/complete",
json=request.model_dump(exclude_none=True),
Expand Down Expand Up @@ -176,7 +183,7 @@ async def list_connected_accounts(
raise InvalidArgumentError("take", "The 'take' parameter must be a positive integer.")

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
params = {}
if connection:
params["connection"] = connection
Expand Down Expand Up @@ -243,7 +250,7 @@ async def delete_connected_account(
raise MissingRequiredArgumentError("connected_account_id")

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.delete(
url=f"{self.audience}v1/connected-accounts/accounts/{connected_account_id}",
auth=BearerAuth(access_token)
Expand Down Expand Up @@ -298,7 +305,7 @@ async def list_connected_account_connections(
raise InvalidArgumentError("take", "The 'take' parameter must be a positive integer.")

try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
params = {}
if from_param:
params["from"] = from_param
Expand Down
34 changes: 24 additions & 10 deletions src/auth0_server_python/auth_server/server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
PollingApiError,
StartLinkUserError,
)
from auth0_server_python.telemetry import Telemetry
from auth0_server_python.utils import PKCE, URL, State
from auth0_server_python.utils.helpers import (
build_domain_resolver_context,
Expand Down Expand Up @@ -152,13 +153,20 @@ def __init__(
self._transaction_identifier = transaction_identifier
self._state_identifier = state_identifier

# Initialize telemetry
self._telemetry = Telemetry.default()
self._telemetry_headers = self._telemetry.headers

# Initialize OAuth client
self._oauth = AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
headers=self._telemetry_headers,
)

self._my_account_client = MyAccountClient(domain=domain)
self._my_account_client = MyAccountClient(
domain=domain, headers=self._telemetry_headers
)

# Unified cache for OIDC metadata and JWKS per domain (LRU eviction + TTL)
self._discovery_cache: OrderedDict[str, dict] = OrderedDict()
Expand All @@ -172,9 +180,15 @@ def __init__(
client_secret=self._client_secret,
secret=self._secret,
state_store=self._state_store,
state_identifier=self._state_identifier
state_identifier=self._state_identifier,
headers=self._telemetry_headers,
)

def _get_http_client(self, **kwargs) -> httpx.AsyncClient:
"""Return an httpx.AsyncClient with telemetry headers injected."""
headers = {**kwargs.pop("headers", {}), **self._telemetry_headers}
return httpx.AsyncClient(headers=headers, **kwargs)

def _normalize_url(self, value: str) -> str:
"""
Normalize a URL-like value (domain or issuer) for comparison.
Expand Down Expand Up @@ -281,7 +295,7 @@ async def _fetch_oidc_metadata(self, domain: str) -> dict:
"""Fetch OIDC metadata from domain."""
normalized_domain = self._normalize_url(domain)
metadata_url = f"{normalized_domain}/.well-known/openid-configuration"
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.get(metadata_url)
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -352,7 +366,7 @@ async def _fetch_jwks(self, jwks_uri: str) -> dict:
ApiError: If JWKS fetch fails
"""
try:
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.get(jwks_uri)
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -516,7 +530,7 @@ async def start_interactive_login(

auth_params["client_id"] = self._client_id
# Post the auth_params to the PAR endpoint
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
par_response = await client.post(
par_endpoint,
data=auth_params,
Expand Down Expand Up @@ -1077,7 +1091,7 @@ async def get_token_by_refresh_token(self, options: dict[str, Any]) -> dict[str,
token_params["scope"] = merged_scope

# Exchange the refresh token for an access token
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
token_endpoint,
data=token_params,
Expand Down Expand Up @@ -1391,7 +1405,7 @@ async def initiate_backchannel_authentication(
params.update(authorization_params)

# Make the backchannel authentication request
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
backchannel_response = await client.post(
backchannel_endpoint,
data=params,
Expand Down Expand Up @@ -1466,7 +1480,7 @@ async def backchannel_authentication_grant(
}

# Exchange the auth_req_id for an access token
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
token_endpoint,
data=token_params,
Expand Down Expand Up @@ -1918,7 +1932,7 @@ async def get_token_for_connection(self, options: dict[str, Any]) -> dict[str, A
params["login_hint"] = options["login_hint"]

# Make the request
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
token_endpoint,
data=params,
Expand Down Expand Up @@ -2272,7 +2286,7 @@ async def custom_token_exchange(
params[key] = value

# Make the token exchange request
async with httpx.AsyncClient() as client:
async with self._get_http_client() as client:
response = await client.post(
token_endpoint,
data=params,
Expand Down
39 changes: 39 additions & 0 deletions src/auth0_server_python/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Telemetry support for auth0-server-python SDK.

Builds and caches the Auth0-Client and User-Agent headers sent
on every HTTP request to Auth0 endpoints.
"""

import base64
import importlib.metadata
import json
import platform
from typing import Optional


class Telemetry:
"""Builds telemetry headers for Auth0 HTTP requests."""

_PACKAGE_NAME = "auth0-server-python"

def __init__(self, name: str, version: str, env: Optional[dict[str, str]] = None):
self.name = name
self.version = version
self.env = env if env is not None else {"python": platform.python_version()}
payload = {"name": self.name, "version": self.version, "env": self.env}
self.headers: dict[str, str] = {
"Auth0-Client": base64.b64encode(
json.dumps(payload).encode("utf-8")
).decode("utf-8"),
"User-Agent": f"Python/{platform.python_version()}",
}

@staticmethod
def default() -> "Telemetry":
"""Create a Telemetry instance with this SDK's package metadata."""
try:
version = importlib.metadata.version(Telemetry._PACKAGE_NAME)
except importlib.metadata.PackageNotFoundError:
version = "unknown"
return Telemetry(name=Telemetry._PACKAGE_NAME, version=version)
Loading
Loading