diff --git a/pyproject.toml b/pyproject.toml index 833d19fbb..0c4cb62a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.13.18" +version = "0.13.19" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/tools/a2a/a2a_tool.py b/src/uipath_langchain/agent/tools/a2a/a2a_tool.py index c7ab64425..725eda178 100644 --- a/src/uipath_langchain/agent/tools/a2a/a2a_tool.py +++ b/src/uipath_langchain/agent/tools/a2a/a2a_tool.py @@ -49,6 +49,16 @@ logger = getLogger(__name__) +# The A2A terminal task states +_TERMINAL_TASK_STATES = frozenset( + { + TaskState.completed.value, + TaskState.canceled.value, + TaskState.failed.value, + TaskState.rejected.value, + } +) + class A2aToolInput(BaseModel): """Input schema for A2A agent tool.""" @@ -315,6 +325,11 @@ async def _a2a_wrapper( context_id=context_id, ) + # The server rejects messages to a terminal task, so start a new task + # next turn, keeping context_id to stay in the same conversation. + if task_state in _TERMINAL_TASK_STATES: + new_task_id = None + return Command( update={ "messages": [ diff --git a/tests/agent/tools/test_a2a_tool.py b/tests/agent/tools/test_a2a_tool.py index a256f0438..b27cb2936 100644 --- a/tests/agent/tools/test_a2a_tool.py +++ b/tests/agent/tools/test_a2a_tool.py @@ -11,7 +11,17 @@ import pytest from a2a.client import Client -from a2a.types import AgentCard, Message, Part, Role, TextPart +from a2a.types import ( + AgentCard, + Artifact, + Message, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, +) from opentelemetry import trace as otel_trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -318,6 +328,93 @@ async def _get(): assert "pong" in command.update["messages"][0].content +def _completed_task( + *, task_id: str = "task-1", context_id: str = "ctx-1", text: str = "done" +) -> Task: + return Task( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.completed), + artifacts=[ + Artifact( + artifact_id="artifact-1", + parts=[Part(root=TextPart(text=text))], + ) + ], + ) + + +def _input_required_task( + *, task_id: str = "task-1", context_id: str = "ctx-1", text: str = "need more" +) -> Task: + return Task( + id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.input_required, + message=Message( + role=Role.agent, + parts=[Part(root=TextPart(text=text))], + message_id="status-msg", + ), + ), + ) + + +async def test_tool_wrapper_drops_task_id_after_terminal_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A completed (terminal) task is not reused: the next turn starts a new + task while keeping the conversation context.""" + resource = _make_resource(cached_agent_card=_cached_card()) + tools, clients = create_a2a_tools_and_clients([resource]) + tool = cast(A2aStructuredToolWithWrapper, tools[0]) + fake = _FakeA2aClient( + [(_completed_task(task_id="task-1", context_id="ctx-1"), None)] + ) + + async def _get(): + return fake + + monkeypatch.setattr(clients[0], "get", _get) + wrapper: Any = tool.awrapper + assert wrapper is not None + + call = {"name": tool.name, "args": {"message": "ping"}, "id": "call-1"} + command = await wrapper(tool, call, AgentGraphState()) + + stored = command.update["inner_state"]["tools_storage"][tool.name] + assert stored["task_id"] is None + assert stored["context_id"] == "ctx-1" + + +async def test_tool_wrapper_keeps_task_id_when_not_terminal( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A non-terminal task (input-required) keeps its task_id so the next turn + continues the same task.""" + resource = _make_resource(cached_agent_card=_cached_card()) + tools, clients = create_a2a_tools_and_clients([resource]) + tool = cast(A2aStructuredToolWithWrapper, tools[0]) + fake = _FakeA2aClient( + [(_input_required_task(task_id="task-1", context_id="ctx-1"), None)] + ) + + async def _get(): + return fake + + monkeypatch.setattr(clients[0], "get", _get) + wrapper: Any = tool.awrapper + assert wrapper is not None + + call = {"name": tool.name, "args": {"message": "ping"}, "id": "call-1"} + command = await wrapper(tool, call, AgentGraphState()) + + stored = command.update["inner_state"]["tools_storage"][tool.name] + assert stored["task_id"] == "task-1" + assert stored["context_id"] == "ctx-1" + + def test_a2a_sdk_telemetry_suppressed_by_default() -> None: """Importing the a2a package disables the a2a-sdk's own OTel transport spans. diff --git a/uv.lock b/uv.lock index 208a31def..edee85254 100644 --- a/uv.lock +++ b/uv.lock @@ -4439,7 +4439,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.13.18" +version = "0.13.19" source = { editable = "." } dependencies = [ { name = "a2a-sdk" },