diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index ebf8909d52..31a4b726a6 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +import copy import json import logging import sys @@ -36,6 +37,7 @@ from boto3.session import Session as Boto3Session from botocore.client import BaseClient from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError from pydantic import BaseModel if sys.version_info >= (3, 13): @@ -115,13 +117,20 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t translates to ``toolConfig.tools``. tool_choice: How the model should use tools, translates to ``toolConfig.toolChoice``. + response_format: Structured output format. Accepts a Pydantic BaseModel + subclass or an OpenAI-style dict schema + (``{"json_schema": {"name": ..., "schema": ...}}``). + When provided, the Converse API request includes + ``outputConfig.textFormat`` with the schema serialized as a JSON + string. ``ChatResponse.value`` will be populated with the parsed + model instance. Only supported on models that support + ``outputConfig.textFormat``. Unsupported models raise a ValueError. # Options not supported in Bedrock Converse API: seed: Not supported. frequency_penalty: Not supported. presence_penalty: Not supported. allow_multiple_tool_calls: Not supported (models handle parallel calls automatically). - response_format: Not directly supported (use model-specific prompting). user: Not supported. store: Not supported. logit_bias: Not supported. @@ -161,9 +170,6 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t allow_multiple_tool_calls: None # type: ignore[misc] """Not supported. Bedrock models handle parallel tool calls automatically.""" - response_format: None # type: ignore[misc] - """Not directly supported. Use model-specific prompting for JSON output.""" - user: None # type: ignore[misc] """Not supported in Bedrock Converse API.""" @@ -324,10 +330,28 @@ def _create_session(settings: BedrockSettings) -> Boto3Session: return Boto3Session(**session_kwargs) def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]: - response = self._bedrock_client.converse(**request) - if not isinstance(response, Mapping): - raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.") - return response + try: + response = self._bedrock_client.converse(**request) + if not isinstance(response, Mapping): + raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.") + return response + except ClientError as e: + error_details = e.response.get("Error", {}) + error_code = error_details.get("Code", "") + error_message = error_details.get("Message", "") + # "outputConfig" in error_message catches cases where Bedrock explicitly + # rejects the outputConfig field (unsupported model). Other ValidationExceptions + # (e.g. malformed schema shape, invalid property values) will not mention + # "outputConfig" and will bubble up as raw ClientError without being misdiagnosed. + if error_code == "ValidationException" and ( + "outputconfig" in error_message.lower() or "outputconfig" in str(e).lower() + ): + raise ValueError( + f"Model '{self.model}' does not support structured output via outputConfig.textFormat. " + "Check the model's Bedrock Converse outputConfig/textFormat support. " + f"AWS error Code: {error_code}. AWS error Message: {error_message}" + ) from e + raise @override def _inner_get_response( @@ -344,7 +368,7 @@ def _inner_get_response( # Streaming mode - simulate streaming by yielding a single update async def _stream() -> AsyncIterable[ChatResponseUpdate]: response = await asyncio.to_thread(self._invoke_converse, request) - parsed_response = self._process_converse_response(response) + parsed_response = self._process_converse_response(response, options) contents = list(parsed_response.messages[0].contents if parsed_response.messages else []) if parsed_response.usage_details: contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type] @@ -360,12 +384,12 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: raw_representation=parsed_response.raw_representation, ) - return self._build_response_stream(_stream()) + return self._build_response_stream(_stream(), response_format=options.get("response_format")) # Non-streaming mode async def _get_response() -> ChatResponse: raw_response = await asyncio.to_thread(self._invoke_converse, request) - return self._process_converse_response(raw_response) + return self._process_converse_response(raw_response, options) return _get_response() @@ -430,6 +454,9 @@ def _prepare_options( if tool_config: run_options["toolConfig"] = tool_config + if output_config := self._prepare_output_config(options.get("response_format")): + run_options["outputConfig"] = output_config + return run_options def _prepare_bedrock_messages( @@ -628,7 +655,9 @@ def _prepare_tools(self, tools: list[FunctionTool | MutableMapping[str, Any]] | def _generate_tool_call_id() -> str: return f"tool-call-{uuid4().hex}" - def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: + def _process_converse_response( + self, response: dict[str, Any], options: Mapping[str, Any] | None = None + ) -> ChatResponse: """Convert Bedrock Converse API response to ChatResponse.""" output = response.get("output") or {} message = output.get("message") or {} @@ -646,6 +675,7 @@ def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: usage_details=usage_details, model=model, finish_reason=finish_reason, + response_format=options.get("response_format") if options else None, raw_representation=response, ) @@ -728,6 +758,108 @@ def _map_finish_reason(self, reason: str | None) -> FinishReasonLiteral | None: return None return FINISH_REASON_MAP.get(reason.lower()) + def _prepare_output_config(self, response_format: Any | None) -> dict[str, Any] | None: + """Convert response_format into the AWS Bedrock outputConfig wire format. + + Args: + response_format: A Pydantic model class or a dict schema, or None. + + Returns: + A dict for the Converse API ``outputConfig`` parameter, or None if + response_format is not set. + """ + if response_format is None: + return None + + if isinstance(response_format, dict): + if "json_schema" in response_format: + # Shape A — OpenAI-style wrapper + json_schema_config = response_format["json_schema"] + schema_src = json_schema_config.get("schema", {}) + name = json_schema_config.get("name", "output_schema") + elif "schema" in response_format: + # Shape B — inner shape directly {"name": ..., "schema": ...} + schema_src = response_format["schema"] + name = response_format.get("name", "output_schema") + else: + # Shape C — assume entire dict is the raw schema + logger.warning( + "response_format dict has no 'json_schema' or 'schema' key; " + "treating entire dict as raw JSON schema." + ) + schema_src = response_format + name = "output_schema" + + if isinstance(schema_src, str): + schema_src = json.loads(schema_src) + schema = copy.deepcopy(schema_src) + else: + if not isinstance(response_format, type) or not issubclass(response_format, BaseModel): + raise TypeError( + "response_format must be None, a dict JSON schema, " + "or a Pydantic BaseModel subclass." + ) + # response_format is a Pydantic model class + schema = response_format.model_json_schema() + name = response_format.__name__ + + self._set_additional_properties_false(schema) + + json_schema: dict[str, Any] = { + "name": name, + "schema": json.dumps(schema), + } + + description = getattr(response_format, "__doc__", None) if not isinstance(response_format, dict) else None + if description and isinstance(description, str) and description.strip(): + json_schema["description"] = description.strip() + + return { + "textFormat": { + "type": "json_schema", + "structure": { + "jsonSchema": json_schema + }, + } + } + + def _set_additional_properties_false(self, schema: dict[str, Any]) -> None: + """Recursively set additionalProperties: false on all object types in a JSON schema. + + AWS requires strict schema enforcement. This mirrors the approach used by + AnthropicChatClient._prepare_response_format(). + + Args: + schema: The JSON schema dict to modify in-place. + """ + visited: set[int] = set() + + def walk(node: Any) -> None: + if isinstance(node, dict): + node_id = id(node) + if node_id in visited: + return + visited.add(node_id) + if node.get("type") == "object" or ( + "properties" in node and "type" not in node + ): + existing = node.get("additionalProperties") + if existing is None or existing is True: + node["additionalProperties"] = False + for value in node.values(): + if isinstance(value, (dict, list)): + walk(value) + elif isinstance(node, list): + node_id = id(node) + if node_id in visited: + return + visited.add(node_id) + for item in node: + if isinstance(item, (dict, list)): + walk(item) + + walk(schema) + def service_url(self) -> str: """Returns the service URL for the Bedrock runtime in the configured AWS region. diff --git a/python/packages/bedrock/tests/test_bedrock_structured_output.py b/python/packages/bedrock/tests/test_bedrock_structured_output.py new file mode 100644 index 0000000000..3b4da63411 --- /dev/null +++ b/python/packages/bedrock/tests/test_bedrock_structured_output.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import json +from typing import Any + +import pytest +from agent_framework import Content, Message +from botocore.exceptions import ClientError +from pydantic import BaseModel + +from agent_framework_bedrock import BedrockChatClient + +# region Test models + + +class WeatherReport(BaseModel): + city: str + temperature: float + summary: str + + +class NestedAddress(BaseModel): + street: str + city: str + zip_code: str + + +class Person(BaseModel): + name: str + age: int + address: NestedAddress + + +# endregion + + +# region Helpers + + +class _StubBedrockRuntime: + """Stub that records calls and returns a canned response.""" + + def __init__(self, response_text: str = "Bedrock says hi") -> None: + self.calls: list[dict[str, Any]] = [] + self._response_text = response_text + + def converse(self, **kwargs: Any) -> dict[str, Any]: + self.calls.append(kwargs) + return { + "modelId": kwargs["modelId"], + "responseId": "resp-structured", + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "output": { + "completionReason": "end_turn", + "message": { + "id": "msg-structured", + "role": "assistant", + "content": [{"text": self._response_text}], + }, + }, + } + + +def _make_client(response_text: str = "Bedrock says hi") -> tuple[BedrockChatClient, _StubBedrockRuntime]: + stub = _StubBedrockRuntime(response_text) + client = BedrockChatClient( + model="us.anthropic.claude-haiku-4-5-v1:0", + region="us-east-1", + client=stub, + ) + return client, stub + + +def _user_messages() -> list[Message]: + return [Message(role="user", contents=[Content.from_text(text="Give me a weather report")])] + + +# endregion + + +# region Tests + + +def test_prepare_output_config_correct_wire_shape() -> None: + """_prepare_output_config(WeatherReport) must produce the correct + textFormat → structure → jsonSchema shape with type: 'json_schema'.""" + client, _ = _make_client() + + output_config = client._prepare_output_config(WeatherReport) + + assert output_config is not None + text_format = output_config["textFormat"] + assert text_format["type"] == "json_schema" + assert "structure" in text_format + json_schema = text_format["structure"]["jsonSchema"] + assert json_schema["name"] == "WeatherReport" + assert "schema" in json_schema + + +def test_prepare_output_config_schema_is_json_string() -> None: + """The schema value inside jsonSchema must be a JSON string, not a dict.""" + client, _ = _make_client() + + output_config = client._prepare_output_config(WeatherReport) + + assert output_config is not None + schema_value = output_config["textFormat"]["structure"]["jsonSchema"]["schema"] + assert isinstance(schema_value, str), f"Expected str, got {type(schema_value)}" + # Verify it's valid JSON + parsed = json.loads(schema_value) + assert isinstance(parsed, dict) + assert parsed["type"] == "object" + + +def test_additional_properties_false_set_recursively() -> None: + """additionalProperties: false must be set on all nested object types.""" + client, _ = _make_client() + + output_config = client._prepare_output_config(Person) + + assert output_config is not None + schema_str = output_config["textFormat"]["structure"]["jsonSchema"]["schema"] + schema = json.loads(schema_str) + + # Top-level object + assert schema.get("additionalProperties") is False + + # Check $defs for NestedAddress + defs = schema.get("$defs", {}) + assert "NestedAddress" in defs, "Expected NestedAddress to be present in $defs" + assert defs["NestedAddress"].get("additionalProperties") is False, ( + "Expected additionalProperties=False on nested NestedAddress schema" + ) + + +def test_no_output_config_when_response_format_none() -> None: + """When response_format is None, no outputConfig key should appear in the request.""" + client, stub = _make_client() + messages = _user_messages() + + request = client._prepare_options(messages, {"max_tokens": 100}) + + assert "outputConfig" not in request, ( + f"outputConfig should not be present when response_format is None, got: {request.get('outputConfig')}" + ) + + +async def test_chat_response_value_populated() -> None: + """After a mocked response with response_format, .value should be a populated Pydantic model.""" + json_response = json.dumps({"city": "Seattle", "temperature": 72.5, "summary": "Sunny and warm"}) + client, stub = _make_client(response_text=json_response) + messages = _user_messages() + + response = await client.get_response( + messages=messages, + options={"max_tokens": 100, "response_format": WeatherReport}, + ) + + assert response.text == json_response + assert response.value is not None + assert isinstance(response.value, WeatherReport) + assert response.value.city == "Seattle" + assert response.value.temperature == 72.5 + assert response.value.summary == "Sunny and warm" + + # Verify outputConfig was sent to the API + assert len(stub.calls) == 1 + api_request = stub.calls[0] + assert "outputConfig" in api_request + assert api_request["outputConfig"]["textFormat"]["type"] == "json_schema" + + +def test_dict_schema_response_format() -> None: + """_prepare_output_config should work when response_format is a dict, not just a Pydantic class.""" + client, _ = _make_client() + + dict_schema = { + "json_schema": { + "name": "weather_output", + "schema": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "temp": {"type": "number"}, + }, + }, + } + } + + output_config = client._prepare_output_config(dict_schema) + + assert output_config is not None + json_schema = output_config["textFormat"]["structure"]["jsonSchema"] + assert json_schema["name"] == "weather_output" + schema_parsed = json.loads(json_schema["schema"]) + assert schema_parsed["type"] == "object" + assert "city" in schema_parsed["properties"] + + +def test_prepare_output_config_none_returns_none() -> None: + """_prepare_output_config(None) must return None.""" + client, _ = _make_client() + + result = client._prepare_output_config(None) + + assert result is None + + +async def test_chat_response_value_populated_streaming() -> None: + """In streaming mode, .value should also be populated on the final response.""" + json_response = json.dumps({"city": "Portland", "temperature": 68.0, "summary": "Cloudy"}) + client, stub = _make_client(response_text=json_response) + messages = _user_messages() + + stream = client.get_response( + messages=messages, + stream=True, + options={"max_tokens": 100, "response_format": WeatherReport}, + ) + + # Consume stream and get final response + async for _ in stream: + pass + response = await stream.get_final_response() + + assert response.value is not None + assert isinstance(response.value, WeatherReport) + assert response.value.city == "Portland" + + # Verify outputConfig was sent + assert len(stub.calls) == 1 + assert "outputConfig" in stub.calls[0] + + +async def test_unsupported_model_validation_exception() -> None: + """When a model doesn't support outputConfig, a clear error should be raised.""" + class _FailingStubBedrockRuntime: + def converse(self, **kwargs: Any) -> dict[str, Any]: + # Simulate botocore ClientError for ValidationException + error_response = {"Error": {"Code": "ValidationException", "Message": "Invalid field outputConfig"}} + raise ClientError(error_response, "Converse") + + client = BedrockChatClient( + model="us.anthropic.claude-v2", + region="us-east-1", + client=_FailingStubBedrockRuntime(), + ) + + with pytest.raises(ValueError) as exc: + await client.get_response( + messages=_user_messages(), + options={"response_format": WeatherReport}, + ) + + assert "does not support structured output via outputConfig.textFormat" in str(exc.value) + assert "Check the model's Bedrock Converse outputConfig/textFormat support." in str(exc.value) + + +def test_invalid_response_format_type_raises() -> None: + """Non-dict, non-BaseModel response_format should raise TypeError.""" + client, _ = _make_client() + with pytest.raises(TypeError, match="Pydantic BaseModel subclass"): + client._prepare_output_config("not_a_valid_format") + + +# endregion