diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index a486215151..236465a623 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -19,6 +19,7 @@ from collections.abc import Generator from enum import Enum import logging +import os import re from typing import Any from typing import Callable @@ -39,9 +40,11 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import google.auth -import google.auth.transport.requests +from google.auth.transport import mtls +from google.auth.transport import requests as requests_auth import httpx from mcp import StdioServerParameters +import requests from typing_extensions import override # pylint: disable=g-import-not-at-top @@ -61,6 +64,9 @@ logger = logging.getLogger("google_adk." + __name__) AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" +AGENT_REGISTRY_MTLS_BASE_URL = ( + "https://agentregistry.mtls.googleapis.com/v1alpha" +) _TRANSPORT_MAPPING = { "HTTP_JSON": A2ATransport.http_json, @@ -120,6 +126,14 @@ async def get_tools( return tools +class _MtlsEndpoint(Enum): + """The mTLS endpoint setting.""" + + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" + + class _ProtocolType(str, Enum): """Supported agent protocol types.""" @@ -224,23 +238,40 @@ def _make_request( self, path: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: """Helper function to make GET requests to the Agent Registry API.""" + # Determine if mTLS should be used + session = requests_auth.AuthorizedSession(credentials=self._credentials) + + use_client_cert = _use_client_cert_effective() + client_cert_source = None + + if use_client_cert: + client_cert_source = ( + mtls.default_client_cert_source() + if mtls.has_default_client_cert_source() + else None + ) + session.configure_mtls_channel() + + base_url = _get_agent_registry_base_url(client_cert_source) + if path.startswith("projects/"): - url = f"{AGENT_REGISTRY_BASE_URL}/{path}" + url = f"{base_url}/{path}" else: - url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}" + url = f"{base_url}/{self._base_path}/{path}" try: - headers = self._get_auth_headers() - with httpx.Client() as client: - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - return response.json() - except httpx.HTTPStatusError as e: + # Using AuthorizedSession for internal API calls to handle mTLS/Auth. + response = session.get( + url, headers=self._get_auth_headers(), params=params + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: raise RuntimeError( f"API request failed with status {e.response.status_code}:" f" {e.response.text}" ) from e - except httpx.RequestError as e: + except requests.exceptions.RequestException as e: raise RuntimeError(f"API request failed (network error): {e}") from e except Exception as e: raise RuntimeError(f"API request failed: {e}") from e @@ -520,3 +551,33 @@ def get_remote_a2a_agent( description=description, httpx_client=httpx_client, ) + + +def _use_client_cert_effective() -> bool: + """Returns whether client certificate should be used for mTLS.""" + try: + return bool(mtls.should_use_client_cert()) + except (ImportError, AttributeError): + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + return use_client_cert_str == "true" + + +def _get_agent_registry_base_url(client_cert_source: Any | None = None) -> str: + """Returns the base URL based on mTLS configuration and cert availability.""" + use_mtls_endpoint_str = os.getenv( + "GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value + ).lower() + + try: + use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) + except ValueError: + use_mtls_endpoint = _MtlsEndpoint.AUTO + + if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or ( + use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source + ): + return AGENT_REGISTRY_MTLS_BASE_URL + + return AGENT_REGISTRY_BASE_URL diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index f4ba47cf25..bd8288c483 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -13,6 +13,7 @@ # limitations under the License. +import os from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -26,28 +27,32 @@ from google.adk.integrations.agent_registry.agent_registry import _ProtocolType from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from google.auth.transport import requests as requests_auth import httpx from mcp import ClientSession from mcp.types import ListToolsResult from mcp.types import Tool import pytest +import requests class TestAgentRegistry: @pytest.fixture def registry(self): - with patch("google.auth.default", return_value=(MagicMock(), "project-id")): + mock_creds = MagicMock() + mock_creds.quota_project_id = None + with patch("google.auth.default", return_value=(mock_creds, "project-id")): return AgentRegistry(project_id="test-project", location="global") @pytest.mark.asyncio - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_adds_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, mock_session_class, registry ): """Test that tools from get_mcp_toolset have the destination ID.""" # Arrange @@ -63,9 +68,7 @@ async def test_get_mcp_toolset_adds_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + mock_session_class.return_value.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -109,13 +112,13 @@ async def test_get_mcp_toolset_adds_destination_id( ) @pytest.mark.asyncio - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") @patch( "google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session", new_callable=AsyncMock, ) async def test_get_mcp_toolset_handles_missing_destination_id( - self, mock_create_session, mock_httpx, registry + self, mock_create_session, mock_session_class, registry ): """Test get_mcp_toolset when the destination ID is missing.""" # Arrange @@ -129,9 +132,7 @@ async def test_get_mcp_toolset_handles_missing_destination_id( "protocolBinding": "JSONRPC", }], } - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_api_response - ) + mock_session_class.return_value.get.return_value = mock_api_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -258,14 +259,12 @@ def test_get_connection_uri_returns_none_if_no_url_in_interfaces( assert version is None assert binding is None - @patch("httpx.Client") - def test_list_agents(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_list_agents(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"agents": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -274,14 +273,12 @@ def test_list_agents(self, mock_httpx, registry): agents = registry.list_agents() assert agents == {"agents": []} - @patch("httpx.Client") - def test_get_mcp_server(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_mcp_server(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-mcp"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -289,14 +286,12 @@ def test_get_mcp_server(self, mock_httpx, registry): server = registry.get_mcp_server("test-mcp") assert server == {"name": "test-mcp"} - @patch("httpx.Client") - def test_list_endpoints(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_list_endpoints(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"endpoints": []} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response # Mock auth refresh registry._credentials.token = "token" @@ -305,14 +300,12 @@ def test_list_endpoints(self, mock_httpx, registry): endpoints = registry.list_endpoints() assert endpoints == {"endpoints": []} - @patch("httpx.Client") - def test_get_endpoint(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_endpoint(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = {"name": "test-endpoint"} mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -329,9 +322,14 @@ def test_get_endpoint(self, mock_httpx, registry): ("https://mcp.googleapis.com/v1", True, True), ], ) - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_mcp_toolset_auth_headers( - self, mock_httpx, registry, url, expected_auth, use_custom_provider + self, + mock_session_class, + registry, + url, + expected_auth, + use_custom_provider, ): mock_response = MagicMock() mock_response.json.return_value = { @@ -342,16 +340,17 @@ def test_get_mcp_toolset_auth_headers( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response + + mock_creds = MagicMock() + mock_creds.quota_project_id = None if use_custom_provider: custom_header_provider = lambda context: { "Authorization": "Bearer custom_token" } with patch( - "google.auth.default", return_value=(MagicMock(), "project-id") + "google.auth.default", return_value=(mock_creds, "project-id") ): registry = AgentRegistry( project_id="test-project", @@ -375,8 +374,8 @@ def test_get_mcp_toolset_auth_headers( else: assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_mcp_toolset_with_auth(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestPrefix", @@ -386,9 +385,7 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -408,9 +405,9 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry): assert auth_config.auth_scheme == auth_scheme assert auth_config.raw_auth_credential == auth_credential - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_mcp_toolset_with_auth_blocks_gcp_headers( - self, mock_httpx, registry + self, mock_session_class, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -421,9 +418,7 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -442,8 +437,8 @@ def test_get_mcp_toolset_with_auth_blocks_gcp_headers( headers = toolset._header_provider(MagicMock()) assert "Authorization" not in headers - @patch("httpx.Client") - def test_get_remote_a2a_agent(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -460,9 +455,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -478,8 +471,8 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.4.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_defaults(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -493,9 +486,7 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -505,8 +496,8 @@ def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry): assert agent._agent_card.preferred_transport == A2ATransport.http_json assert agent._agent_card.protocol_version == "0.3.0" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_with_card(self, mock_session_class, registry): mock_response = MagicMock() mock_response.json.return_value = { "name": "projects/p/locations/l/agents/a", @@ -530,9 +521,7 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): }, } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -547,8 +536,10 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): assert len(agent._agent_card.skills) == 1 assert agent._agent_card.skills[0].name == "S1" - @patch("httpx.Client") - def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_get_remote_a2a_agent_with_httpx_client( + self, mock_session_class, registry + ): mock_response = MagicMock() mock_response.json.return_value = { "displayName": "TestAgent", @@ -562,9 +553,7 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response custom_client = httpx.AsyncClient() agent = registry.get_remote_a2a_agent( @@ -572,9 +561,9 @@ def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry): ) assert agent._httpx_client is custom_client - @patch("httpx.Client") + @patch("google.auth.transport.requests.AuthorizedSession") def test_get_remote_a2a_agent_configures_transports( - self, mock_httpx, registry + self, mock_session_class, registry ): mock_response = MagicMock() mock_response.json.return_value = { @@ -588,9 +577,7 @@ def test_get_remote_a2a_agent_configures_transports( }], } mock_response.raise_for_status = MagicMock() - mock_httpx.return_value.__enter__.return_value.get.return_value = ( - mock_response - ) + mock_session_class.return_value.get.return_value = mock_response registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -616,15 +603,17 @@ def test_get_auth_headers_fallback_to_project_id(self, registry): assert headers["Authorization"] == "Bearer fake-token" assert headers["x-goog-user-project"] == "test-project" - @patch("httpx.Client") - def test_make_request_raises_http_status_error(self, mock_httpx, registry): + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_http_status_error( + self, mock_session_class, registry + ): mock_response = MagicMock() mock_response.status_code = 404 mock_response.text = "Not Found" - error = httpx.HTTPStatusError( + error = requests.exceptions.HTTPError( "Error", request=MagicMock(), response=mock_response ) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + mock_session_class.return_value.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -634,10 +623,14 @@ def test_make_request_raises_http_status_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_request_error(self, mock_httpx, registry): - error = httpx.RequestError("Connection failed", request=MagicMock()) - mock_httpx.return_value.__enter__.return_value.get.side_effect = error + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_request_error( + self, mock_session_class, registry + ): + error = requests.exceptions.RequestException( + "Connection failed", request=MagicMock() + ) + mock_session_class.return_value.get.side_effect = error registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -647,11 +640,11 @@ def test_make_request_raises_request_error(self, mock_httpx, registry): ): registry._make_request("test-path") - @patch("httpx.Client") - def test_make_request_raises_generic_exception(self, mock_httpx, registry): - mock_httpx.return_value.__enter__.return_value.get.side_effect = Exception( - "Generic error" - ) + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_raises_generic_exception( + self, mock_session_class, registry + ): + mock_session_class.return_value.get.side_effect = Exception("Generic error") registry._credentials.token = "token" registry._credentials.refresh = MagicMock() @@ -741,3 +734,103 @@ def side_effect(*args, **kwargs): == "projects/123/locations/l/authProviders/ap-789" ) assert toolset._auth_scheme.continue_uri == "https://override.com/continue" + + +class TestAgentRegistryMtls: + + @pytest.fixture + def registry(self): + with patch( + "google.auth.default", return_value=(MagicMock(), "test-project") + ): + return AgentRegistry(project_id="test-project", location="global") + + @patch("google.auth.transport.requests.AuthorizedSession") + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ) + def test_make_request_uses_authorized_session_no_mtls( + self, mock_has_cert, mock_session_class, registry + ): + """Verifies that AuthorizedSession is used for standard requests.""" + mock_session = mock_session_class.return_value + mock_response = MagicMock() + mock_response.json.return_value = {"key": "value"} + mock_session.get.return_value = mock_response + + result = registry._make_request("test-path") + + # Assert session initialization and usage + mock_session_class.assert_called_once_with( + credentials=registry._credentials + ) + mock_session.get.assert_called_once() + assert mock_session.configure_mtls_channel.call_count == 0 + assert result == {"key": "value"} + + @patch("google.auth.transport.requests.AuthorizedSession") + @patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ) + @patch("google.auth.transport.mtls.default_client_cert_source") + @patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_make_request_configures_mtls( + self, mock_cert_source, mock_has_cert, mock_session_class, registry + ): + """Verifies that mTLS is configured when supported and enabled.""" + mock_session = mock_session_class.return_value + mock_cert_source.return_value = lambda: (b"cert", b"key") + + registry._make_request("test-path") + + # Verify mTLS configuration and endpoint + mock_session.configure_mtls_channel.assert_called_once() + args, kwargs = mock_session.get.call_args + assert "agentregistry.mtls.googleapis.com" in args[0] + + @pytest.mark.parametrize( + "env_val, has_cert, expected", + [ + ("true", True, True), + ("true", False, True), + ("false", True, False), + ("false", False, False), + ], + ) + def test_use_client_cert_effective( + self, env_val, has_cert, expected, registry + ): + """Tests the logic for enabling mTLS based on env vars and cert availability.""" + with patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": env_val}): + with patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=has_cert, + ): + from google.adk.integrations.agent_registry.agent_registry import _use_client_cert_effective + + assert _use_client_cert_effective() == expected + + def test_get_agent_registry_base_url(self, registry): + """Verifies correct base URL selection for mTLS vs non-mTLS.""" + from google.adk.integrations.agent_registry.agent_registry import _get_agent_registry_base_url + + # Non-mTLS + assert "agentregistry.googleapis.com" in _get_agent_registry_base_url(None) + + # mTLS + assert "agentregistry.mtls.googleapis.com" in _get_agent_registry_base_url( + lambda: True + ) + + @patch("google.auth.transport.requests.AuthorizedSession") + def test_make_request_error_handling(self, mock_session_class, registry): + """Ensures exceptions from AuthorizedSession are handled gracefully.""" + mock_session = mock_session_class.return_value + mock_session.get.side_effect = Exception("Connection error") + + with pytest.raises( + RuntimeError, match="API request failed: Connection error" + ): + registry._make_request("test-path")