diff --git a/agentex/src/api/routes/channels.py b/agentex/src/api/routes/channels.py index 3d31a8b0..a3e7668d 100644 --- a/agentex/src/api/routes/channels.py +++ b/agentex/src/api/routes/channels.py @@ -6,12 +6,23 @@ `params` dict forwarded verbatim to task/create (agentex does not interpret it — the bound agent does). -Route bindings live in the CHANNELS_WEBHOOK_ROUTES env var (JSON) for now — the -seam where a real per-config store (resolving a saved agent_config_id) plugs in later: - - CHANNELS_WEBHOOK_ROUTES='{"demo": {"secret": "", - "agent_name": "golden-agent", - "params": {"system_prompt": "You are ...", "mcps": []}}}' +Route bindings live in the CHANNELS_WEBHOOK_ROUTES env var (JSON) for now. A route +supplies the turn's params one of two ways: + +- remote (`params_source`): a URL the channel GETs at dispatch time to obtain the + params. The source owns whatever produces them; the channel just forwards the + result. Auth headers for the fetch are configured generically via + CHANNELS_PARAMS_SOURCE_HEADERS (a JSON object of header name -> value). +- inline (`params`): an opaque dict passed directly, for one-off routes. + + CHANNELS_WEBHOOK_ROUTES='{ + "pr-review": {"secret": "", "agent_name": "", + "channel": "github_pr", + "params_source": "https:///"}, + "demo": {"secret": "", "agent_name": "", + "params": {"system_prompt": "You are ..."}}}' + +(The env-backed store is the current seam; a DB-backed route store replaces it later.) """ from __future__ import annotations @@ -21,7 +32,9 @@ from fastapi import APIRouter, HTTPException, Query, Request -from src.domain.channels.base import ChannelBinding +from src.domain.channels.base import Channel, ChannelBinding +from src.domain.channels.github_pr import GitHubPRChannel +from src.domain.channels.params_source import ParamsSourceError, resolve_binding_params from src.domain.channels.router import ChannelRouter from src.domain.channels.webhook import MAX_BODY_BYTES, WebhookChannel from src.domain.services.task_message_service import DTaskMessageService @@ -33,8 +46,12 @@ router = APIRouter(prefix="/channels", tags=["Channels"]) -# Channel registry — add "slack": SlackChannel() here when it lands. -_WEBHOOK = WebhookChannel() +# Channel registry — keyed by ChannelBinding.channel. Add "slack": SlackChannel() here +# when it lands. All entries reach the same ingress endpoint below; the binding selects. +_CHANNELS: dict[str, Channel] = { + "webhook": WebhookChannel(), + "github_pr": GitHubPRChannel(), +} def _webhook_binding(route_id: str) -> ChannelBinding | None: @@ -50,13 +67,18 @@ def _webhook_binding(route_id: str) -> ChannelBinding | None: cfg = routes.get(route_id) if not isinstance(cfg, dict) or not cfg.get("secret") or not cfg.get("agent_name"): return None - # `params` is opaque — forwarded verbatim to task/create; agentex does not - # interpret it (an agent like golden-agent reads system_prompt/mcps/etc. there). + # A binding supplies its params either remotely (params_source URL, fetched at + # dispatch time) or inline (the opaque `params` dict, forwarded verbatim to + # task/create; agentex does not interpret it — the bound agent does). + params_source = cfg.get("params_source") params = cfg.get("params") + channel = cfg.get("channel") return ChannelBinding( secret=cfg["secret"], agent_name=cfg["agent_name"], + channel=channel if isinstance(channel, str) and channel else "webhook", params=params if isinstance(params, dict) else {}, + params_source=params_source if isinstance(params_source, str) else None, ) @@ -82,26 +104,41 @@ async def handle_webhook( binding = _webhook_binding(route_id) if binding is None: raise HTTPException(status_code=404, detail="unknown route") + channel = _CHANNELS.get(binding.channel) + if channel is None: + logger.error("[channels] route %s bound to unknown channel %r", route_id, binding.channel) + raise HTTPException(status_code=500, detail="misconfigured route channel") raw = await request.body() if len(raw) > MAX_BODY_BYTES: raise HTTPException(status_code=413, detail="payload too large") - if not _WEBHOOK.authenticate(binding, request, raw): + if not channel.authenticate(binding, request, raw): raise HTTPException(status_code=401, detail="unauthorized") if "application/json" not in request.headers.get("content-type", ""): raise HTTPException(status_code=400, detail="expected application/json") try: body = json.loads(raw) except json.JSONDecodeError: - raise HTTPException(status_code=400, detail="invalid json") + raise HTTPException(status_code=400, detail="invalid json") from None if not isinstance(body, dict): raise HTTPException(status_code=400, detail="json body must be an object") + # Remote params: if the route has a params_source, fetch the params now (no-op for + # inline-params bindings). Done after auth so an unauthenticated request never + # triggers an outbound fetch. + try: + binding = await resolve_binding_params(binding) + except ParamsSourceError as exc: + logger.error( + "[channels] params source resolution failed for route %s: %s", route_id, exc + ) + raise HTTPException(status_code=500, detail="params resolution failed") from exc + # Resolve the agent's ACP type so the router picks the right turn method # (sync -> message/send returns the reply inline; async -> event/send). agent = await agents_use_case.get(name=binding.agent_name) - inbound = _WEBHOOK.to_inbound(route_id, body) + inbound = channel.to_inbound(route_id, body) router_ = ChannelRouter(agents_acp_use_case, task_message_service) result = await router_.dispatch(inbound, binding, agent.acp_type) @@ -115,7 +152,7 @@ async def handle_webhook( response = { "ok": True, - "channel": "webhook", + "channel": inbound.channel, "route_id": route_id, "task_id": result.task_id, } diff --git a/agentex/src/domain/channels/base.py b/agentex/src/domain/channels/base.py index 762a78b8..ff204ec9 100644 --- a/agentex/src/domain/channels/base.py +++ b/agentex/src/domain/channels/base.py @@ -43,15 +43,31 @@ def session_key(self, agent_name: str) -> str: class ChannelBinding: """A route's binding to one agent. + A binding provides the turn's task params one of two ways: + + - **inline** (`params` set): the params are given directly — a one-off with no + remote lookup. + - **remote** (`params_source` set): a URL the channel GETs at dispatch time to + obtain the params (see `domain.channels.params_source`). The source endpoint + owns whatever produces those params; the channel layer just forwards the result + and never interprets it. + `params` is an OPAQUE dict forwarded verbatim as the task/create params — the agentex platform does not interpret it. Whatever a given agent expects there - (e.g. golden-agent's system_prompt / mcps / harness / model) is that agent's - concern, not the channel layer's. Later this can be sourced from a saved config. + (system prompt, tools, model, …) is that agent's concern, not the channel layer's. """ secret: str agent_name: str + # Which channel implementation handles this route (registry key, e.g. "webhook", + # "github_pr", "slack"). Defaults to the generic webhook channel. + channel: str = "webhook" params: dict[str, Any] = field(default_factory=dict) + # When set, `params` is fetched from this URL at dispatch time (the source owns + # what they contain; the channel layer just forwards them). + params_source: str | None = None + # Extra metadata to stamp on the task (e.g. returned alongside remote params). + extra_task_metadata: dict[str, str] = field(default_factory=dict) # Headers the router forwards to the agent (auth/delegation). Empty for local/open. forward_headers: dict[str, str] = field(default_factory=dict) diff --git a/agentex/src/domain/channels/github_pr.py b/agentex/src/domain/channels/github_pr.py new file mode 100644 index 00000000..b61aeff5 --- /dev/null +++ b/agentex/src/domain/channels/github_pr.py @@ -0,0 +1,115 @@ +"""GitHubPRChannel: shape a GitHub/Gitea pull-request webhook into a clean prompt. + +A thin payload-shaper on top of the generic webhook channel. It reuses +`WebhookChannel`'s `sha256=` HMAC authentication (the GitHub/Gitea +`X-Hub-Signature-256` scheme) and only overrides normalization: a PR event becomes +a single review prompt (title + body + metadata, plus an inline diff when the caller +includes one), keyed on `repo#number` so repeated events for the same PR (opened, +synchronize, reopened, ...) fold into one task instead of spawning a new one each time. + +This keeps PR-specific shaping out of the generic `WebhookChannel` (which stays +source-agnostic by design). Posting the reply back as a PR comment is the outbound +half — a CI Action can call this endpoint with `?wait=true` and post the returned +review. +""" + +from __future__ import annotations + +from typing import Any + +from src.domain.channels.base import InboundMessage +from src.domain.channels.webhook import WebhookChannel + +# Keep the shaped prompt bounded — PR bodies and diffs can be large. +_MAX_BODY_CHARS = 4000 +_MAX_DIFF_CHARS = 30000 + + +class GitHubPRChannel(WebhookChannel): + name = "github_pr" + + def to_inbound(self, route_id: str, body: dict[str, Any]) -> InboundMessage: + pull_request = body.get("pull_request") + if not isinstance(pull_request, dict): + # Not a PR event (ping, issue comment, ...) — defer to generic rendering. + return super().to_inbound(route_id, body) + + return InboundMessage( + text=_render_pr_prompt(body, pull_request), + channel=self.name, + route_id=route_id, + peer_id=_pr_peer_id(body, pull_request) or route_id, + sender_id=_actor(body), + raw=body, + ) + + +def _repo_full_name(body: dict[str, Any]) -> str | None: + repo = body.get("repository") + if isinstance(repo, dict): + full_name = repo.get("full_name") + if isinstance(full_name, str) and full_name: + return full_name + return None + + +def _pr_peer_id(body: dict[str, Any], pull_request: dict[str, Any]) -> str | None: + """Stable per-PR conversation scope so repeat events fold into one task.""" + number = pull_request.get("number") + repo = _repo_full_name(body) + if repo and number is not None: + return f"{repo}#{number}" + if number is not None: + return f"pr#{number}" + return None + + +def _actor(body: dict[str, Any]) -> str: + sender = body.get("sender") + if isinstance(sender, dict): + login = sender.get("login") + if isinstance(login, str) and login: + return login + return "github" + + +def _inline_diff(body: dict[str, Any], pull_request: dict[str, Any]) -> str | None: + """A diff the caller chose to inline (webhook payloads don't carry it natively).""" + for source in (body, pull_request): + diff = source.get("diff") + if isinstance(diff, str) and diff.strip(): + return diff.strip() + return None + + +def _render_pr_prompt(body: dict[str, Any], pull_request: dict[str, Any]) -> str: + repo = _repo_full_name(body) + number = pull_request.get("number") + title = (pull_request.get("title") or "").strip() + action = (body.get("action") or "").strip() + description = (pull_request.get("body") or "").strip() + html_url = pull_request.get("html_url") or pull_request.get("url") + + header = "Pull request" + if repo and number is not None: + header = f"Pull request {repo}#{number}" + elif number is not None: + header = f"Pull request #{number}" + + lines = [f"{header}: {title}" if title else header] + if action: + lines.append(f"Action: {action}") + if html_url: + lines.append(f"URL: {html_url}") + if description: + lines.append("") + lines.append("Description:") + lines.append(description[:_MAX_BODY_CHARS]) + + diff = _inline_diff(body, pull_request) + if diff: + lines.append("") + lines.append("Diff:") + lines.append(diff[:_MAX_DIFF_CHARS]) + + return "\n".join(lines) diff --git a/agentex/src/domain/channels/params_source.py b/agentex/src/domain/channels/params_source.py new file mode 100644 index 00000000..d5ae8f3b --- /dev/null +++ b/agentex/src/domain/channels/params_source.py @@ -0,0 +1,102 @@ +"""Resolve a channel binding's task params from a configured remote source. + +A route binding can carry a ``params_source``: a URL the channel GETs at dispatch +time to obtain the opaque params forwarded to ``task/create``. This keeps the channel +layer generic — it fetches and forwards params without interpreting them. The source +endpoint owns whatever mapping produces those params; the channel never learns what +they mean. + +Response shape (lenient):: + + { "params": { ... }, "task_metadata": { ... } } # task_metadata optional + +A bare JSON object with no ``params`` key is treated as the params dict itself. +""" + +from __future__ import annotations + +import os +from collections.abc import Awaitable, Callable +from typing import Any + +from src.domain.channels.base import ChannelBinding +from src.utils.logging import make_logger + +logger = make_logger(__name__) + +# Injectable fetcher: url -> response JSON. Default uses httpx; tests inject a fake. +ParamsFetcher = Callable[[str], Awaitable[dict[str, Any]]] + +# Optional generic auth header sent when fetching a params source, configured via env +# so no credential is hard-coded. JSON object of header name -> value, e.g. +# CHANNELS_PARAMS_SOURCE_HEADERS='{"x-api-key": "...", "x-selected-account-id": "..."}'. +# The channel forwards these opaquely — it does not interpret what they mean. +_AUTH_HEADERS_ENV = "CHANNELS_PARAMS_SOURCE_HEADERS" + + +class ParamsSourceError(RuntimeError): + """Raised when a binding's params_source cannot be resolved.""" + + +async def _default_fetch(url: str) -> dict[str, Any]: + """GET the params source over HTTP. Imported lazily so inline-only bindings carry + no httpx dependency.""" + import json as _json + + import httpx + + headers = {"accept": "application/json"} + raw_headers = os.environ.get(_AUTH_HEADERS_ENV) + if raw_headers: + try: + extra = _json.loads(raw_headers) + except _json.JSONDecodeError: + raise ParamsSourceError(f"{_AUTH_HEADERS_ENV} is not valid JSON") from None + if isinstance(extra, dict): + headers.update({str(k): str(v) for k, v in extra.items()}) + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url, headers=headers) + response.raise_for_status() + return response.json() + except httpx.HTTPError as exc: + # Covers connection/timeout (RequestError) and non-2xx (HTTPStatusError) so + # the route handler's ParamsSourceError catch logs + returns a clean 500. + raise ParamsSourceError(f"params source request failed: {exc}") from exc + except ValueError as exc: # json.JSONDecodeError subclasses ValueError + raise ParamsSourceError(f"params source returned invalid JSON: {exc}") from exc + + +async def resolve_binding_params( + binding: ChannelBinding, *, fetch: ParamsFetcher | None = None +) -> ChannelBinding: + """Populate ``binding.params`` from its ``params_source`` when set. + + Precedence: a binding with a ``params_source`` fetches its params remotely. A + binding with only inline ``params`` is returned untouched. Any ``task_metadata`` + the source returns is captured for stamping on the task. Mutates and returns the + same binding. + """ + if not binding.params_source: + return binding + + do_fetch = fetch or _default_fetch + payload = await do_fetch(binding.params_source) + if not isinstance(payload, dict): + raise ParamsSourceError("params source returned a non-object response") + + metadata = payload.get("task_metadata") + if isinstance(metadata, dict): + binding.extra_task_metadata = {str(k): str(v) for k, v in metadata.items()} + + params = payload.get("params") + if isinstance(params, dict): + binding.params = params + else: + # Lenient: a bare object with no "params" key is the params dict itself — + # minus a top-level task_metadata, which is captured above, not a param. + binding.params = {k: v for k, v in payload.items() if k != "task_metadata"} + + logger.info("[channels] resolved remote params for agent %s", binding.agent_name) + return binding diff --git a/agentex/src/domain/channels/router.py b/agentex/src/domain/channels/router.py index f07c5811..d4766767 100644 --- a/agentex/src/domain/channels/router.py +++ b/agentex/src/domain/channels/router.py @@ -62,17 +62,24 @@ async def dispatch( session_key = inbound.session_key(binding.agent_name) headers = binding.forward_headers or None + task_metadata: dict[str, str] = { + "channel": inbound.channel, + "route_id": inbound.route_id, + "peer_id": inbound.peer_id, + "sender_id": inbound.sender_id, + } + # Any extra metadata the binding carries (e.g. returned by a remote params + # source) is stamped for traceability — but never overrides the canonical + # channel/route_id/peer_id/sender_id fields (setdefault, not update). + for key, value in binding.extra_task_metadata.items(): + task_metadata.setdefault(key, value) + task = await self._acp.handle_rpc_request( method=AgentRPCMethod.TASK_CREATE, params=CreateTaskRequestEntity( name=session_key, params=binding.params, - task_metadata={ - "channel": inbound.channel, - "route_id": inbound.route_id, - "peer_id": inbound.peer_id, - "sender_id": inbound.sender_id, - }, + task_metadata=task_metadata, ), agent_name=binding.agent_name, request_headers=headers, diff --git a/agentex/tests/unit/domain/channels/test_github_pr.py b/agentex/tests/unit/domain/channels/test_github_pr.py new file mode 100644 index 00000000..8dba2ba9 --- /dev/null +++ b/agentex/tests/unit/domain/channels/test_github_pr.py @@ -0,0 +1,106 @@ +"""Unit tests for the GitHub/Gitea PR channel shaper.""" + +from __future__ import annotations + +import json + +import pytest +from src.domain.channels.github_pr import GitHubPRChannel + +_CHANNEL = GitHubPRChannel() + + +def _pr_body(**pr_overrides) -> dict: + pr = { + "number": 42, + "title": "Add config-by-id", + "body": "This PR wires config-by-id into the channel binding.", + "html_url": "https://example.com/org/repo/pull/42", + } + pr.update(pr_overrides) + return { + "action": "opened", + "repository": {"full_name": "org/repo"}, + "sender": {"login": "octocat"}, + "pull_request": pr, + } + + +class TestToInbound: + def test_shapes_pr_into_clean_prompt(self): + inbound = _CHANNEL.to_inbound("pr-review", _pr_body()) + assert inbound.channel == "github_pr" + assert "Pull request org/repo#42: Add config-by-id" in inbound.text + assert "Action: opened" in inbound.text + assert "https://example.com/org/repo/pull/42" in inbound.text + assert "This PR wires config-by-id" in inbound.text + + def test_peer_id_is_repo_and_number_so_repeat_events_fold(self): + # Same PR, different actions -> same peer_id -> same task (get-or-create). + opened = _CHANNEL.to_inbound("pr-review", _pr_body(number=7)) + synced = _CHANNEL.to_inbound( + "pr-review", {**_pr_body(number=7), "action": "synchronize"} + ) + assert opened.peer_id == "org/repo#7" + assert opened.peer_id == synced.peer_id + + def test_sender_is_the_pr_actor(self): + inbound = _CHANNEL.to_inbound("pr-review", _pr_body()) + assert inbound.sender_id == "octocat" + + def test_inline_diff_is_included_when_present(self): + body = _pr_body() + body["diff"] = "diff --git a/x b/x\n+added line" + inbound = _CHANNEL.to_inbound("pr-review", body) + assert "Diff:" in inbound.text + assert "+added line" in inbound.text + + def test_diff_is_truncated(self): + body = _pr_body() + body["diff"] = "x" * 50000 + inbound = _CHANNEL.to_inbound("pr-review", body) + # 30k cap + the surrounding prompt scaffolding, well under the raw 50k. + assert len(inbound.text) < 40000 + + def test_non_pr_payload_falls_back_to_generic_rendering(self): + # A ping / non-PR event has no pull_request; PR shaping is skipped and the + # generic webhook rendering applies (raw JSON), not a "Pull request ..." prompt. + body = {"zen": "Keep it logically awesome.", "hook_id": 1} + inbound = _CHANNEL.to_inbound("pr-review", body) + assert "Keep it logically awesome." in inbound.text + assert "Pull request" not in inbound.text + + def test_missing_repo_full_name_still_keys_on_number(self): + body = _pr_body() + body.pop("repository") + inbound = _CHANNEL.to_inbound("pr-review", body) + assert inbound.peer_id == "pr#42" + + +class TestAuthInheritedFromWebhook: + def test_uses_hmac_sha256_auth(self): + # GitHubPRChannel reuses WebhookChannel's sha256= HMAC verification. + import hashlib + import hmac + + from src.domain.channels.base import ChannelBinding + + secret = "topsecret" + raw = json.dumps(_pr_body()).encode() + sig = "sha256=" + hmac.new(secret.encode(), raw, hashlib.sha256).hexdigest() + + class _Req: + def __init__(self, headers): + self.headers = headers + + binding = ChannelBinding(secret=secret, agent_name="review-agent", channel="github_pr") + good = _Req({"x-hub-signature-256": sig}) + bad = _Req({"x-hub-signature-256": "sha256=deadbeef"}) + assert _CHANNEL.authenticate(binding, good, raw) is True + assert _CHANNEL.authenticate(binding, bad, raw) is False + + +@pytest.fixture(autouse=True) +def _clear_routes_env(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("CHANNELS_WEBHOOK_ROUTES", raising=False) + yield diff --git a/agentex/tests/unit/domain/channels/test_params_source.py b/agentex/tests/unit/domain/channels/test_params_source.py new file mode 100644 index 00000000..67119a00 --- /dev/null +++ b/agentex/tests/unit/domain/channels/test_params_source.py @@ -0,0 +1,214 @@ +"""Unit tests for remote params resolution + task_metadata stamping. + +The channel layer stays generic: a binding either carries inline params or a +`params_source` URL it GETs and forwards verbatim. These tests cover that resolution, +the route-store parsing, and metadata stamping on dispatch. +""" + +from __future__ import annotations + +import json + +import pytest +from src.domain.channels.base import ChannelBinding, InboundMessage +from src.domain.channels.params_source import ( + ParamsSourceError, + resolve_binding_params, +) +from src.domain.channels.router import ChannelRouter +from src.domain.entities.agents import ACPType +from src.domain.entities.agents_rpc import AgentRPCMethod + + +class TestResolveBindingParams: + async def test_remote_source_populates_params_and_metadata(self): + binding = ChannelBinding( + secret="s", agent_name="agent-1", params_source="https://host/resolve/abc" + ) + captured: list[str] = [] + + async def fake_fetch(url: str) -> dict: + captured.append(url) + return { + "params": {"system_prompt": "from source", "model": "some-model"}, + "task_metadata": {"trace": "xyz"}, + } + + resolved = await resolve_binding_params(binding, fetch=fake_fetch) + + assert captured == ["https://host/resolve/abc"] + assert resolved.params == { + "system_prompt": "from source", + "model": "some-model", + } + assert resolved.extra_task_metadata == {"trace": "xyz"} + + async def test_bare_object_response_is_treated_as_params(self): + binding = ChannelBinding( + secret="s", agent_name="agent-1", params_source="https://host/resolve/abc" + ) + + async def fake_fetch(_url: str) -> dict: + return {"system_prompt": "bare", "model": "m"} + + resolved = await resolve_binding_params(binding, fetch=fake_fetch) + assert resolved.params == {"system_prompt": "bare", "model": "m"} + assert resolved.extra_task_metadata == {} + + async def test_bare_object_still_captures_task_metadata_and_strips_it(self): + # A source that returns task_metadata without a "params" wrapper: the metadata + # is stamped, not leaked into params. + binding = ChannelBinding( + secret="s", agent_name="agent-1", params_source="https://host/resolve/abc" + ) + + async def fake_fetch(_url: str) -> dict: + return {"system_prompt": "bare", "task_metadata": {"trace": "xyz"}} + + resolved = await resolve_binding_params(binding, fetch=fake_fetch) + assert resolved.params == {"system_prompt": "bare"} + assert resolved.extra_task_metadata == {"trace": "xyz"} + + async def test_inline_binding_is_left_untouched_and_does_not_fetch(self): + binding = ChannelBinding( + secret="s", agent_name="agent-1", params={"system_prompt": "one-off"} + ) + + async def fail_fetch(_url: str) -> dict: + raise AssertionError("fetch must not be called for inline bindings") + + resolved = await resolve_binding_params(binding, fetch=fail_fetch) + assert resolved.params == {"system_prompt": "one-off"} + + async def test_non_object_response_raises(self): + binding = ChannelBinding( + secret="s", agent_name="agent-1", params_source="https://host/resolve/abc" + ) + + async def fake_fetch(_url: str): + return ["not", "an", "object"] + + with pytest.raises(ParamsSourceError): + await resolve_binding_params(binding, fetch=fake_fetch) + + +class TestWebhookBindingParse: + def test_parses_params_source_and_channel(self, monkeypatch: pytest.MonkeyPatch): + from src.api.routes.channels import _webhook_binding + + monkeypatch.setenv( + "CHANNELS_WEBHOOK_ROUTES", + json.dumps( + { + "pr-review": { + "secret": "shh", + "agent_name": "agent-1", + "channel": "github_pr", + "params_source": "https://host/resolve/abc", + } + } + ), + ) + binding = _webhook_binding("pr-review") + assert binding is not None + assert binding.channel == "github_pr" + assert binding.params_source == "https://host/resolve/abc" + assert binding.params == {} + + def test_inline_params_route_has_no_source(self, monkeypatch: pytest.MonkeyPatch): + from src.api.routes.channels import _webhook_binding + + monkeypatch.setenv( + "CHANNELS_WEBHOOK_ROUTES", + json.dumps( + { + "demo": { + "secret": "shh", + "agent_name": "agent-1", + "params": {"system_prompt": "hi"}, + } + } + ), + ) + binding = _webhook_binding("demo") + assert binding is not None + assert binding.params_source is None + assert binding.params == {"system_prompt": "hi"} + + +class _FakeACP: + """Records handle_rpc_request calls; returns a task on create, [] on message/send.""" + + def __init__(self, task_id: str = "task-1"): + self.calls: list[tuple] = [] + self._task_id = task_id + + async def handle_rpc_request(self, *, method, params, agent_name, request_headers): + self.calls.append((method, params)) + if method == AgentRPCMethod.TASK_CREATE: + return type("_Task", (), {"id": self._task_id})() + return [] + + def task_metadata_of_create(self) -> dict: + for method, params in self.calls: + if method == AgentRPCMethod.TASK_CREATE: + return params.task_metadata + raise AssertionError("no TASK_CREATE call recorded") + + +class TestDispatchStampsExtraMetadata: + async def test_extra_task_metadata_is_stamped_on_task(self): + acp = _FakeACP() + router = ChannelRouter(acp, task_message_service=object()) + binding = ChannelBinding( + secret="s", + agent_name="agent-1", + params={"system_prompt": "x"}, + extra_task_metadata={"trace": "xyz"}, + ) + inbound = InboundMessage( + text="hi", channel="webhook", route_id="r", peer_id="r" + ) + + await router.dispatch(inbound, binding, ACPType.SYNC) + + metadata = acp.task_metadata_of_create() + assert metadata["trace"] == "xyz" + assert metadata["channel"] == "webhook" + assert metadata["route_id"] == "r" + + async def test_binding_without_extra_metadata_omits_it(self): + acp = _FakeACP() + router = ChannelRouter(acp, task_message_service=object()) + binding = ChannelBinding( + secret="s", agent_name="agent-1", params={"system_prompt": "x"} + ) + inbound = InboundMessage( + text="hi", channel="webhook", route_id="r", peer_id="r" + ) + + await router.dispatch(inbound, binding, ACPType.SYNC) + + metadata = acp.task_metadata_of_create() + assert "trace" not in metadata + assert set(metadata) == {"channel", "route_id", "peer_id", "sender_id"} + + async def test_extra_metadata_cannot_override_canonical_keys(self): + acp = _FakeACP() + router = ChannelRouter(acp, task_message_service=object()) + # A malicious/misconfigured source tries to spoof the canonical channel field. + binding = ChannelBinding( + secret="s", + agent_name="agent-1", + params={"system_prompt": "x"}, + extra_task_metadata={"channel": "spoofed", "trace": "ok"}, + ) + inbound = InboundMessage( + text="hi", channel="webhook", route_id="r", peer_id="r" + ) + + await router.dispatch(inbound, binding, ACPType.SYNC) + + metadata = acp.task_metadata_of_create() + assert metadata["channel"] == "webhook" # canonical wins, not "spoofed" + assert metadata["trace"] == "ok" # non-colliding extras still stamped