diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b56d7e62fb..c5e9f08a73 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -227,6 +227,7 @@ async def reset( request_max_retries: int | None = None, tool_result_overflow_dir: str | None = None, read_tool: FunctionTool | None = None, + overflow_file_writer: T.Callable[[str, str], T.Awaitable[str]] | None = None, **kwargs: T.Any, ) -> None: self.req = request @@ -241,6 +242,7 @@ async def reset( self.request_max_retries = request_max_retries self.tool_result_overflow_dir = tool_result_overflow_dir self.read_tool = read_tool + self._overflow_file_writer = overflow_file_writer self._tool_result_token_counter = EstimateTokenCounter() self.request_context_manager_config = ContextConfig( # <=0 disables token-based guarding. @@ -369,6 +371,9 @@ async def _write_tool_result_overflow_file( tool_call_id: str, content: str, ) -> str: + if self._overflow_file_writer is not None: + return await self._overflow_file_writer(content, tool_call_id) + if self.tool_result_overflow_dir is None: raise ValueError("tool_result_overflow_dir is not configured") diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index af3ac71322..482bcddc0c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -1533,6 +1533,18 @@ async def build_main_agent( elif config.computer_use_runtime == "local": _apply_local_env_tools(req, plugin_context) + overflow_file_writer = None + if ( + config.computer_use_runtime == "sandbox" + and req.func_tool + and req.func_tool.get_tool("astrbot_file_read_tool") + ): + from astrbot.core.computer.computer_client import make_sandbox_overflow_writer + + overflow_file_writer = make_sandbox_overflow_writer( + plugin_context, event.unified_msg_origin + ) + agent_runner = AgentRunner() astr_agent_ctx = AstrAgentContext( context=plugin_context, @@ -1625,6 +1637,7 @@ async def build_main_agent( read_tool=( req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None ), + overflow_file_writer=overflow_file_writer, ) if apply_reset: diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 9be646265e..80d0066ded 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -539,6 +539,40 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: logger.warning(f"Failed to remove temp skills zip: {zip_path}") +def make_sandbox_overflow_writer( + context: Context, + unified_msg_origin: str, +): + """Build a callback that writes tool-result overflow content directly into the sandbox. + + The returned callable has the signature + ``(content: str, tool_call_id: str) -> Awaitable[str]`` and returns a + sandbox-relative path that ``astrbot_file_read_tool`` can resolve inside + the sandbox container. + + Bay's filesystem API requires relative paths, so we write to a file under + the sandbox working directory rather than an absolute ``/tmp/...`` path. + """ + + async def _write(content: str, tool_call_id: str) -> str: + safe_id = ( + "".join( + ch if ch.isalnum() or ch in {"-", "_", "."} else "_" + for ch in tool_call_id + ).strip("._") + or "tool_call" + ) + sandbox_path = f"astrbot_overflow_{safe_id}_{uuid.uuid4().hex[:8]}.txt" + booter = await get_booter(context, unified_msg_origin) + await booter.fs.write_file(sandbox_path, content) + logger.debug( + "[Computer] Overflow file written to sandbox: %s", sandbox_path + ) + return sandbox_path + + return _write + + async def get_booter( context: Context, session_id: str, diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 11fc00eec4..8a6075448a 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -305,6 +305,15 @@ async def tool_loop_agent( other_kwargs.setdefault( "read_tool", request.func_tool.get_tool("astrbot_file_read_tool") ) + if self._is_sandbox_runtime(event.unified_msg_origin): + from astrbot.core.computer.computer_client import ( + make_sandbox_overflow_writer, + ) + + other_kwargs.setdefault( + "overflow_file_writer", + make_sandbox_overflow_writer(self, event.unified_msg_origin), + ) await agent_runner.reset( provider=prov, @@ -503,6 +512,13 @@ def get_config(self, umo: str | None = None) -> AstrBotConfig: return self._config return self.astrbot_config_mgr.get_conf(umo) + def _is_sandbox_runtime(self, umo: str) -> bool: + cfg = self.get_config(umo=umo) + runtime = str( + cfg.get("provider_settings", {}).get("computer_use_runtime", "local") + ) + return runtime == "sandbox" + async def send_message( self, session: str | MessageSesion, diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index b4464680fb..5dcd77fb2a 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1741,6 +1741,307 @@ async def test_follow_up_after_stop_not_merged_into_tool_result( assert ticket_before.resolved.is_set() +# --------------------------------------------------------------------------- +# Tests for tool-result overflow file writer (sandbox mode fix) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_overflow_file_writer_callback_is_used(tmp_path, monkeypatch): + """When overflow_file_writer is provided (via make_sandbox_overflow_writer), + the agent runner MUST upload overflow content into the sandbox and return a + sandbox-relative path — no file written to the host tool_result_overflow_dir.""" + from astrbot.core.computer.computer_client import make_sandbox_overflow_writer + + tool = FunctionTool( + name="test_tool", + description="test", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + read_tool = FunctionTool( + name="astrbot_file_read_tool", + description="read file", + parameters={"type": "object", "properties": {"path": {"type": "string"}}}, + handler=AsyncMock(), + ) + tool_set = ToolSet(tools=[tool, read_tool]) + provider = SingleToolThenFinalProvider(tool.name, {"query": "large"}) + request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[]) + runner = ToolLoopAgentRunner() + + # ---- Fake booter to record sandbox writes -------------------------------- + write_calls: list[dict] = [] + + class _FakeFS: + async def write_file(self, path: str, content: str) -> None: + write_calls.append({"path": path, "content": content}) + + class _FakeBooter: + def __init__(self): + self.fs = _FakeFS() + + _fake_booter = _FakeBooter() + + async def _fake_get_booter(context, umo): + return _fake_booter + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.get_booter", + _fake_get_booter, + ) + + sandbox_writer = make_sandbox_overflow_writer( + context=SimpleNamespace(), # type: ignore[arg-type] + unified_msg_origin="test_umo", + ) + + # ---- Run agent with sandbox writer --------------------------------------- + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=cast( + Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text()) + ), + agent_hooks=MockHooks(), + streaming=False, + tool_result_overflow_dir=str(tmp_path), + read_tool=read_tool, + overflow_file_writer=sandbox_writer, + ) + + async for _ in runner.step_until_done(3): + pass + + # ---- Assert: file was uploaded to sandbox via booter --------------------- + assert len(write_calls) == 1, ( + f"Expected 1 booter.fs.write_file call, got {len(write_calls)}" + ) + assert write_calls[0]["content"] == _make_large_tool_result_text(), ( + "Overflow content uploaded to sandbox does not match original" + ) + sandbox_path = write_calls[0]["path"] + assert sandbox_path.startswith("astrbot_overflow_"), ( + f"Expected sandbox-relative path (astrbot_overflow_*), got: {sandbox_path}" + ) + assert sandbox_path.endswith(".txt") + + # ---- Assert: NOTICE contains the sandbox path, not the host dir ---------- + tool_messages = [m for m in runner.run_context.messages if m.role == "tool"] + assert len(tool_messages) == 1 + tool_message_content = str(tool_messages[0].content) + assert sandbox_path in tool_message_content, ( + f"Expected sandbox path '{sandbox_path}' in notice, " + f"got: ...{tool_message_content[-200:]}" + ) + assert "Truncated tool output preview shown above." in tool_message_content + assert "`astrbot_file_read_tool`" in tool_message_content + + # ---- Assert: NO file was written to the host tool_result_overflow_dir ---- + overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt")) + assert len(overflow_files) == 0, ( + f"Sandbox writer was used but file leaked to host dir: {overflow_files}" + ) + + +@pytest.mark.asyncio +async def test_overflow_file_writer_none_uses_disk_fallback(tmp_path): + """When overflow_file_writer is None (default), the existing disk-based + overflow path MUST work exactly as before — no regression.""" + tool = FunctionTool( + name="test_tool", + description="test", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + read_tool = FunctionTool( + name="astrbot_file_read_tool", + description="read file", + parameters={"type": "object", "properties": {"path": {"type": "string"}}}, + handler=AsyncMock(), + ) + tool_set = ToolSet(tools=[tool, read_tool]) + provider = SingleToolThenFinalProvider(tool.name, {"query": "large"}) + request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[]) + runner = ToolLoopAgentRunner() + + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=cast( + Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text()) + ), + agent_hooks=MockHooks(), + streaming=False, + tool_result_overflow_dir=str(tmp_path), + read_tool=read_tool, + # overflow_file_writer NOT passed — should default to None + ) + + async for _ in runner.step_until_done(3): + pass + + # Disk-based overflow MUST still work + tool_messages = [m for m in runner.run_context.messages if m.role == "tool"] + assert len(tool_messages) == 1 + tool_message_content = str(tool_messages[0].content) + assert "Truncated tool output preview shown above." in tool_message_content + assert "`astrbot_file_read_tool`" in tool_message_content + + overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt")) + assert len(overflow_files) == 1 + assert ( + overflow_files[0].read_text(encoding="utf-8") == _make_large_tool_result_text() + ) + + +def test_make_sandbox_overflow_writer_returns_callable(): + """make_sandbox_overflow_writer must return an async callable.""" + from astrbot.core.computer.computer_client import make_sandbox_overflow_writer + + writer = make_sandbox_overflow_writer( + context=SimpleNamespace(), # type: ignore[arg-type] + unified_msg_origin="test_umo", + ) + assert callable(writer) + assert asyncio.iscoroutinefunction(writer) + + +@pytest.mark.asyncio +async def test_make_sandbox_overflow_writer_writes_via_booter(monkeypatch): + """The writer returned by make_sandbox_overflow_writer MUST write to the + sandbox filesystem via booter.fs.write_file and return a relative path.""" + from astrbot.core.computer.computer_client import make_sandbox_overflow_writer + + # Fake booter that records write_file calls + write_calls: list[dict] = [] + + class _FakeFS: + async def write_file(self, path: str, content: str) -> None: + write_calls.append({"path": path, "content": content}) + + class _FakeBooter: + def __init__(self): + self.fs = _FakeFS() + + _fake_booter = _FakeBooter() + + async def _fake_get_booter(context, umo): + return _fake_booter + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.get_booter", + _fake_get_booter, + ) + + writer = make_sandbox_overflow_writer( + context=SimpleNamespace(), # type: ignore[arg-type] + unified_msg_origin="test_umo", + ) + + result_path = await writer("hello sandbox", "call_abc123") + + # Must return a sandbox-relative path + assert result_path.startswith("astrbot_overflow_"), ( + f"Expected sandbox-relative path (astrbot_overflow_*), got: {result_path}" + ) + assert result_path.endswith(".txt") + + # Must have called write_file on the booter + assert len(write_calls) == 1 + assert write_calls[0]["path"] == result_path + assert write_calls[0]["content"] == "hello sandbox" + + +@pytest.mark.asyncio +async def test_overflow_notice_contains_sandbox_path_not_host_path(monkeypatch): + """End-to-end: when make_sandbox_overflow_writer is wired through the agent + runner, the tool-message notice MUST contain the sandbox-relative path + and MUST NOT leak the host's tool_result_overflow_dir.""" + from astrbot.core.computer.computer_client import make_sandbox_overflow_writer + + tool = FunctionTool( + name="test_tool", + description="test", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + handler=AsyncMock(), + ) + read_tool = FunctionTool( + name="astrbot_file_read_tool", + description="read file", + parameters={"type": "object", "properties": {"path": {"type": "string"}}}, + handler=AsyncMock(), + ) + tool_set = ToolSet(tools=[tool, read_tool]) + provider = SingleToolThenFinalProvider(tool.name, {"query": "large"}) + request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[]) + runner = ToolLoopAgentRunner() + + # ---- Fake booter to capture the sandbox path ----------------------------- + sandbox_written_path: str | None = None + + class _FakeFS: + async def write_file(self, path: str, _content: str) -> None: + nonlocal sandbox_written_path + sandbox_written_path = path + + class _FakeBooter: + def __init__(self): + self.fs = _FakeFS() + + async def _fake_get_booter(_context, _umo): + return _FakeBooter() + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.get_booter", + _fake_get_booter, + ) + + sandbox_writer = make_sandbox_overflow_writer( + context=SimpleNamespace(), # type: ignore[arg-type] + unified_msg_origin="test_umo", + ) + + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=cast( + Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text()) + ), + agent_hooks=MockHooks(), + streaming=False, + tool_result_overflow_dir="/tmp/.astrbot", # host path — must NOT leak + read_tool=read_tool, + overflow_file_writer=sandbox_writer, + ) + + async for _ in runner.step_until_done(3): + pass + + # ---- Assert: content was uploaded to the sandbox ------------------------- + assert sandbox_written_path is not None, ( + "Expected booter.fs.write_file to be called" + ) + assert sandbox_written_path.startswith("astrbot_overflow_"), ( + f"Expected sandbox-relative path (astrbot_overflow_*), got: {sandbox_written_path}" + ) + + # ---- Assert: NOTICE uses sandbox path, host path does NOT leak ----------- + tool_messages = [m for m in runner.run_context.messages if m.role == "tool"] + assert len(tool_messages) == 1 + tool_message_content = str(tool_messages[0].content) + + assert sandbox_written_path in tool_message_content, ( + f"Expected sandbox path {sandbox_written_path!r} in notice" + ) + assert "/tmp/.astrbot" not in tool_message_content, ( + "Host tool_result_overflow_dir path leaked into sandbox-mode notice" + ) + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_computer.py b/tests/unit/test_computer.py index e667f98a6c..68c4062c70 100644 --- a/tests/unit/test_computer.py +++ b/tests/unit/test_computer.py @@ -777,3 +777,81 @@ async def test_sync_skills_success(self): ): # Should not raise await computer_client._sync_skills_to_sandbox(mock_booter) + + @pytest.mark.asyncio + async def test_make_sandbox_overflow_writer(self, monkeypatch): + """make_sandbox_overflow_writer returns a callback that writes to the + sandbox via booter.fs.write_file with a relative path.""" + from astrbot.core.computer import computer_client + + write_calls: list[dict] = [] + + class _FakeFS: + async def write_file(self, path: str, content: str): + write_calls.append({"path": path, "content": content}) + + class _FakeBooter: + fs = _FakeFS() + + async def _fake_get_booter(context, umo): + return _FakeBooter() + + monkeypatch.setattr(computer_client, "get_booter", _fake_get_booter) + + writer = computer_client.make_sandbox_overflow_writer( + context=object(), # type: ignore[arg-type] + unified_msg_origin="test-umo", + ) + + result = await writer("overflow content", "tool-call-001") + + # Must return a sandbox-relative path + assert result.startswith("astrbot_overflow_"), ( + f"Expected sandbox-relative path (astrbot_overflow_*), got: {result}" + ) + assert result.endswith(".txt") + assert "tool_call_001" in result or "tool-call-001" in result + + # Must have called booter.fs.write_file + assert len(write_calls) == 1 + assert write_calls[0]["path"] == result + assert write_calls[0]["content"] == "overflow content" + + @pytest.mark.asyncio + async def test_make_sandbox_overflow_writer_sanitizes_tool_call_id( + self, monkeypatch, + ): + """The sandbox overflow writer must sanitize special characters in + the tool_call_id for use in a safe filename.""" + from astrbot.core.computer import computer_client + + write_calls: list[dict] = [] + + class _FakeFS: + async def write_file(self, path: str, content: str): + write_calls.append({"path": path, "content": content}) + + class _FakeBooter: + fs = _FakeFS() + + async def _fake_get_booter(context, umo): + return _FakeBooter() + + monkeypatch.setattr(computer_client, "get_booter", _fake_get_booter) + + writer = computer_client.make_sandbox_overflow_writer( + context=object(), # type: ignore[arg-type] + unified_msg_origin="test-umo", + ) + + # Tool call IDs from various providers may contain special chars + result = await writer("data", "chatcmpl-tool_abc:123/456") + + # Path must be a valid relative filename (no directory separators) + assert result.startswith("astrbot_overflow_") + # Must not contain : or / + assert ":" not in result + assert "/" not in result + # Must still have written via the booter + assert len(write_calls) == 1 + assert write_calls[0]["content"] == "data"