Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions src/google/adk/integrations/agent_registry/agent_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Generator
from enum import Enum
import logging
import os
import re
from typing import Any
from typing import Callable
Expand All @@ -39,9 +40,11 @@
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
import google.auth
import google.auth.transport.requests
from google.auth.transport import mtls
from google.auth.transport import requests as requests_auth
import httpx
from mcp import StdioServerParameters
import requests
from typing_extensions import override

# pylint: disable=g-import-not-at-top
Expand All @@ -61,6 +64,9 @@
logger = logging.getLogger("google_adk." + __name__)

AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha"
AGENT_REGISTRY_MTLS_BASE_URL = (
"https://agentregistry.mtls.googleapis.com/v1alpha"
)

_TRANSPORT_MAPPING = {
"HTTP_JSON": A2ATransport.http_json,
Expand Down Expand Up @@ -120,6 +126,14 @@ async def get_tools(
return tools


class _MtlsEndpoint(Enum):
"""The mTLS endpoint setting."""

AUTO = "auto"
ALWAYS = "always"
NEVER = "never"


class _ProtocolType(str, Enum):
"""Supported agent protocol types."""

Expand Down Expand Up @@ -224,23 +238,40 @@ def _make_request(
self, path: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
"""Helper function to make GET requests to the Agent Registry API."""
# Determine if mTLS should be used
session = requests_auth.AuthorizedSession(credentials=self._credentials)

use_client_cert = _use_client_cert_effective()
client_cert_source = None

if use_client_cert:
client_cert_source = (
mtls.default_client_cert_source()
if mtls.has_default_client_cert_source()
else None
)
session.configure_mtls_channel()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommended to pass in client_cert_source: session.configure_mtls_channel(client_cert_source)


base_url = _get_agent_registry_base_url(client_cert_source)

if path.startswith("projects/"):
url = f"{AGENT_REGISTRY_BASE_URL}/{path}"
url = f"{base_url}/{path}"
else:
url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}"
url = f"{base_url}/{self._base_path}/{path}"

try:
headers = self._get_auth_headers()
with httpx.Client() as client:
response = client.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
# Using AuthorizedSession for internal API calls to handle mTLS/Auth.
response = session.get(
url, headers=self._get_auth_headers(), params=params
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't be manually injecting _get_auth_headers() into every request (I think it defeats the point of AuthorizedSession)

perhaps we should have something like

response = session.get(
    url, headers=self._get_request_headers(), params=params
)

)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
raise RuntimeError(
f"API request failed with status {e.response.status_code}:"
f" {e.response.text}"
) from e
except httpx.RequestError as e:
except requests.exceptions.RequestException as e:
raise RuntimeError(f"API request failed (network error): {e}") from e
except Exception as e:
raise RuntimeError(f"API request failed: {e}") from e
Expand Down Expand Up @@ -520,3 +551,33 @@ def get_remote_a2a_agent(
description=description,
httpx_client=httpx_client,
)


def _use_client_cert_effective() -> bool:
"""Returns whether client certificate should be used for mTLS."""
try:
return bool(mtls.should_use_client_cert())
except (ImportError, AttributeError):
use_client_cert_str = os.getenv(
"GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
).lower()
return use_client_cert_str == "true"


def _get_agent_registry_base_url(client_cert_source: Any | None = None) -> str:
"""Returns the base URL based on mTLS configuration and cert availability."""
use_mtls_endpoint_str = os.getenv(
"GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value
).lower()

try:
use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str)
except ValueError:
use_mtls_endpoint = _MtlsEndpoint.AUTO

if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or (
use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source
):
return AGENT_REGISTRY_MTLS_BASE_URL

return AGENT_REGISTRY_BASE_URL
Loading