Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 144 additions & 12 deletions python/packages/bedrock/agent_framework_bedrock/_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import asyncio
import copy
import json
import logging
import sys
Expand Down Expand Up @@ -36,6 +37,7 @@
from boto3.session import Session as Boto3Session
from botocore.client import BaseClient
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError
from pydantic import BaseModel

if sys.version_info >= (3, 13):
Expand Down Expand Up @@ -115,13 +117,20 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
translates to ``toolConfig.tools``.
tool_choice: How the model should use tools,
translates to ``toolConfig.toolChoice``.
response_format: Structured output format. Accepts a Pydantic BaseModel
subclass or an OpenAI-style dict schema
(``{"json_schema": {"name": ..., "schema": ...}}``).
When provided, the Converse API request includes
``outputConfig.textFormat`` with the schema serialized as a JSON
string. ``ChatResponse.value`` will be populated with the parsed
model instance. Only supported on models that support
``outputConfig.textFormat``. Unsupported models raise a ValueError.

# Options not supported in Bedrock Converse API:
seed: Not supported.
frequency_penalty: Not supported.
presence_penalty: Not supported.
allow_multiple_tool_calls: Not supported (models handle parallel calls automatically).
response_format: Not directly supported (use model-specific prompting).
user: Not supported.
store: Not supported.
logit_bias: Not supported.
Expand Down Expand Up @@ -161,9 +170,6 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
allow_multiple_tool_calls: None # type: ignore[misc]
"""Not supported. Bedrock models handle parallel tool calls automatically."""

response_format: None # type: ignore[misc]
"""Not directly supported. Use model-specific prompting for JSON output."""

user: None # type: ignore[misc]
"""Not supported in Bedrock Converse API."""

Expand Down Expand Up @@ -324,10 +330,28 @@ def _create_session(settings: BedrockSettings) -> Boto3Session:
return Boto3Session(**session_kwargs)

def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]:
response = self._bedrock_client.converse(**request)
if not isinstance(response, Mapping):
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
return response
try:
response = self._bedrock_client.converse(**request)
if not isinstance(response, Mapping):
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
return response
except ClientError as e:
error_details = e.response.get("Error", {})
error_code = error_details.get("Code", "")
error_message = error_details.get("Message", "")
# "outputConfig" in error_message catches cases where Bedrock explicitly
# rejects the outputConfig field (unsupported model). Other ValidationExceptions
# (e.g. malformed schema shape, invalid property values) will not mention
# "outputConfig" and will bubble up as raw ClientError without being misdiagnosed.
if error_code == "ValidationException" and (
"outputconfig" in error_message.lower() or "outputconfig" in str(e).lower()
):
raise ValueError(
f"Model '{self.model}' does not support structured output via outputConfig.textFormat. "
"Check the model's Bedrock Converse outputConfig/textFormat support. "
f"AWS error Code: {error_code}. AWS error Message: {error_message}"
) from e
Comment on lines +342 to +353
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked the existing MAF exception hierarchy — there is no UnsupportedFeature-style exception in the codebase. Two options: use the existing ChatClientInvalidRequestException which semantically fits ("the model rejected this request configuration"), or keep ValueError since it's standard Python for bad argument values and is consistent with how other validation errors are surfaced across MAF. Flagging for human reviewer input before making this call — happy to go either direction.

raise
Comment thread
karthik-0306 marked this conversation as resolved.
Comment thread
karthik-0306 marked this conversation as resolved.

@override
def _inner_get_response(
Expand All @@ -344,7 +368,7 @@ def _inner_get_response(
# Streaming mode - simulate streaming by yielding a single update
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
response = await asyncio.to_thread(self._invoke_converse, request)
parsed_response = self._process_converse_response(response)
parsed_response = self._process_converse_response(response, options)
contents = list(parsed_response.messages[0].contents if parsed_response.messages else [])
if parsed_response.usage_details:
contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type]
Expand All @@ -360,12 +384,12 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
raw_representation=parsed_response.raw_representation,
)

return self._build_response_stream(_stream())
return self._build_response_stream(_stream(), response_format=options.get("response_format"))

# Non-streaming mode
async def _get_response() -> ChatResponse:
raw_response = await asyncio.to_thread(self._invoke_converse, request)
return self._process_converse_response(raw_response)
return self._process_converse_response(raw_response, options)

