From 22c09ed7acb788d4d88d65fe1d93c0977ebfeb8e Mon Sep 17 00:00:00 2001 From: Nicolas Borges Date: Mon, 13 Apr 2026 15:54:28 -0400 Subject: [PATCH] feat: add runtime passthrough support --- .../runtime/agent_core_runtime_client.py | 226 +++++++++++++- .../runtime/test_agent_core_runtime_client.py | 295 ++++++++++++++++++ 2 files changed, 512 insertions(+), 9 deletions(-) diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py index e905119f..de7b90b8 100644 --- a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py +++ b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py @@ -8,15 +8,20 @@ import datetime import logging import secrets +import time import uuid -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple from urllib.parse import quote, urlencode, urlparse import boto3 from botocore.auth import SigV4Auth, SigV4QueryAuth from botocore.awsrequest import AWSRequest +from botocore.config import Config +from botocore.exceptions import ClientError from .._utils.endpoints import get_data_plane_endpoint +from .._utils.snake_case import accept_snake_case_kwargs +from .._utils.user_agent import build_user_agent_suffix from .utils import is_valid_partition DEFAULT_PRESIGNED_URL_TIMEOUT = 300 @@ -35,21 +40,224 @@ class AgentCoreRuntimeClient: session (boto3.Session): The boto3 session for AWS credentials. """ - def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None: + _ALLOWED_DP_METHODS = { + "invoke_agent_runtime", + "stop_runtime_session", + } + + _ALLOWED_CP_METHODS = { + "create_agent_runtime", + "update_agent_runtime", + "get_agent_runtime", + "get_agent_runtime_endpoint", + "delete_agent_runtime", + "delete_agent_runtime_endpoint", + } + + def __init__( + self, + region: Optional[str] = None, + session: Optional[boto3.Session] = None, + integration_source: Optional[str] = None, + ) -> None: """Initialize an AgentCoreRuntime client for the specified AWS region. Args: - region (str): The AWS region to use for the AgentCore Runtime service. - session (Optional[boto3.Session]): Optional boto3 session. If not provided, - a new session will be created using default credentials. + region: AWS region name. If not provided, uses the session's region or "us-west-2". + session: Optional boto3 Session to use. If not provided, a default session is created + integration_source: Optional integration source for user-agent telemetry. """ - self.region = region + session = session if session else boto3.Session() + self.region = region or session.region_name or "us-west-2" + self.session = session self.logger = logging.getLogger(__name__) + self.integration_source = integration_source - if session is None: - session = boto3.Session() + user_agent_extra = build_user_agent_suffix(integration_source) + client_config = Config(user_agent_extra=user_agent_extra) - self.session = session + self.cp_client = session.client("bedrock-agentcore-control", region_name=self.region, config=client_config) + self.dp_client = session.client("bedrock-agentcore", region_name=self.region, config=client_config) + self.logger.info( + "Initialized AgentCoreRuntimeClient for control plane: %s, data plane: %s", + self.cp_client.meta.region_name, + self.dp_client.meta.region_name, + ) + + def __getattr__(self, name: str): + """Dynamically forward allowlisted method calls to the appropriate boto3 client. + + Methods are looked up in the following order: + 1. dp_client (bedrock-agentcore) - for data plane operations + 2. cp_client (bedrock-agentcore-control) - for control plane operations + """ + if name in self._ALLOWED_DP_METHODS and hasattr(self.dp_client, name): + method = getattr(self.dp_client, name) + self.logger.debug("Forwarding method '%s' to dp_client", name) + return accept_snake_case_kwargs(method) + + if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name): + method = getattr(self.cp_client, name) + self.logger.debug("Forwarding method '%s' to cp_client", name) + return accept_snake_case_kwargs(method) + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'. " + f"Method not found on dp_client or cp_client. " + f"Available methods can be found in the boto3 documentation for " + f"'bedrock-agentcore' and 'bedrock-agentcore-control' services." + ) + + def wait_for_runtime_ready( + self, + agent_runtime_id: str, + max_wait: int = 300, + poll_interval: int = 5, + ) -> Dict[str, Any]: + """Wait for an agent runtime to reach READY status. + + Args: + agent_runtime_id: The agent runtime ID. + max_wait: Maximum seconds to wait (default: 300). + poll_interval: Seconds between status checks (default: 5). + + Returns: + Runtime response dict when READY. + + Raises: + RuntimeError: If the runtime reaches a failed state. + TimeoutError: If the runtime doesn't become READY within max_wait. + """ + start_time = time.time() + while time.time() - start_time < max_wait: + try: + resp = self.cp_client.get_agent_runtime(agentRuntimeId=agent_runtime_id) + status = resp.get("status", "UNKNOWN") + + if status == "READY": + self.logger.info("Runtime %s is READY", agent_runtime_id) + return resp + elif status in ("CREATE_FAILED", "UPDATE_FAILED", "FAILED"): + raise RuntimeError( + "Runtime %s: %s" % (status.lower().replace("_", " "), resp.get("failureReason", "Unknown")) + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise + time.sleep(poll_interval) + + raise TimeoutError("Runtime %s did not become READY within %d seconds" % (agent_runtime_id, max_wait)) + + def wait_for_endpoint_ready( + self, + agent_runtime_id: str, + endpoint_name: str = "DEFAULT", + max_wait: int = 120, + poll_interval: int = 1, + ) -> Dict[str, Any]: + """Wait for an agent runtime endpoint to reach READY status. + + Args: + agent_runtime_id: The agent runtime ID. + endpoint_name: Endpoint name (default: "DEFAULT"). + max_wait: Maximum seconds to wait (default: 120). + poll_interval: Seconds between status checks (default: 1). + + Returns: + Endpoint response dict when READY. + + Raises: + RuntimeError: If the endpoint reaches a failed state. + TimeoutError: If the endpoint doesn't become READY within max_wait. + """ + start_time = time.time() + while time.time() - start_time < max_wait: + try: + resp = self.cp_client.get_agent_runtime_endpoint( + agentRuntimeId=agent_runtime_id, + endpointName=endpoint_name, + ) + status = resp.get("status", "UNKNOWN") + + if status == "READY": + self.logger.info("Endpoint '%s' is READY", endpoint_name) + return resp + elif status in ("CREATE_FAILED", "UPDATE_FAILED"): + raise RuntimeError( + "Endpoint %s: %s" % (status.lower().replace("_", " "), resp.get("failureReason", "Unknown")) + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise + time.sleep(poll_interval) + + raise TimeoutError("Endpoint '%s' did not become READY within %d seconds" % (endpoint_name, max_wait)) + + def get_aggregated_status( + self, + agent_runtime_id: str, + endpoint_name: str = "DEFAULT", + ) -> Dict[str, Any]: + """Get aggregated status of runtime and endpoint. + + Args: + agent_runtime_id: The agent runtime ID. + endpoint_name: Endpoint name (default: "DEFAULT"). + + Returns: + Dict with 'runtime' and 'endpoint' status details. + """ + result: Dict[str, Any] = {"runtime": None, "endpoint": None} + + try: + result["runtime"] = self.cp_client.get_agent_runtime(agentRuntimeId=agent_runtime_id) + except ClientError as e: + result["runtime"] = {"error": str(e)} + + try: + result["endpoint"] = self.cp_client.get_agent_runtime_endpoint( + agentRuntimeId=agent_runtime_id, + endpointName=endpoint_name, + ) + except ClientError as e: + result["endpoint"] = {"error": str(e)} + + return result + + def teardown( + self, + agent_runtime_id: str, + endpoint_name: str = "DEFAULT", + ) -> None: + """Delete endpoint then runtime in correct order. + + Silently ignores ResourceNotFoundException for either resource + (already deleted). + + Args: + agent_runtime_id: The agent runtime ID. + endpoint_name: Endpoint name (default: "DEFAULT"). + """ + # Delete endpoint + try: + self.cp_client.delete_agent_runtime_endpoint( + agentRuntimeId=agent_runtime_id, + endpointName=endpoint_name, + ) + self.logger.info("Deleted endpoint '%s' for runtime %s", endpoint_name, agent_runtime_id) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise + self.logger.info("Endpoint '%s' not found, skipping", endpoint_name) + + # Delete runtime + try: + self.cp_client.delete_agent_runtime(agentRuntimeId=agent_runtime_id) + self.logger.info("Deleted runtime %s", agent_runtime_id) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise + self.logger.info("Runtime %s not found, skipping", agent_runtime_id) def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]: """Parse runtime ARN and extract components. diff --git a/tests/unit/runtime/test_agent_core_runtime_client.py b/tests/unit/runtime/test_agent_core_runtime_client.py index 4cb33085..e924b0ea 100644 --- a/tests/unit/runtime/test_agent_core_runtime_client.py +++ b/tests/unit/runtime/test_agent_core_runtime_client.py @@ -4,6 +4,7 @@ from urllib.parse import quote import pytest +from botocore.exceptions import ClientError from bedrock_agentcore.runtime.agent_core_runtime_client import AgentCoreRuntimeClient @@ -21,6 +22,300 @@ def test_init_creates_logger(self): client = AgentCoreRuntimeClient(region="us-west-2") assert client.logger is not None + def test_init_default_region(self): + """Test that region defaults to session region or us-west-2.""" + mock_session = Mock() + mock_session.region_name = "eu-west-1" + client = AgentCoreRuntimeClient(session=mock_session) + assert client.region == "eu-west-1" + + def test_init_default_region_fallback(self): + """Test that region falls back to us-west-2 when session has no region.""" + mock_session = Mock() + mock_session.region_name = None + client = AgentCoreRuntimeClient(session=mock_session) + assert client.region == "us-west-2" + + def test_init_creates_boto3_clients(self): + """Test that initialization creates cp_client and dp_client.""" + mock_session = Mock() + mock_session.region_name = "us-west-2" + AgentCoreRuntimeClient(region="us-west-2", session=mock_session) + + assert mock_session.client.call_count == 2 + call_args = [call[0][0] for call in mock_session.client.call_args_list] + assert "bedrock-agentcore-control" in call_args + assert "bedrock-agentcore" in call_args + + def test_init_with_integration_source(self): + """Test that integration_source is stored and passed to config.""" + mock_session = Mock() + mock_session.region_name = "us-west-2" + client = AgentCoreRuntimeClient(region="us-west-2", session=mock_session, integration_source="langchain") + assert client.integration_source == "langchain" + + +class TestAgentCoreRuntimeClientPassthrough: + """Tests for __getattr__ passthrough to boto3 clients.""" + + def _make_client(self): + mock_session = Mock() + mock_session.region_name = "us-west-2" + client = AgentCoreRuntimeClient(region="us-west-2", session=mock_session) + client.cp_client = Mock() + client.dp_client = Mock() + return client + + def test_cp_method_forwarded(self): + """Test that allowlisted CP methods forward to cp_client.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = {"agentRuntimeId": "rt-123"} + + result = client.get_agent_runtime(agentRuntimeId="rt-123") + + client.cp_client.get_agent_runtime.assert_called_once_with(agentRuntimeId="rt-123") + assert result["agentRuntimeId"] == "rt-123" + + def test_dp_method_forwarded(self): + """Test that allowlisted DP methods forward to dp_client.""" + client = self._make_client() + client.dp_client.invoke_agent_runtime.return_value = {"response": "ok"} + + result = client.invoke_agent_runtime(agentRuntimeArn="arn:test") + + client.dp_client.invoke_agent_runtime.assert_called_once_with(agentRuntimeArn="arn:test") + assert result["response"] == "ok" + + def test_snake_case_kwargs_converted(self): + """Test that snake_case kwargs are converted to camelCase.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = {"agentRuntimeId": "rt-123"} + + client.get_agent_runtime(agent_runtime_id="rt-123") + + client.cp_client.get_agent_runtime.assert_called_once_with(agentRuntimeId="rt-123") + + def test_non_allowlisted_method_raises_attribute_error(self): + """Test that non-allowlisted methods raise AttributeError.""" + client = self._make_client() + + with pytest.raises(AttributeError, match="has no attribute 'not_a_real_method'"): + client.not_a_real_method() + + def test_all_cp_methods_in_allowlist(self): + """Test all expected CP methods are in the allowlist.""" + expected = { + "create_agent_runtime", + "update_agent_runtime", + "get_agent_runtime", + "get_agent_runtime_endpoint", + "delete_agent_runtime", + "delete_agent_runtime_endpoint", + } + assert expected == AgentCoreRuntimeClient._ALLOWED_CP_METHODS + + def test_all_dp_methods_in_allowlist(self): + """Test all expected DP methods are in the allowlist.""" + expected = { + "invoke_agent_runtime", + "stop_runtime_session", + } + assert expected == AgentCoreRuntimeClient._ALLOWED_DP_METHODS + + +class TestAgentCoreRuntimeClientHighLevel: + """Tests for higher-level abstraction methods.""" + + def _make_client(self): + mock_session = Mock() + mock_session.region_name = "us-west-2" + client = AgentCoreRuntimeClient(region="us-west-2", session=mock_session) + client.cp_client = Mock() + client.dp_client = Mock() + return client + + def test_wait_for_endpoint_ready_immediate(self): + """Test wait_for_endpoint_ready when endpoint is already READY.""" + client = self._make_client() + client.cp_client.get_agent_runtime_endpoint.return_value = { + "status": "READY", + "agentRuntimeEndpointArn": "arn:test", + } + + result = client.wait_for_endpoint_ready("rt-123") + + assert result["status"] == "READY" + client.cp_client.get_agent_runtime_endpoint.assert_called_once_with( + agentRuntimeId="rt-123", endpointName="DEFAULT" + ) + + def test_wait_for_runtime_ready_immediate(self): + """Test wait_for_runtime_ready when runtime is already READY.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = { + "status": "READY", + "agentRuntimeId": "rt-123", + } + + result = client.wait_for_runtime_ready("rt-123") + + assert result["status"] == "READY" + client.cp_client.get_agent_runtime.assert_called_once_with(agentRuntimeId="rt-123") + + @patch("time.sleep") + @patch("time.time", side_effect=[0, 0, 0, 1, 1]) + def test_wait_for_runtime_ready_after_creating(self, _mock_time, _mock_sleep): + """Test wait_for_runtime_ready polls through CREATING status.""" + client = self._make_client() + client.cp_client.get_agent_runtime.side_effect = [ + {"status": "CREATING"}, + {"status": "READY", "agentRuntimeId": "rt-123"}, + ] + + result = client.wait_for_runtime_ready("rt-123") + + assert result["status"] == "READY" + assert client.cp_client.get_agent_runtime.call_count == 2 + + def test_wait_for_runtime_ready_create_failed(self): + """Test wait_for_runtime_ready raises on CREATE_FAILED.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = { + "status": "CREATE_FAILED", + "failureReason": "Bad config", + } + + with pytest.raises(RuntimeError, match="Bad config"): + client.wait_for_runtime_ready("rt-123") + + @patch("time.sleep") + @patch("time.time", side_effect=[0, 0, 0, 301]) + def test_wait_for_runtime_ready_timeout(self, _mock_time, _mock_sleep): + """Test wait_for_runtime_ready raises TimeoutError.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = {"status": "CREATING"} + + with pytest.raises(TimeoutError, match="did not become READY"): + client.wait_for_runtime_ready("rt-123", max_wait=300) + + @patch("time.sleep") + @patch("time.time", side_effect=[0, 0, 0, 1, 1]) + def test_wait_for_endpoint_ready_after_creating(self, _mock_time, _mock_sleep): + """Test wait_for_endpoint_ready polls through CREATING status.""" + client = self._make_client() + client.cp_client.get_agent_runtime_endpoint.side_effect = [ + {"status": "CREATING"}, + {"status": "READY", "agentRuntimeEndpointArn": "arn:test"}, + ] + + result = client.wait_for_endpoint_ready("rt-123") + + assert result["status"] == "READY" + assert client.cp_client.get_agent_runtime_endpoint.call_count == 2 + + def test_wait_for_endpoint_ready_create_failed(self): + """Test wait_for_endpoint_ready raises on CREATE_FAILED.""" + client = self._make_client() + client.cp_client.get_agent_runtime_endpoint.return_value = { + "status": "CREATE_FAILED", + "failureReason": "Bad config", + } + + with pytest.raises(RuntimeError, match="Bad config"): + client.wait_for_endpoint_ready("rt-123") + + @patch("time.sleep") + @patch("time.time", side_effect=[0, 0, 0, 121]) + def test_wait_for_endpoint_ready_timeout(self, _mock_time, _mock_sleep): + """Test wait_for_endpoint_ready raises TimeoutError.""" + client = self._make_client() + client.cp_client.get_agent_runtime_endpoint.return_value = {"status": "CREATING"} + + with pytest.raises(TimeoutError, match="did not become READY"): + client.wait_for_endpoint_ready("rt-123", max_wait=120) + + @patch("time.sleep") + @patch("time.time", side_effect=[0, 0, 0, 1, 1]) + def test_wait_for_endpoint_ready_not_found_then_ready(self, _mock_time, _mock_sleep): + """Test wait_for_endpoint_ready handles ResourceNotFoundException during polling.""" + client = self._make_client() + not_found = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetAgentRuntimeEndpoint", + ) + client.cp_client.get_agent_runtime_endpoint.side_effect = [ + not_found, + {"status": "READY", "agentRuntimeEndpointArn": "arn:test"}, + ] + + result = client.wait_for_endpoint_ready("rt-123") + + assert result["status"] == "READY" + + def test_get_aggregated_status_success(self): + """Test get_aggregated_status returns both runtime and endpoint.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = {"status": "ACTIVE"} + client.cp_client.get_agent_runtime_endpoint.return_value = {"status": "READY"} + + result = client.get_aggregated_status("rt-123") + + assert result["runtime"]["status"] == "ACTIVE" + assert result["endpoint"]["status"] == "READY" + + def test_get_aggregated_status_partial_failure(self): + """Test get_aggregated_status captures errors without raising.""" + client = self._make_client() + client.cp_client.get_agent_runtime.return_value = {"status": "ACTIVE"} + client.cp_client.get_agent_runtime_endpoint.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetAgentRuntimeEndpoint", + ) + + result = client.get_aggregated_status("rt-123") + + assert result["runtime"]["status"] == "ACTIVE" + assert "error" in result["endpoint"] + + def test_teardown_deletes_endpoint_then_runtime(self): + """Test teardown deletes in correct order.""" + client = self._make_client() + + client.teardown("rt-123") + + # Verify order: endpoint first, then runtime + calls = client.cp_client.method_calls + assert calls[0] == ( + "delete_agent_runtime_endpoint", + (), + {"agentRuntimeId": "rt-123", "endpointName": "DEFAULT"}, + ) + assert calls[1] == ("delete_agent_runtime", (), {"agentRuntimeId": "rt-123"}) + + def test_teardown_endpoint_not_found_continues(self): + """Test teardown continues if endpoint already deleted.""" + client = self._make_client() + client.cp_client.delete_agent_runtime_endpoint.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "DeleteAgentRuntimeEndpoint", + ) + + client.teardown("rt-123") + + # Should still delete runtime + client.cp_client.delete_agent_runtime.assert_called_once_with(agentRuntimeId="rt-123") + + def test_teardown_runtime_not_found_continues(self): + """Test teardown doesn't raise if runtime already deleted.""" + client = self._make_client() + client.cp_client.delete_agent_runtime.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "DeleteAgentRuntime", + ) + + # Should not raise + client.teardown("rt-123") + class TestParseRuntimeArn: """Tests for _parse_runtime_arn helper."""