return _get_response()

Expand Down Expand Up @@ -430,6 +454,9 @@ def _prepare_options(
if tool_config:
run_options["toolConfig"] = tool_config

if output_config := self._prepare_output_config(options.get("response_format")):
run_options["outputConfig"] = output_config
Comment thread
karthik-0306 marked this conversation as resolved.

return run_options

def _prepare_bedrock_messages(
Expand Down Expand Up @@ -628,7 +655,9 @@ def _prepare_tools(self, tools: list[FunctionTool | MutableMapping[str, Any]] |
def _generate_tool_call_id() -> str:
return f"tool-call-{uuid4().hex}"

def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
def _process_converse_response(
self, response: dict[str, Any], options: Mapping[str, Any] | None = None
) -> ChatResponse:
"""Convert Bedrock Converse API response to ChatResponse."""
output = response.get("output") or {}
message = output.get("message") or {}
Expand All @@ -646,6 +675,7 @@ def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
usage_details=usage_details,
model=model,
finish_reason=finish_reason,
response_format=options.get("response_format") if options else None,
raw_representation=response,
)

Expand Down Expand Up @@ -728,6 +758,108 @@ def _map_finish_reason(self, reason: str | None) -> FinishReasonLiteral | None:
return None
return FINISH_REASON_MAP.get(reason.lower())

def _prepare_output_config(self, response_format: Any | None) -> dict[str, Any] | None:
"""Convert response_format into the AWS Bedrock outputConfig wire format.

Args:
response_format: A Pydantic model class or a dict schema, or None.

Returns:
A dict for the Converse API ``outputConfig`` parameter, or None if
response_format is not set.
"""
if response_format is None:
return None

if isinstance(response_format, dict):
if "json_schema" in response_format:
# Shape A — OpenAI-style wrapper
json_schema_config = response_format["json_schema"]
schema_src = json_schema_config.get("schema", {})
name = json_schema_config.get("name", "output_schema")
elif "schema" in response_format:
# Shape B — inner shape directly {"name": ..., "schema": ...}
schema_src = response_format["schema"]
name = response_format.get("name", "output_schema")
else:
# Shape C — assume entire dict is the raw schema
logger.warning(
"response_format dict has no 'json_schema' or 'schema' key; "
"treating entire dict as raw JSON schema."
)
schema_src = response_format
name = "output_schema"

if isinstance(schema_src, str):
schema_src = json.loads(schema_src)
schema = copy.deepcopy(schema_src)
else:
Comment thread
karthik-0306 marked this conversation as resolved.
if not isinstance(response_format, type) or not issubclass(response_format, BaseModel):
raise TypeError(
"response_format must be None, a dict JSON schema, "
"or a Pydantic BaseModel subclass."
)
# response_format is a Pydantic model class
schema = response_format.model_json_schema()
name = response_format.__name__

self._set_additional_properties_false(schema)
Comment thread
karthik-0306 marked this conversation as resolved.

json_schema: dict[str, Any] = {
"name": name,
"schema": json.dumps(schema),
}

description = getattr(response_format, "__doc__", None) if not isinstance(response_format, dict) else None
if description and isinstance(description, str) and description.strip():
json_schema["description"] = description.strip()

return {
"textFormat": {
"type": "json_schema",
"structure": {
"jsonSchema": json_schema
},
}
}

def _set_additional_properties_false(self, schema: dict[str, Any]) -> None:
"""Recursively set additionalProperties: false on all object types in a JSON schema.

AWS requires strict schema enforcement. This mirrors the approach used by
AnthropicChatClient._prepare_response_format().

Args:
schema: The JSON schema dict to modify in-place.
"""
visited: set[int] = set()

def walk(node: Any) -> None:
if isinstance(node, dict):
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)
if node.get("type") == "object" or (
"properties" in node and "type" not in node
):
existing = node.get("additionalProperties")
if existing is None or existing is True:
node["additionalProperties"] = False
for value in node.values():
if isinstance(value, (dict, list)):
walk(value)
Comment thread
karthik-0306 marked this conversation as resolved.
elif isinstance(node, list):
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)
for item in node:
if isinstance(item, (dict, list)):
walk(item)

walk(schema)

def service_url(self) -> str:
"""Returns the service URL for the Bedrock runtime in the configured AWS region.

Expand Down
Loading