From 9416577e85d0ff3ec2795e975bac89f55794e6e6 Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 12:30:48 -0600 Subject: [PATCH 1/7] feat: add tool requirements policy enforcement system Introduces a PreToolUse hook-based policy system that evaluates tool calls against RFC 2119-style requirements defined in .deepwork/tool_requirements/*.yml. Policies are checked via an HTTP sidecar server (spawned alongside the MCP server) using Haiku for semantic evaluation. Failed checks can be appealed via a new appeal_tool_requirement MCP tool. Approvals are cached with a 1-hour TTL. Key features: - Policy files with tools, match (param regex), requirements, extends (inheritance) - no_exception rules that cannot be appealed - Fail-closed: hook denies if MCP sidecar is unreachable - Loop prevention: appeal tool calls skip the hook - Multi-instance support via PID-keyed + session-keyed port files - Evaluator encapsulated behind ABC for future swap to direct API calls Co-Authored-By: Claude Opus 4.6 (1M context) --- plugins/claude/hooks/hooks.json | 11 + plugins/claude/hooks/tool_requirements.sh | 15 + src/deepwork/cli/serve.py | 23 ++ src/deepwork/hooks/tool_requirements.py | 126 ++++++++ src/deepwork/jobs/mcp/server.py | 94 ++++++ .../schemas/tool_requirements_schema.json | 62 ++++ src/deepwork/tool_requirements/__init__.py | 1 + src/deepwork/tool_requirements/cache.py | 58 ++++ src/deepwork/tool_requirements/config.py | 106 +++++++ src/deepwork/tool_requirements/discovery.py | 104 +++++++ src/deepwork/tool_requirements/engine.py | 186 ++++++++++++ src/deepwork/tool_requirements/evaluator.py | 256 ++++++++++++++++ src/deepwork/tool_requirements/matcher.py | 78 +++++ src/deepwork/tool_requirements/sidecar.py | 274 ++++++++++++++++++ tests/unit/test_tool_requirements_hook.py | 126 ++++++++ tests/unit/tool_requirements/__init__.py | 0 tests/unit/tool_requirements/test_cache.py | 56 ++++ tests/unit/tool_requirements/test_config.py | 119 ++++++++ .../unit/tool_requirements/test_discovery.py | 116 ++++++++ tests/unit/tool_requirements/test_engine.py | 205 +++++++++++++ .../unit/tool_requirements/test_evaluator.py | 102 +++++++ tests/unit/tool_requirements/test_matcher.py | 81 ++++++ 22 files changed, 2199 insertions(+) create mode 100755 plugins/claude/hooks/tool_requirements.sh create mode 100644 src/deepwork/hooks/tool_requirements.py create mode 100644 src/deepwork/schemas/tool_requirements_schema.json create mode 100644 src/deepwork/tool_requirements/__init__.py create mode 100644 src/deepwork/tool_requirements/cache.py create mode 100644 src/deepwork/tool_requirements/config.py create mode 100644 src/deepwork/tool_requirements/discovery.py create mode 100644 src/deepwork/tool_requirements/engine.py create mode 100644 src/deepwork/tool_requirements/evaluator.py create mode 100644 src/deepwork/tool_requirements/matcher.py create mode 100644 src/deepwork/tool_requirements/sidecar.py create mode 100644 tests/unit/test_tool_requirements_hook.py create mode 100644 tests/unit/tool_requirements/__init__.py create mode 100644 tests/unit/tool_requirements/test_cache.py create mode 100644 tests/unit/tool_requirements/test_config.py create mode 100644 tests/unit/tool_requirements/test_discovery.py create mode 100644 tests/unit/tool_requirements/test_engine.py create mode 100644 tests/unit/tool_requirements/test_evaluator.py create mode 100644 tests/unit/tool_requirements/test_matcher.py diff --git a/plugins/claude/hooks/hooks.json b/plugins/claude/hooks/hooks.json index 80ebdff4..96de9734 100644 --- a/plugins/claude/hooks/hooks.json +++ b/plugins/claude/hooks/hooks.json @@ -1,5 +1,16 @@ { "hooks": { + "PreToolUse": [ + { + "matcher": "", + "hooks": [ + { + "type": "command", + "command": "${CLAUDE_PLUGIN_ROOT}/hooks/tool_requirements.sh" + } + ] + } + ], "SessionStart": [ { "matcher": "", diff --git a/plugins/claude/hooks/tool_requirements.sh b/plugins/claude/hooks/tool_requirements.sh new file mode 100755 index 00000000..2eb4bc05 --- /dev/null +++ b/plugins/claude/hooks/tool_requirements.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +# tool_requirements.sh - PreToolUse hook for tool requirements enforcement +# +# Fires before every tool call. Delegates to the Python hook which contacts +# the MCP sidecar to check policies. +# +# Input (stdin): JSON from Claude Code PreToolUse hook +# Output (stdout): JSON with hookSpecificOutput.permissionDecision +# Exit codes: +# 0 - Always (decision encoded in JSON output) + +INPUT=$(cat) +export DEEPWORK_HOOK_PLATFORM="claude" +echo "${INPUT}" | deepwork hook tool_requirements +exit $? diff --git a/src/deepwork/cli/serve.py b/src/deepwork/cli/serve.py index 77307865..de9ac47d 100644 --- a/src/deepwork/cli/serve.py +++ b/src/deepwork/cli/serve.py @@ -122,6 +122,9 @@ def _serve_mcp( "# Ignore everything in this directory\n*\n# But keep this .gitignore\n!.gitignore\n" ) + # Start tool requirements sidecar (if policies exist) + _start_tool_requirements_sidecar(project_path) + # Create and run server from deepwork.jobs.mcp.server import create_server @@ -135,3 +138,23 @@ def _serve_mcp( server.run(transport="stdio") else: server.run(transport="sse", port=port) + + +def _start_tool_requirements_sidecar(project_path: Path) -> None: + """Start the tool requirements sidecar if policy files exist.""" + policy_dir = project_path / ".deepwork" / "tool_requirements" + if not policy_dir.is_dir(): + return + if not any(policy_dir.glob("*.yml")): + return + + try: + from deepwork.tool_requirements.sidecar import start_sidecar + + start_sidecar(project_path) + except Exception: + import logging + + logging.getLogger("deepwork.tool_requirements").warning( + "Failed to start tool requirements sidecar", exc_info=True + ) diff --git a/src/deepwork/hooks/tool_requirements.py b/src/deepwork/hooks/tool_requirements.py new file mode 100644 index 00000000..b34066ac --- /dev/null +++ b/src/deepwork/hooks/tool_requirements.py @@ -0,0 +1,126 @@ +"""PreToolUse hook for tool requirements policy enforcement. + +Fires before every tool call. Contacts the MCP sidecar server to check +whether the call complies with policies defined in +.deepwork/tool_requirements/*.yml. + +Fail-closed: if the sidecar is unreachable, the hook denies the call +with a message to restart the MCP server. +""" + +from __future__ import annotations + +import http.client +import json +import os +import sys +from pathlib import Path +from typing import Any + +from deepwork.hooks.wrapper import ( + HookInput, + HookOutput, + NormalizedEvent, + Platform, + output_hook_error, + run_hook, +) +from deepwork.tool_requirements.sidecar import discover_sidecar + +# Tool name substrings to skip (loop prevention) +_SKIP_TOOLS = ("appeal_tool_requirement",) + + +def tool_requirements_hook(hook_input: HookInput) -> HookOutput: + """Pre-tool hook: check tool call against requirement policies.""" + if hook_input.event != NormalizedEvent.BEFORE_TOOL: + return HookOutput() + + # Loop prevention: skip the appeal MCP tool itself + raw_tool = hook_input.raw_input.get("tool_name", "") + for skip in _SKIP_TOOLS: + if skip in raw_tool: + return HookOutput() + + cwd = hook_input.cwd or os.getcwd() + session_id = hook_input.session_id or "" + + # Discover sidecar + sidecar = discover_sidecar(Path(cwd), session_id) + if sidecar is None: + return _deny( + "DeepWork Tool Requirements: MCP server is not running. " + "The tool_requirements system requires the MCP server to be active. " + "Please restart the MCP server." + ) + + # Send check request to sidecar + try: + response = _http_post(sidecar["port"], "/check", { + "tool_name": hook_input.tool_name, + "tool_input": hook_input.tool_input, + "raw_tool_name": raw_tool, + "session_id": session_id, + }) + except Exception as e: + return _deny( + f"DeepWork Tool Requirements: Failed to reach MCP server sidecar: {e}. " + "Please restart the MCP server." + ) + + if response.get("decision") == "allow": + return HookOutput() + + if response.get("decision") == "deny": + reason = response.get("reason", "Tool call blocked by policy") + return _deny(reason) + + # Unexpected response — allow (fail-open only for malformed responses + # from an actually-running sidecar, not for missing sidecars) + return HookOutput() + + +def _deny(reason: str) -> HookOutput: + """Create a deny output for PreToolUse with proper Claude Code format.""" + return HookOutput( + raw_output={ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": reason, + } + } + ) + + +def _http_post(port: int, path: str, body: dict[str, Any]) -> dict[str, Any]: + """Send an HTTP POST to the sidecar on localhost.""" + conn = http.client.HTTPConnection("127.0.0.1", port, timeout=30) + try: + payload = json.dumps(body).encode("utf-8") + conn.request( + "POST", + path, + body=payload, + headers={"Content-Type": "application/json"}, + ) + response = conn.getresponse() + data = response.read() + result: dict[str, Any] = json.loads(data) + return result + finally: + conn.close() + + +def main() -> int: + """Entry point for the hook CLI.""" + platform = Platform(os.environ.get("DEEPWORK_HOOK_PLATFORM", "claude")) + return run_hook(tool_requirements_hook, platform) + + +if __name__ == "__main__": + try: + sys.exit(main()) + except Exception as e: + output_hook_error(e, context="tool_requirements hook") + sys.exit(0) diff --git a/src/deepwork/jobs/mcp/server.py b/src/deepwork/jobs/mcp/server.py index 910166b5..af43fac9 100644 --- a/src/deepwork/jobs/mcp/server.py +++ b/src/deepwork/jobs/mcp/server.py @@ -149,6 +149,21 @@ def _log_tool_call( log_data["params"] = params logger.info("MCP tool call: %s", log_data) + # Track whether session has been registered for tool requirements sidecar + _registered_sessions: set[str] = set() + + def _maybe_register_session(session_id: str | None) -> None: + """Register session with the tool requirements sidecar on first tool call.""" + if not session_id or session_id in _registered_sessions: + return + _registered_sessions.add(session_id) + try: + from deepwork.tool_requirements.sidecar import register_session + + register_session(project_path, session_id) + except Exception: + pass # Best-effort — sidecar may not be running + @mcp.tool( description=( "List all available DeepWork workflows. " @@ -186,6 +201,7 @@ async def start_workflow( agent_id: str | None = None, ) -> dict[str, Any]: """Start a workflow and get first step instructions.""" + _maybe_register_session(session_id) _log_tool_call( "start_workflow", { @@ -505,6 +521,84 @@ async def mark_review_as_passed(review_id: str, ctx: Context) -> str: except ValueError as e: return f"Validation error: {e}" + # ---- Tool Requirements: appeal tool ---- + + @mcp.tool( + description=( + "Appeal a tool requirement policy denial. When a tool call is blocked " + "by a tool requirement policy, call this to appeal specific failed " + "checks by providing justifications. " + "Required: tool_name (the normalized tool name that was blocked), " + "tool_input (the exact tool_input that was blocked), " + "policy_justification (dict mapping each failed check name to a " + "justification string explaining why the check should pass). " + "Optional: session_id (CLAUDE_CODE_SESSION_ID). " + "Some checks are marked no_exception and cannot be appealed. " + "If the appeal succeeds, the tool call is cached as approved and " + "you can retry the original tool call." + ) + ) + async def appeal_tool_requirement( + tool_name: str, + tool_input: dict[str, Any], + policy_justification: dict[str, str], + ctx: Context, + session_id: str | None = None, + ) -> dict[str, Any]: + """Appeal a tool requirement denial with justifications.""" + _log_tool_call( + "appeal_tool_requirement", + { + "tool_name": tool_name, + "justification_keys": list(policy_justification.keys()), + }, + session_id=session_id, + ) + _maybe_register_session(session_id) + + root = await root_resolver.get_root(ctx) + + # Delegate to sidecar (same process) or engine directly + try: + from deepwork.tool_requirements.sidecar import discover_sidecar + + sidecar = discover_sidecar(root, session_id or "") + if sidecar is None: + return { + "passed": False, + "reason": "Tool requirements sidecar is not running. " + "Please restart the MCP server.", + } + + import http.client + import json as json_mod + + conn = http.client.HTTPConnection( + "127.0.0.1", sidecar["port"], timeout=60 + ) + try: + payload = json_mod.dumps({ + "tool_name": tool_name, + "tool_input": tool_input, + "policy_justification": policy_justification, + }).encode("utf-8") + conn.request( + "POST", + "/appeal", + body=payload, + headers={"Content-Type": "application/json"}, + ) + response = conn.getresponse() + return json_mod.loads(response.read()) + finally: + conn.close() + except Exception as e: + logger.exception("Error in appeal_tool_requirement") + return { + "passed": False, + "reason": f"Appeal failed: {e}", + } + return mcp diff --git a/src/deepwork/schemas/tool_requirements_schema.json b/src/deepwork/schemas/tool_requirements_schema.json new file mode 100644 index 00000000..3e898b23 --- /dev/null +++ b/src/deepwork/schemas/tool_requirements_schema.json @@ -0,0 +1,62 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Tool Requirements Policy", + "description": "Schema for .deepwork/tool_requirements/*.yml policy files that define RFC 2119-style rules for AI agent tool calls.", + "type": "object", + "required": ["tools", "requirements"], + "additionalProperties": false, + "properties": { + "summary": { + "type": "string", + "description": "Human-readable summary of what this policy enforces." + }, + "tools": { + "type": "array", + "description": "Normalized tool names (shell, write_file, edit_file, etc.) or MCP tool names (mcp__server__tool) this policy applies to.", + "items": { + "type": "string" + }, + "minItems": 1 + }, + "match": { + "type": "object", + "description": "Optional parameter-level filtering. Keys are tool_input parameter names, values are regex patterns. Policy only applies when at least one pattern matches.", + "patternProperties": { + "^[a-zA-Z0-9_-]+$": { + "type": "string" + } + }, + "additionalProperties": false + }, + "extends": { + "type": "array", + "description": "List of policy file stems to inherit requirements from.", + "items": { + "type": "string" + } + }, + "requirements": { + "type": "object", + "description": "RFC 2119 keyed requirements. Keys are requirement identifiers, values define the rule and exception policy.", + "patternProperties": { + "^[a-zA-Z0-9_-]+$": { + "type": "object", + "required": ["rule"], + "additionalProperties": false, + "properties": { + "rule": { + "type": "string", + "description": "RFC 2119 statement (using MUST, SHOULD, MAY, etc.)." + }, + "no_exception": { + "type": "boolean", + "description": "If true, this requirement cannot be appealed. Defaults to false.", + "default": false + } + } + } + }, + "additionalProperties": false + } + } +} diff --git a/src/deepwork/tool_requirements/__init__.py b/src/deepwork/tool_requirements/__init__.py new file mode 100644 index 00000000..471880cf --- /dev/null +++ b/src/deepwork/tool_requirements/__init__.py @@ -0,0 +1 @@ +"""Tool Requirements — policy enforcement for AI agent tool calls.""" diff --git a/src/deepwork/tool_requirements/cache.py b/src/deepwork/tool_requirements/cache.py new file mode 100644 index 00000000..49ce5001 --- /dev/null +++ b/src/deepwork/tool_requirements/cache.py @@ -0,0 +1,58 @@ +"""In-memory TTL cache for tool requirements pass decisions. + +Caches approved tool calls so that repeated identical calls +don't require re-evaluation within the TTL window. +""" + +from __future__ import annotations + +import hashlib +import json +import time +from dataclasses import dataclass +from typing import Any + + +@dataclass +class _CacheEntry: + """A cached approval decision.""" + + approved_at: float + ttl_seconds: float = 3600.0 # 1 hour + + @property + def is_valid(self) -> bool: + return (time.time() - self.approved_at) < self.ttl_seconds + + +class ToolRequirementsCache: + """In-memory TTL cache for approved tool calls.""" + + def __init__(self, ttl_seconds: float = 3600.0) -> None: + self._cache: dict[str, _CacheEntry] = {} + self._ttl = ttl_seconds + + def make_key(self, tool_name: str, tool_input: dict[str, Any]) -> str: + """Create a deterministic cache key from tool name and input.""" + content = json.dumps({"tool": tool_name, "input": tool_input}, sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest() + + def is_approved(self, key: str) -> bool: + """Check if a tool call has a valid cached approval.""" + entry = self._cache.get(key) + if entry is not None and entry.is_valid: + return True + if entry is not None: + del self._cache[key] + return False + + def approve(self, key: str) -> None: + """Cache an approval for a tool call.""" + self._cache[key] = _CacheEntry(approved_at=time.time(), ttl_seconds=self._ttl) + + def clear(self) -> None: + """Clear all cached entries.""" + self._cache.clear() + + def __len__(self) -> int: + return len(self._cache) diff --git a/src/deepwork/tool_requirements/config.py b/src/deepwork/tool_requirements/config.py new file mode 100644 index 00000000..a5e1f174 --- /dev/null +++ b/src/deepwork/tool_requirements/config.py @@ -0,0 +1,106 @@ +"""Configuration parsing for tool requirements policy files. + +Parses .deepwork/tool_requirements/*.yml files into ToolPolicy dataclasses, +validating against the JSON schema. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from deepwork.utils.validation import ValidationError, validate_against_schema +from deepwork.utils.yaml_utils import YAMLError, load_yaml + +# Load the JSON schema once at module level +_SCHEMA_PATH = Path(__file__).parent.parent / "schemas" / "tool_requirements_schema.json" +_SCHEMA: dict[str, Any] | None = None + + +def _get_schema() -> dict[str, Any]: + global _SCHEMA + if _SCHEMA is None: + _SCHEMA = json.loads(_SCHEMA_PATH.read_text(encoding="utf-8")) + return _SCHEMA + + +class ToolRequirementsError(Exception): + """Exception raised for tool requirements configuration errors.""" + + pass + + +@dataclass +class Requirement: + """A single RFC 2119 requirement within a policy.""" + + rule: str + no_exception: bool = False + + +@dataclass +class ToolPolicy: + """A parsed tool requirements policy definition.""" + + name: str + source_path: Path + summary: str = "" + tools: list[str] = field(default_factory=list) + match: dict[str, str] = field(default_factory=dict) + requirements: dict[str, Requirement] = field(default_factory=dict) + extends: list[str] = field(default_factory=list) + + +def parse_policy_file(filepath: Path) -> ToolPolicy: + """Parse a tool requirements YAML file into a ToolPolicy. + + Args: + filepath: Path to the YAML file. + + Returns: + A ToolPolicy object. + + Raises: + ToolRequirementsError: If the file cannot be parsed or fails validation. + """ + try: + data = load_yaml(filepath) + except YAMLError as e: + raise ToolRequirementsError(f"Failed to parse {filepath}: {e}") from e + + if data is None: + raise ToolRequirementsError(f"File not found: {filepath}") + + if not data: + raise ToolRequirementsError(f"Empty policy file: {filepath}") + + try: + validate_against_schema(data, _get_schema()) + except ValidationError as e: + raise ToolRequirementsError(f"Schema validation failed for {filepath}: {e}") from e + + return _build_policy(data, filepath) + + +def _build_policy(data: dict[str, Any], filepath: Path) -> ToolPolicy: + """Build a ToolPolicy from validated YAML data.""" + name = filepath.stem + + requirements: dict[str, Requirement] = {} + for req_id, req_data in data.get("requirements", {}).items(): + requirements[req_id] = Requirement( + rule=req_data["rule"], + no_exception=req_data.get("no_exception", False), + ) + + return ToolPolicy( + name=name, + source_path=filepath, + summary=data.get("summary", ""), + tools=data.get("tools", []), + match=data.get("match", {}), + requirements=requirements, + extends=data.get("extends", []), + ) diff --git a/src/deepwork/tool_requirements/discovery.py b/src/deepwork/tool_requirements/discovery.py new file mode 100644 index 00000000..99ac7511 --- /dev/null +++ b/src/deepwork/tool_requirements/discovery.py @@ -0,0 +1,104 @@ +"""Discovery and loading for tool requirements policy files. + +Scans .deepwork/tool_requirements/ for YAML policy files, parses them, +and resolves inheritance via the extends field. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +from deepwork.tool_requirements.config import ( + Requirement, + ToolPolicy, + ToolRequirementsError, + parse_policy_file, +) + +logger = logging.getLogger("deepwork.tool_requirements") + + +def load_all_policies(project_root: Path) -> list[ToolPolicy]: + """Load all tool requirement policies from the project. + + Scans .deepwork/tool_requirements/ for *.yml files, parses each, + and resolves inheritance. Skips files that fail to parse (logged as warnings). + + Args: + project_root: Path to the project root directory. + + Returns: + List of parsed and inheritance-resolved ToolPolicy objects. + """ + policy_dir = project_root / ".deepwork" / "tool_requirements" + if not policy_dir.is_dir(): + return [] + + policies: list[ToolPolicy] = [] + for yml_path in sorted(policy_dir.glob("*.yml")): + try: + policy = parse_policy_file(yml_path) + policies.append(policy) + except ToolRequirementsError as e: + logger.warning("Skipping policy %s: %s", yml_path.name, e) + + if policies: + policies = _resolve_inheritance(policies) + + return policies + + +def _resolve_inheritance(policies: list[ToolPolicy]) -> list[ToolPolicy]: + """Resolve extends inheritance for all policies. + + Parent requirements are merged into children. Child requirements + override parent requirements on key conflict. + + Args: + policies: List of parsed policies. + + Returns: + List of policies with inherited requirements merged in. + """ + by_name: dict[str, ToolPolicy] = {p.name: p for p in policies} + resolved: dict[str, ToolPolicy] = {} + + def resolve(name: str, visited: set[str]) -> ToolPolicy: + if name in resolved: + return resolved[name] + + if name not in by_name: + logger.warning("Policy '%s' extends unknown policy '%s'", name, name) + return by_name.get(name, ToolPolicy(name=name, source_path=Path())) + + policy = by_name[name] + + if name in visited: + logger.warning("Circular inheritance detected for policy '%s'", name) + return policy + + visited.add(name) + + # Merge parent requirements (parent first, child overrides) + merged_requirements: dict[str, Requirement] = {} + for parent_name in policy.extends: + if parent_name not in by_name: + logger.warning( + "Policy '%s' extends unknown policy '%s'", policy.name, parent_name + ) + continue + parent = resolve(parent_name, visited) + merged_requirements.update(parent.requirements) + + # Child requirements override parent + merged_requirements.update(policy.requirements) + policy.requirements = merged_requirements + + resolved[name] = policy + return policy + + for policy in policies: + resolve(policy.name, set()) + + return policies diff --git a/src/deepwork/tool_requirements/engine.py b/src/deepwork/tool_requirements/engine.py new file mode 100644 index 00000000..7616b304 --- /dev/null +++ b/src/deepwork/tool_requirements/engine.py @@ -0,0 +1,186 @@ +"""Tool Requirements Engine — orchestrates check and appeal flows. + +This is the central coordinator: loads policies, matches them to tool calls, +evaluates requirements via the LLM evaluator, and manages the cache. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from deepwork.tool_requirements.cache import ToolRequirementsCache +from deepwork.tool_requirements.config import Requirement +from deepwork.tool_requirements.discovery import load_all_policies +from deepwork.tool_requirements.evaluator import RequirementEvaluator +from deepwork.tool_requirements.matcher import match_policies, merge_requirements + +logger = logging.getLogger("deepwork.tool_requirements") + + +@dataclass +class CheckResult: + """Result of checking a tool call against policies.""" + + allowed: bool + reason: str + failed_checks: list[str] = field(default_factory=list) + + +@dataclass +class AppealResult: + """Result of an appeal attempt.""" + + passed: bool + reason: str + no_exception_blocked: list[str] = field(default_factory=list) + + +class ToolRequirementsEngine: + """Orchestrates tool requirement checking and appeals.""" + + def __init__( + self, + project_root: Path, + evaluator: RequirementEvaluator, + cache: ToolRequirementsCache | None = None, + ) -> None: + self.project_root = project_root + self.evaluator = evaluator + self.cache = cache or ToolRequirementsCache() + self._policies = load_all_policies(project_root) + + def reload_policies(self) -> None: + """Reload policies from disk.""" + self._policies = load_all_policies(self.project_root) + + async def check(self, tool_name: str, tool_input: dict[str, Any]) -> CheckResult: + """Check a tool call against all matching policies. + + Args: + tool_name: Normalized tool name. + tool_input: Tool call parameters. + + Returns: + CheckResult with allowed=True if the call is permitted. + """ + # Check cache first + cache_key = self.cache.make_key(tool_name, tool_input) + if self.cache.is_approved(cache_key): + return CheckResult(allowed=True, reason="Previously approved (cached)") + + # Find matching policies + matching = match_policies(tool_name, tool_input, self._policies) + if not matching: + return CheckResult(allowed=True, reason="No policies match this tool call") + + # Merge all requirements + all_requirements = merge_requirements(matching) + if not all_requirements: + return CheckResult(allowed=True, reason="No requirements to check") + + # Evaluate + verdicts = await self.evaluator.evaluate(all_requirements, tool_name, tool_input) + failures = [v for v in verdicts if not v.passed] + + if not failures: + self.cache.approve(cache_key) + return CheckResult(allowed=True, reason="All requirements passed") + + # Build error message with ALL failures + error_lines = ["Tool call blocked by the following policy violations:\n"] + for f in failures: + req = all_requirements.get(f.requirement_id) + no_exc = "" + if req and req.no_exception: + no_exc = " [NO EXCEPTION - cannot be appealed]" + error_lines.append(f"- **{f.requirement_id}**{no_exc}: {f.explanation}") + + error_lines.append( + "\nTo appeal, call the `appeal_tool_requirement` MCP tool with:\n" + "- `tool_name`: the tool that was blocked\n" + "- `tool_input`: the exact tool_input that was blocked\n" + "- `policy_justification`: a dict mapping each failed check name " + "to your justification string" + ) + + return CheckResult( + allowed=False, + reason="\n".join(error_lines), + failed_checks=[f.requirement_id for f in failures], + ) + + async def appeal( + self, + tool_name: str, + tool_input: dict[str, Any], + justifications: dict[str, str], + ) -> AppealResult: + """Appeal a tool requirement denial. + + Args: + tool_name: Normalized tool name. + tool_input: Tool call parameters. + justifications: Failed check IDs mapped to justification strings. + + Returns: + AppealResult with passed=True if the appeal succeeds. + """ + if not justifications: + return AppealResult( + passed=False, + reason="No justifications provided. Provide a justification for each failed check.", + ) + + # Find matching policies and merge requirements + matching = match_policies(tool_name, tool_input, self._policies) + all_requirements = merge_requirements(matching) + + # Check for no_exception rules + no_exception_blocked: list[str] = [] + appealable: dict[str, Requirement] = {} + + for req_id in justifications: + if req_id not in all_requirements: + continue + req = all_requirements[req_id] + if req.no_exception: + no_exception_blocked.append(req_id) + else: + appealable[req_id] = req + + if no_exception_blocked: + return AppealResult( + passed=False, + reason=( + "Cannot appeal no_exception requirements: " + + ", ".join(no_exception_blocked) + ), + no_exception_blocked=no_exception_blocked, + ) + + if not appealable: + return AppealResult( + passed=False, + reason="No valid appealable requirements found in justifications.", + ) + + # Re-evaluate with justifications + verdicts = await self.evaluator.evaluate( + appealable, tool_name, tool_input, justifications=justifications + ) + failures = [v for v in verdicts if not v.passed] + + if not failures: + # Cache the approval so the retried tool call passes + cache_key = self.cache.make_key(tool_name, tool_input) + self.cache.approve(cache_key) + return AppealResult(passed=True, reason="Appeal accepted — you may retry the tool call.") + + error_lines = ["Appeal denied. The following checks still fail:\n"] + for f in failures: + error_lines.append(f"- **{f.requirement_id}**: {f.explanation}") + + return AppealResult(passed=False, reason="\n".join(error_lines)) diff --git a/src/deepwork/tool_requirements/evaluator.py b/src/deepwork/tool_requirements/evaluator.py new file mode 100644 index 00000000..202c4548 --- /dev/null +++ b/src/deepwork/tool_requirements/evaluator.py @@ -0,0 +1,256 @@ +"""LLM-based requirement evaluation for tool calls. + +Provides an abstract interface for evaluating RFC 2119 requirements +against tool calls, with a concrete Haiku subprocess implementation. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from deepwork.tool_requirements.config import Requirement + +logger = logging.getLogger("deepwork.tool_requirements") + +# Maximum characters of tool_input to include in the prompt +_MAX_INPUT_CHARS = 8000 + + +@dataclass +class RequirementVerdict: + """Result of evaluating a single requirement.""" + + requirement_id: str + passed: bool + explanation: str + + +class RequirementEvaluator(ABC): + """Abstract interface for requirement evaluation.""" + + @abstractmethod + async def evaluate( + self, + requirements: dict[str, Requirement], + tool_name: str, + tool_input: dict[str, Any], + justifications: dict[str, str] | None = None, + ) -> list[RequirementVerdict]: + """Evaluate requirements against a tool call. + + Args: + requirements: Requirements to check (id -> Requirement). + tool_name: Normalized tool name. + tool_input: Tool call parameters. + justifications: Optional justifications for appealed requirements. + + Returns: + List of verdicts, one per requirement. + """ + ... + + +class HaikuSubprocessEvaluator(RequirementEvaluator): + """Evaluates requirements using Claude Code subprocess with Haiku.""" + + async def evaluate( + self, + requirements: dict[str, Requirement], + tool_name: str, + tool_input: dict[str, Any], + justifications: dict[str, str] | None = None, + ) -> list[RequirementVerdict]: + if not requirements: + return [] + + prompt = _build_prompt(requirements, tool_name, tool_input, justifications) + raw_result = await self._call_haiku(prompt) + return _parse_result(raw_result, requirements) + + async def _call_haiku(self, prompt: str) -> str: + """Call Claude Code in print mode with Haiku model.""" + proc = await asyncio.create_subprocess_exec( + "claude", + "--model", + "haiku", + "--output-format", + "stream-json", + "-p", + prompt, + "--max-turns", + "1", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + error_msg = stderr.decode("utf-8", errors="replace").strip() + raise RuntimeError(f"Haiku subprocess failed (exit {proc.returncode}): {error_msg}") + + return stdout.decode("utf-8", errors="replace") + + +def _build_prompt( + requirements: dict[str, Requirement], + tool_name: str, + tool_input: dict[str, Any], + justifications: dict[str, str] | None = None, +) -> str: + """Build the evaluation prompt for Haiku.""" + input_str = json.dumps(tool_input, indent=2, default=str) + if len(input_str) > _MAX_INPUT_CHARS: + half = _MAX_INPUT_CHARS // 2 + input_str = input_str[:half] + "\n... [truncated] ...\n" + input_str[-half:] + + req_lines = [] + for req_id, req in requirements.items(): + req_lines.append(f"- {req_id}: {req.rule}") + + parts = [ + "You are evaluating whether a tool call complies with a set of requirements.", + "", + f"Tool: {tool_name}", + f"Tool Input:\n```json\n{input_str}\n```", + "", + "Requirements to check:", + *req_lines, + ] + + if justifications: + parts.append("") + parts.append("The agent has provided justifications for why certain requirements should pass:") + for req_id, justification in justifications.items(): + parts.append(f"- {req_id}: {justification}") + + parts.extend([ + "", + "For each requirement, determine if the tool call PASSES or FAILS.", + "Consider RFC 2119 keywords:", + "- MUST/MUST NOT: strict pass/fail — any violation is a failure", + "- SHOULD/SHOULD NOT: fail only if the violation is clear and easily avoidable", + "- MAY: always pass (informational only)", + "", + "If justifications are provided, consider them when making your determination.", + "A good justification can override a SHOULD violation but not a MUST violation.", + "", + "Return ONLY a JSON array with no other text:", + '[{"requirement_id": "...", "passed": true/false, "explanation": "..."}]', + ]) + + return "\n".join(parts) + + +def _parse_result( + raw_output: str, + requirements: dict[str, Requirement], +) -> list[RequirementVerdict]: + """Parse Haiku's streaming JSON output into verdicts.""" + # stream-json format: one JSON object per line + # We need to find the result message content + result_text = "" + for line in raw_output.strip().splitlines(): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + if not isinstance(obj, dict): + # Raw JSON array — use as-is + result_text = line + continue + # stream-json emits objects with "type" field + if obj.get("type") == "result": + result_text = obj.get("result", "") + break + # Also check for assistant message content + if obj.get("type") == "content": + content = obj.get("content", "") + if isinstance(content, str): + result_text = content + except json.JSONDecodeError: + continue + + if not result_text: + # Fall back to trying to extract JSON from the raw output + result_text = raw_output + + # Extract JSON array from the text + verdicts = _extract_json_array(result_text) + if verdicts is None: + # If we can't parse, fail-closed: all requirements fail + logger.warning("Could not parse evaluator output, failing all requirements") + return [ + RequirementVerdict( + requirement_id=req_id, + passed=False, + explanation="Failed to parse evaluator response", + ) + for req_id in requirements + ] + + result: list[RequirementVerdict] = [] + seen: set[str] = set() + for item in verdicts: + req_id = item.get("requirement_id", "") + if req_id not in requirements or req_id in seen: + continue + seen.add(req_id) + result.append( + RequirementVerdict( + requirement_id=req_id, + passed=bool(item.get("passed", False)), + explanation=str(item.get("explanation", "")), + ) + ) + + # Any requirements not in the response fail-closed + for req_id in requirements: + if req_id not in seen: + result.append( + RequirementVerdict( + requirement_id=req_id, + passed=False, + explanation="Requirement not evaluated by the evaluator", + ) + ) + + return result + + +def _extract_json_array(text: str) -> list[dict[str, Any]] | None: + """Extract a JSON array from text that may contain surrounding prose.""" + # Try direct parse first + try: + parsed = json.loads(text.strip()) + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + pass + + # Try to find [...] in the text + start = text.find("[") + if start == -1: + return None + + # Find matching closing bracket + depth = 0 + for i in range(start, len(text)): + if text[i] == "[": + depth += 1 + elif text[i] == "]": + depth -= 1 + if depth == 0: + try: + parsed = json.loads(text[start : i + 1]) + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + return None + + return None diff --git a/src/deepwork/tool_requirements/matcher.py b/src/deepwork/tool_requirements/matcher.py new file mode 100644 index 00000000..54eea3c7 --- /dev/null +++ b/src/deepwork/tool_requirements/matcher.py @@ -0,0 +1,78 @@ +"""Matching logic for tool requirements policies. + +Determines which policies apply to a given tool call based on +tool name and optional parameter-level regex filtering. +""" + +from __future__ import annotations + +import re +from typing import Any + +from deepwork.tool_requirements.config import Requirement, ToolPolicy + + +def match_policies( + tool_name: str, + tool_input: dict[str, Any], + policies: list[ToolPolicy], +) -> list[ToolPolicy]: + """Find all policies that match a tool call. + + Matching is two-step: + 1. Tool name must be in the policy's tools list + 2. If the policy has a match dict, at least one parameter regex must match + + Args: + tool_name: Normalized tool name (e.g., "shell", "write_file"). + tool_input: The tool's input parameters. + policies: All loaded policies. + + Returns: + List of matching policies. + """ + matched: list[ToolPolicy] = [] + for policy in policies: + if tool_name not in policy.tools: + continue + if policy.match and not _param_matches(tool_input, policy.match): + continue + matched.append(policy) + return matched + + +def merge_requirements(policies: list[ToolPolicy]) -> dict[str, Requirement]: + """Merge requirements from all matching policies. + + If the same requirement key appears in multiple policies, the first + occurrence wins. + + Args: + policies: List of matched policies. + + Returns: + Merged requirements dict. + """ + merged: dict[str, Requirement] = {} + for policy in policies: + for req_id, req in policy.requirements.items(): + if req_id not in merged: + merged[req_id] = req + return merged + + +def _param_matches(tool_input: dict[str, Any], match: dict[str, str]) -> bool: + """Check if any parameter regex matches the tool input. + + Returns True if at least one match entry matches a tool input value. + """ + for param_name, pattern in match.items(): + value = tool_input.get(param_name) + if value is None: + continue + try: + if re.search(pattern, str(value)): + return True + except re.error: + continue + return False diff --git a/src/deepwork/tool_requirements/sidecar.py b/src/deepwork/tool_requirements/sidecar.py new file mode 100644 index 00000000..6aad5a96 --- /dev/null +++ b/src/deepwork/tool_requirements/sidecar.py @@ -0,0 +1,274 @@ +"""HTTP sidecar server for tool requirements enforcement. + +Runs as a daemon thread alongside the MCP server, providing HTTP endpoints +that the PreToolUse hook calls to check/appeal tool requirements. + +Uses stdlib http.server — no external dependencies. +""" + +from __future__ import annotations + +import atexit +import json +import logging +import os +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from typing import Any + +from deepwork.tool_requirements.cache import ToolRequirementsCache +from deepwork.tool_requirements.engine import ToolRequirementsEngine +from deepwork.tool_requirements.evaluator import HaikuSubprocessEvaluator + +logger = logging.getLogger("deepwork.tool_requirements") + + +class _SidecarHandler(BaseHTTPRequestHandler): + """HTTP request handler for sidecar endpoints.""" + + engine: ToolRequirementsEngine # Set via partial/class attr + + def do_POST(self) -> None: + try: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + data = json.loads(body) if body else {} + except (json.JSONDecodeError, ValueError): + self._respond(400, {"error": "Invalid JSON body"}) + return + + if self.path == "/check": + self._handle_check(data) + elif self.path == "/appeal": + self._handle_appeal(data) + else: + self._respond(404, {"error": f"Unknown endpoint: {self.path}"}) + + def _handle_check(self, data: dict[str, Any]) -> None: + tool_name = data.get("tool_name", "") + tool_input = data.get("tool_input", {}) + + if not tool_name: + self._respond(400, {"error": "Missing tool_name"}) + return + + import asyncio + + try: + loop = asyncio.new_event_loop() + result = loop.run_until_complete( + self.engine.check(tool_name, tool_input) + ) + loop.close() + except Exception as e: + logger.exception("Error checking tool requirements") + self._respond(500, { + "decision": "deny", + "reason": f"Tool requirements evaluation error: {e}", + }) + return + + self._respond(200, { + "decision": "allow" if result.allowed else "deny", + "reason": result.reason, + "failed_checks": result.failed_checks, + }) + + def _handle_appeal(self, data: dict[str, Any]) -> None: + tool_name = data.get("tool_name", "") + tool_input = data.get("tool_input", {}) + justifications = data.get("policy_justification", {}) + + if not tool_name: + self._respond(400, {"error": "Missing tool_name"}) + return + if not justifications: + self._respond(400, {"error": "Missing policy_justification"}) + return + + import asyncio + + try: + loop = asyncio.new_event_loop() + result = loop.run_until_complete( + self.engine.appeal(tool_name, tool_input, justifications) + ) + loop.close() + except Exception as e: + logger.exception("Error processing appeal") + self._respond(500, { + "passed": False, + "reason": f"Appeal evaluation error: {e}", + }) + return + + self._respond(200, { + "passed": result.passed, + "reason": result.reason, + "no_exception_blocked": result.no_exception_blocked, + }) + + def _respond(self, status: int, body: dict[str, Any]) -> None: + self.send_response(status) + self.send_header("Content-Type", "application/json") + response = json.dumps(body).encode("utf-8") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def log_message(self, format: str, *args: Any) -> None: + """Suppress default stderr logging — use our logger instead.""" + logger.debug("Sidecar: %s", format % args) + + +class SidecarInfo: + """Information about a running sidecar server.""" + + def __init__(self, pid: int, port: int, port_file: Path) -> None: + self.pid = pid + self.port = port + self.port_file = port_file + + +def start_sidecar(project_root: Path) -> SidecarInfo: + """Start the sidecar HTTP server in a daemon thread. + + Writes a port file to .deepwork/tmp/tool_req_sidecar/.json + so the hook can discover and connect. + + Args: + project_root: Project root directory. + + Returns: + SidecarInfo with pid, port, and port file path. + """ + # Build the engine + engine = ToolRequirementsEngine( + project_root=project_root, + evaluator=HaikuSubprocessEvaluator(), + cache=ToolRequirementsCache(), + ) + + # Create handler class with engine reference + handler_class = type( + "_BoundSidecarHandler", + (_SidecarHandler,), + {"engine": engine}, + ) + + # Bind to random port on localhost + server = HTTPServer(("127.0.0.1", 0), handler_class) + port = server.server_address[1] + pid = os.getpid() + + # Write port file + sidecar_dir = project_root / ".deepwork" / "tmp" / "tool_req_sidecar" + sidecar_dir.mkdir(parents=True, exist_ok=True) + port_file = sidecar_dir / f"{pid}.json" + port_file.write_text(json.dumps({"pid": pid, "port": port})) + + logger.info("Tool requirements sidecar started on port %d (PID %d)", port, pid) + + # Start server in daemon thread + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + # Cleanup on exit + def cleanup() -> None: + server.shutdown() + port_file.unlink(missing_ok=True) + # Clean up any session mapping files that point to this PID + for f in sidecar_dir.glob("session_*.json"): + try: + data = json.loads(f.read_text()) + if data.get("pid") == pid: + f.unlink(missing_ok=True) + except (json.JSONDecodeError, OSError): + pass + + atexit.register(cleanup) + + return SidecarInfo(pid=pid, port=port, port_file=port_file) + + +def register_session(project_root: Path, session_id: str) -> None: + """Register a session-to-sidecar mapping. + + Called when the MCP server receives its first tool call with a session_id. + Creates a session_.json file pointing to this PID's sidecar. + + Args: + project_root: Project root directory. + session_id: The Claude Code session ID. + """ + pid = os.getpid() + sidecar_dir = project_root / ".deepwork" / "tmp" / "tool_req_sidecar" + + # Read our own port file + port_file = sidecar_dir / f"{pid}.json" + if not port_file.exists(): + return + + try: + data = json.loads(port_file.read_text()) + except (json.JSONDecodeError, OSError): + return + + # Write session mapping + session_file = sidecar_dir / f"session_{session_id}.json" + session_file.write_text(json.dumps({"pid": pid, "port": data["port"]})) + + +def discover_sidecar(project_root: Path, session_id: str) -> dict[str, Any] | None: + """Discover the sidecar server for a given session. + + Looks for a session-specific mapping first, then falls back to + scanning PID-keyed port files for live processes. + + Args: + project_root: Project root directory. + session_id: The Claude Code session ID. + + Returns: + Dict with "port" and "pid" keys, or None if no sidecar found. + """ + sidecar_dir = Path(project_root) / ".deepwork" / "tmp" / "tool_req_sidecar" + if not sidecar_dir.is_dir(): + return None + + # Try session-specific mapping first + if session_id: + session_file = sidecar_dir / f"session_{session_id}.json" + info = _read_and_validate_port_file(session_file) + if info is not None: + return info + + # Fall back to scanning PID files + for port_file in sidecar_dir.glob("[0-9]*.json"): + info = _read_and_validate_port_file(port_file) + if info is not None: + return info + + return None + + +def _read_and_validate_port_file(port_file: Path) -> dict[str, Any] | None: + """Read a port file and check if the PID is alive.""" + if not port_file.exists(): + return None + + try: + data = json.loads(port_file.read_text()) + pid = data.get("pid") + port = data.get("port") + if not pid or not port: + return None + + # Check if PID is alive + os.kill(pid, 0) + return {"pid": pid, "port": port} + except (json.JSONDecodeError, OSError): + # PID is dead or file is corrupt — clean up + port_file.unlink(missing_ok=True) + return None diff --git a/tests/unit/test_tool_requirements_hook.py b/tests/unit/test_tool_requirements_hook.py new file mode 100644 index 00000000..f8ab7024 --- /dev/null +++ b/tests/unit/test_tool_requirements_hook.py @@ -0,0 +1,126 @@ +"""Tests for the tool requirements PreToolUse hook.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from deepwork.hooks.wrapper import HookInput, HookOutput, NormalizedEvent, Platform + + +class TestToolRequirementsHook: + def _make_input( + self, + tool_name: str = "shell", + tool_input: dict | None = None, + raw_tool_name: str = "Bash", + session_id: str = "test-session", + cwd: str = "/test/project", + ) -> HookInput: + return HookInput( + platform=Platform.CLAUDE, + event=NormalizedEvent.BEFORE_TOOL, + session_id=session_id, + cwd=cwd, + tool_name=tool_name, + tool_input=tool_input or {"command": "ls"}, + raw_input={ + "hook_event_name": "PreToolUse", + "tool_name": raw_tool_name, + "session_id": session_id, + }, + ) + + def test_skips_non_before_tool_events(self) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + hook_input = HookInput( + platform=Platform.CLAUDE, + event=NormalizedEvent.AFTER_TOOL, + tool_name="shell", + ) + result = tool_requirements_hook(hook_input) + assert result.decision == "" + + def test_loop_prevention_skips_appeal_tool(self) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + hook_input = self._make_input( + raw_tool_name="mcp__plugin_deepwork_deepwork__appeal_tool_requirement" + ) + result = tool_requirements_hook(hook_input) + assert result.decision == "" + + def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + for prefix in [ + "mcp__deepwork__appeal_tool_requirement", + "mcp__deepwork-dev__appeal_tool_requirement", + "mcp__plugin_deepwork_deepwork__appeal_tool_requirement", + ]: + hook_input = self._make_input(raw_tool_name=prefix) + result = tool_requirements_hook(hook_input) + assert result.decision == "", f"Failed for {prefix}" + + @patch("deepwork.hooks.tool_requirements.discover_sidecar") + def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + mock_discover.return_value = None + hook_input = self._make_input() + result = tool_requirements_hook(hook_input) + + assert "permissionDecision" in str(result.raw_output) + assert result.raw_output["hookSpecificOutput"]["permissionDecision"] == "deny" + assert "not running" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + + @patch("deepwork.hooks.tool_requirements._http_post") + @patch("deepwork.hooks.tool_requirements.discover_sidecar") + def test_allow_on_sidecar_allow( + self, mock_discover: MagicMock, mock_post: MagicMock + ) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + mock_discover.return_value = {"pid": 123, "port": 9999} + mock_post.return_value = {"decision": "allow", "reason": "OK"} + + hook_input = self._make_input() + result = tool_requirements_hook(hook_input) + assert result.decision == "" + assert result.raw_output == {} + + @patch("deepwork.hooks.tool_requirements._http_post") + @patch("deepwork.hooks.tool_requirements.discover_sidecar") + def test_deny_on_sidecar_deny( + self, mock_discover: MagicMock, mock_post: MagicMock + ) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + mock_discover.return_value = {"pid": 123, "port": 9999} + mock_post.return_value = { + "decision": "deny", + "reason": "Policy violation: r1", + } + + hook_input = self._make_input() + result = tool_requirements_hook(hook_input) + + assert result.raw_output["hookSpecificOutput"]["permissionDecision"] == "deny" + assert "Policy violation" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + + @patch("deepwork.hooks.tool_requirements._http_post") + @patch("deepwork.hooks.tool_requirements.discover_sidecar") + def test_fail_closed_on_connection_error( + self, mock_discover: MagicMock, mock_post: MagicMock + ) -> None: + from deepwork.hooks.tool_requirements import tool_requirements_hook + + mock_discover.return_value = {"pid": 123, "port": 9999} + mock_post.side_effect = ConnectionRefusedError("Connection refused") + + hook_input = self._make_input() + result = tool_requirements_hook(hook_input) + + assert result.raw_output["hookSpecificOutput"]["permissionDecision"] == "deny" + assert "Failed to reach" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] diff --git a/tests/unit/tool_requirements/__init__.py b/tests/unit/tool_requirements/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/tool_requirements/test_cache.py b/tests/unit/tool_requirements/test_cache.py new file mode 100644 index 00000000..57c4d64b --- /dev/null +++ b/tests/unit/tool_requirements/test_cache.py @@ -0,0 +1,56 @@ +"""Tests for tool requirements TTL cache.""" + +import time +from unittest.mock import patch + +from deepwork.tool_requirements.cache import ToolRequirementsCache + + +class TestToolRequirementsCache: + def test_make_key_deterministic(self) -> None: + cache = ToolRequirementsCache() + k1 = cache.make_key("shell", {"command": "ls"}) + k2 = cache.make_key("shell", {"command": "ls"}) + assert k1 == k2 + + def test_make_key_different_for_different_input(self) -> None: + cache = ToolRequirementsCache() + k1 = cache.make_key("shell", {"command": "ls"}) + k2 = cache.make_key("shell", {"command": "rm"}) + assert k1 != k2 + + def test_make_key_different_for_different_tool(self) -> None: + cache = ToolRequirementsCache() + k1 = cache.make_key("shell", {"command": "ls"}) + k2 = cache.make_key("write_file", {"command": "ls"}) + assert k1 != k2 + + def test_approve_and_check(self) -> None: + cache = ToolRequirementsCache() + key = cache.make_key("shell", {"command": "ls"}) + assert not cache.is_approved(key) + cache.approve(key) + assert cache.is_approved(key) + + def test_ttl_expiry(self) -> None: + cache = ToolRequirementsCache(ttl_seconds=1.0) + key = cache.make_key("shell", {"command": "ls"}) + cache.approve(key) + assert cache.is_approved(key) + + with patch("deepwork.tool_requirements.cache.time") as mock_time: + mock_time.time.return_value = time.time() + 2.0 + assert not cache.is_approved(key) + + def test_clear(self) -> None: + cache = ToolRequirementsCache() + cache.approve(cache.make_key("shell", {"command": "ls"})) + assert len(cache) == 1 + cache.clear() + assert len(cache) == 0 + + def test_sorted_keys_for_consistency(self) -> None: + cache = ToolRequirementsCache() + k1 = cache.make_key("shell", {"b": "2", "a": "1"}) + k2 = cache.make_key("shell", {"a": "1", "b": "2"}) + assert k1 == k2 diff --git a/tests/unit/tool_requirements/test_config.py b/tests/unit/tool_requirements/test_config.py new file mode 100644 index 00000000..9b5d18e8 --- /dev/null +++ b/tests/unit/tool_requirements/test_config.py @@ -0,0 +1,119 @@ +"""Tests for tool requirements config parsing.""" + +from pathlib import Path + +import pytest +import yaml + +from deepwork.tool_requirements.config import ( + Requirement, + ToolPolicy, + ToolRequirementsError, + parse_policy_file, +) + + +@pytest.fixture() +def policy_dir(tmp_path: Path) -> Path: + d = tmp_path / ".deepwork" / "tool_requirements" + d.mkdir(parents=True) + return d + + +def _write_policy(path: Path, data: dict) -> Path: + path.write_text(yaml.safe_dump(data, sort_keys=False)) + return path + + +class TestParsePolicy: + def test_basic_policy(self, policy_dir: Path) -> None: + data = { + "summary": "Test policy", + "tools": ["shell"], + "requirements": { + "no-rm-rf": { + "rule": "MUST NOT use rm -rf /", + "no_exception": True, + }, + "prefer-safe": { + "rule": "SHOULD prefer safe alternatives", + }, + }, + } + path = _write_policy(policy_dir / "test.yml", data) + policy = parse_policy_file(path) + + assert policy.name == "test" + assert policy.summary == "Test policy" + assert policy.tools == ["shell"] + assert len(policy.requirements) == 2 + assert policy.requirements["no-rm-rf"].rule == "MUST NOT use rm -rf /" + assert policy.requirements["no-rm-rf"].no_exception is True + assert policy.requirements["prefer-safe"].no_exception is False + + def test_policy_with_match(self, policy_dir: Path) -> None: + data = { + "tools": ["shell"], + "match": {"command": "rm "}, + "requirements": {"r1": {"rule": "MUST check"}}, + } + path = _write_policy(policy_dir / "match.yml", data) + policy = parse_policy_file(path) + + assert policy.match == {"command": "rm "} + + def test_policy_with_extends(self, policy_dir: Path) -> None: + data = { + "tools": ["shell"], + "extends": ["common"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + path = _write_policy(policy_dir / "child.yml", data) + policy = parse_policy_file(path) + + assert policy.extends == ["common"] + + def test_missing_file_raises(self, tmp_path: Path) -> None: + with pytest.raises(ToolRequirementsError, match="File not found"): + parse_policy_file(tmp_path / "nonexistent.yml") + + def test_empty_file_raises(self, policy_dir: Path) -> None: + path = policy_dir / "empty.yml" + path.write_text("") + with pytest.raises(ToolRequirementsError, match="Empty policy"): + parse_policy_file(path) + + def test_missing_required_fields_raises(self, policy_dir: Path) -> None: + data = {"summary": "no tools or requirements"} + path = _write_policy(policy_dir / "bad.yml", data) + with pytest.raises(ToolRequirementsError, match="Schema validation failed"): + parse_policy_file(path) + + def test_invalid_requirement_structure_raises(self, policy_dir: Path) -> None: + data = { + "tools": ["shell"], + "requirements": {"r1": "just a string, not a dict"}, + } + path = _write_policy(policy_dir / "bad2.yml", data) + with pytest.raises(ToolRequirementsError, match="Schema validation failed"): + parse_policy_file(path) + + def test_no_exception_defaults_false(self, policy_dir: Path) -> None: + data = { + "tools": ["write_file"], + "requirements": {"r1": {"rule": "SHOULD do something"}}, + } + path = _write_policy(policy_dir / "defaults.yml", data) + policy = parse_policy_file(path) + + assert policy.requirements["r1"].no_exception is False + + def test_multiple_tools(self, policy_dir: Path) -> None: + data = { + "tools": ["shell", "write_file", "edit_file"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + path = _write_policy(policy_dir / "multi.yml", data) + policy = parse_policy_file(path) + + assert policy.tools == ["shell", "write_file", "edit_file"] diff --git a/tests/unit/tool_requirements/test_discovery.py b/tests/unit/tool_requirements/test_discovery.py new file mode 100644 index 00000000..62a95d85 --- /dev/null +++ b/tests/unit/tool_requirements/test_discovery.py @@ -0,0 +1,116 @@ +"""Tests for tool requirements discovery and inheritance.""" + +from pathlib import Path + +import pytest +import yaml + +from deepwork.tool_requirements.discovery import load_all_policies + + +@pytest.fixture() +def project(tmp_path: Path) -> Path: + (tmp_path / ".deepwork" / "tool_requirements").mkdir(parents=True) + return tmp_path + + +def _write_policy(project: Path, name: str, data: dict) -> Path: + path = project / ".deepwork" / "tool_requirements" / f"{name}.yml" + path.write_text(yaml.safe_dump(data, sort_keys=False)) + return path + + +class TestLoadAllPolicies: + def test_no_policy_dir(self, tmp_path: Path) -> None: + assert load_all_policies(tmp_path) == [] + + def test_empty_dir(self, project: Path) -> None: + assert load_all_policies(project) == [] + + def test_single_policy(self, project: Path) -> None: + _write_policy(project, "bash_safety", { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }) + policies = load_all_policies(project) + assert len(policies) == 1 + assert policies[0].name == "bash_safety" + + def test_multiple_policies(self, project: Path) -> None: + _write_policy(project, "bash_safety", { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }) + _write_policy(project, "write_safety", { + "tools": ["write_file"], + "requirements": {"r2": {"rule": "SHOULD verify"}}, + }) + policies = load_all_policies(project) + assert len(policies) == 2 + names = {p.name for p in policies} + assert names == {"bash_safety", "write_safety"} + + def test_bad_file_skipped(self, project: Path) -> None: + _write_policy(project, "good", { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }) + # Write invalid YAML + bad = project / ".deepwork" / "tool_requirements" / "bad.yml" + bad.write_text("tools: !!invalid") + + policies = load_all_policies(project) + assert len(policies) == 1 + assert policies[0].name == "good" + + +class TestInheritance: + def test_extends_merges_requirements(self, project: Path) -> None: + _write_policy(project, "parent", { + "tools": ["shell"], + "requirements": { + "parent-req": {"rule": "MUST do parent thing"}, + }, + }) + _write_policy(project, "child", { + "tools": ["shell"], + "extends": ["parent"], + "requirements": { + "child-req": {"rule": "MUST do child thing"}, + }, + }) + policies = load_all_policies(project) + child = next(p for p in policies if p.name == "child") + assert "parent-req" in child.requirements + assert "child-req" in child.requirements + + def test_child_overrides_parent(self, project: Path) -> None: + _write_policy(project, "parent", { + "tools": ["shell"], + "requirements": { + "shared": {"rule": "Parent version", "no_exception": True}, + }, + }) + _write_policy(project, "child", { + "tools": ["shell"], + "extends": ["parent"], + "requirements": { + "shared": {"rule": "Child version"}, + }, + }) + policies = load_all_policies(project) + child = next(p for p in policies if p.name == "child") + assert child.requirements["shared"].rule == "Child version" + assert child.requirements["shared"].no_exception is False + + def test_unknown_parent_ignored(self, project: Path) -> None: + _write_policy(project, "child", { + "tools": ["shell"], + "extends": ["nonexistent"], + "requirements": { + "r1": {"rule": "MUST check"}, + }, + }) + policies = load_all_policies(project) + assert len(policies) == 1 + assert "r1" in policies[0].requirements diff --git a/tests/unit/tool_requirements/test_engine.py b/tests/unit/tool_requirements/test_engine.py new file mode 100644 index 00000000..00766eff --- /dev/null +++ b/tests/unit/tool_requirements/test_engine.py @@ -0,0 +1,205 @@ +"""Tests for tool requirements engine.""" + +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest +import yaml + +from deepwork.tool_requirements.cache import ToolRequirementsCache +from deepwork.tool_requirements.config import Requirement +from deepwork.tool_requirements.engine import ToolRequirementsEngine +from deepwork.tool_requirements.evaluator import RequirementEvaluator, RequirementVerdict + + +class MockEvaluator(RequirementEvaluator): + """Test evaluator that returns preconfigured verdicts.""" + + def __init__(self, verdicts: list[RequirementVerdict] | None = None) -> None: + self._verdicts = verdicts or [] + self.call_count = 0 + self.last_justifications: dict[str, str] | None = None + + async def evaluate(self, requirements, tool_name, tool_input, justifications=None): + self.call_count += 1 + self.last_justifications = justifications + if self._verdicts: + return self._verdicts + # Default: all pass + return [ + RequirementVerdict(req_id, True, "OK") + for req_id in requirements + ] + + +def _setup_project(tmp_path: Path, policies: dict[str, dict]) -> Path: + policy_dir = tmp_path / ".deepwork" / "tool_requirements" + policy_dir.mkdir(parents=True) + for name, data in policies.items(): + (policy_dir / f"{name}.yml").write_text(yaml.safe_dump(data, sort_keys=False)) + return tmp_path + + +class TestEngineCheck: + @pytest.mark.asyncio() + async def test_no_policies_allows(self, tmp_path: Path) -> None: + engine = ToolRequirementsEngine(tmp_path, MockEvaluator()) + result = await engine.check("shell", {"command": "ls"}) + assert result.allowed is True + + @pytest.mark.asyncio() + async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "write_rules": { + "tools": ["write_file"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }) + engine = ToolRequirementsEngine(project, MockEvaluator()) + result = await engine.check("shell", {"command": "ls"}) + assert result.allowed is True + + @pytest.mark.asyncio() + async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "bash_rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }) + evaluator = MockEvaluator() + engine = ToolRequirementsEngine(project, evaluator) + + result = await engine.check("shell", {"command": "ls"}) + assert result.allowed is True + assert evaluator.call_count == 1 + + # Second call should be cached + result2 = await engine.check("shell", {"command": "ls"}) + assert result2.allowed is True + assert evaluator.call_count == 1 # Not called again + + @pytest.mark.asyncio() + async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST do A"}, + "r2": {"rule": "MUST do B"}, + }, + } + }) + evaluator = MockEvaluator(verdicts=[ + RequirementVerdict("r1", False, "Failed A"), + RequirementVerdict("r2", False, "Failed B"), + ]) + engine = ToolRequirementsEngine(project, evaluator) + + result = await engine.check("shell", {"command": "bad"}) + assert result.allowed is False + assert "r1" in result.reason + assert "r2" in result.reason + assert "Failed A" in result.reason + assert "Failed B" in result.reason + assert set(result.failed_checks) == {"r1", "r2"} + + @pytest.mark.asyncio() + async def test_no_exception_label_in_error(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST check", "no_exception": True}, + }, + } + }) + evaluator = MockEvaluator(verdicts=[ + RequirementVerdict("r1", False, "Blocked"), + ]) + engine = ToolRequirementsEngine(project, evaluator) + + result = await engine.check("shell", {"command": "bad"}) + assert "NO EXCEPTION" in result.reason + + +class TestEngineAppeal: + @pytest.mark.asyncio() + async def test_successful_appeal_caches(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "SHOULD check"}, + }, + } + }) + evaluator = MockEvaluator() # All pass by default + engine = ToolRequirementsEngine(project, evaluator) + + result = await engine.appeal( + "shell", {"command": "rm file"}, + justifications={"r1": "It's a temp file"}, + ) + assert result.passed is True + assert evaluator.last_justifications == {"r1": "It's a temp file"} + + # Should be cached now + check = await engine.check("shell", {"command": "rm file"}) + assert check.allowed is True + + @pytest.mark.asyncio() + async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST NOT", "no_exception": True}, + }, + } + }) + engine = ToolRequirementsEngine(project, MockEvaluator()) + + result = await engine.appeal( + "shell", {"command": "bad"}, + justifications={"r1": "Please?"}, + ) + assert result.passed is False + assert "no_exception" in result.reason.lower() + assert "r1" in result.no_exception_blocked + + @pytest.mark.asyncio() + async def test_empty_justifications_rejected(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }) + engine = ToolRequirementsEngine(project, MockEvaluator()) + + result = await engine.appeal("shell", {"command": "bad"}, justifications={}) + assert result.passed is False + + @pytest.mark.asyncio() + async def test_failed_appeal_not_cached(self, tmp_path: Path) -> None: + project = _setup_project(tmp_path, { + "rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }) + evaluator = MockEvaluator(verdicts=[ + RequirementVerdict("r1", False, "Still bad"), + ]) + engine = ToolRequirementsEngine(project, evaluator) + + result = await engine.appeal( + "shell", {"command": "bad"}, + justifications={"r1": "Please"}, + ) + assert result.passed is False + + # Should NOT be cached + cache_key = engine.cache.make_key("shell", {"command": "bad"}) + assert not engine.cache.is_approved(cache_key) diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py new file mode 100644 index 00000000..33f3ffc3 --- /dev/null +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -0,0 +1,102 @@ +"""Tests for tool requirements LLM evaluator.""" + +import json + +import pytest + +from deepwork.tool_requirements.config import Requirement +from deepwork.tool_requirements.evaluator import ( + RequirementVerdict, + _build_prompt, + _extract_json_array, + _parse_result, +) + + +class TestBuildPrompt: + def test_includes_tool_info(self) -> None: + reqs = {"r1": Requirement(rule="MUST check")} + prompt = _build_prompt(reqs, "shell", {"command": "rm -rf /"}) + assert "shell" in prompt + assert "rm -rf /" in prompt + assert "r1" in prompt + assert "MUST check" in prompt + + def test_includes_justifications_when_present(self) -> None: + reqs = {"r1": Requirement(rule="MUST check")} + prompt = _build_prompt( + reqs, "shell", {"command": "ls"}, + justifications={"r1": "This is safe because..."}, + ) + assert "This is safe because..." in prompt + assert "justification" in prompt.lower() + + def test_truncates_large_input(self) -> None: + large_input = {"content": "x" * 20000} + reqs = {"r1": Requirement(rule="MUST check")} + prompt = _build_prompt(reqs, "write_file", large_input) + assert "[truncated]" in prompt + + +class TestExtractJsonArray: + def test_direct_array(self) -> None: + text = '[{"a": 1}]' + assert _extract_json_array(text) == [{"a": 1}] + + def test_array_with_surrounding_text(self) -> None: + text = 'Here is my analysis:\n[{"a": 1}]\nDone.' + assert _extract_json_array(text) == [{"a": 1}] + + def test_no_array(self) -> None: + assert _extract_json_array("no json here") is None + + def test_empty_array(self) -> None: + assert _extract_json_array("[]") == [] + + +class TestParseResult: + def test_parses_stream_json_result(self) -> None: + verdicts = [ + {"requirement_id": "r1", "passed": True, "explanation": "OK"}, + {"requirement_id": "r2", "passed": False, "explanation": "Bad"}, + ] + # Simulate stream-json output + lines = [ + json.dumps({"type": "content", "content": json.dumps(verdicts)}), + ] + raw = "\n".join(lines) + reqs = { + "r1": Requirement(rule="MUST do A"), + "r2": Requirement(rule="MUST do B"), + } + result = _parse_result(raw, reqs) + assert len(result) == 2 + assert result[0].requirement_id == "r1" + assert result[0].passed is True + assert result[1].requirement_id == "r2" + assert result[1].passed is False + + def test_missing_requirement_fails_closed(self) -> None: + verdicts = [{"requirement_id": "r1", "passed": True, "explanation": "OK"}] + raw = json.dumps(verdicts) + reqs = { + "r1": Requirement(rule="Rule 1"), + "r2": Requirement(rule="Rule 2"), + } + result = _parse_result(raw, reqs) + r2 = next(v for v in result if v.requirement_id == "r2") + assert r2.passed is False + assert "not evaluated" in r2.explanation + + def test_unparseable_output_fails_all(self) -> None: + reqs = {"r1": Requirement(rule="Rule 1")} + result = _parse_result("garbage output", reqs) + assert len(result) == 1 + assert result[0].passed is False + + def test_result_type_message(self) -> None: + verdicts = [{"requirement_id": "r1", "passed": True, "explanation": "OK"}] + raw = json.dumps({"type": "result", "result": json.dumps(verdicts)}) + reqs = {"r1": Requirement(rule="Rule 1")} + result = _parse_result(raw, reqs) + assert result[0].passed is True diff --git a/tests/unit/tool_requirements/test_matcher.py b/tests/unit/tool_requirements/test_matcher.py new file mode 100644 index 00000000..62392af8 --- /dev/null +++ b/tests/unit/tool_requirements/test_matcher.py @@ -0,0 +1,81 @@ +"""Tests for tool requirements policy matching.""" + +from pathlib import Path + +import pytest + +from deepwork.tool_requirements.config import Requirement, ToolPolicy +from deepwork.tool_requirements.matcher import match_policies, merge_requirements + + +def _policy( + name: str = "test", + tools: list[str] | None = None, + match: dict[str, str] | None = None, + requirements: dict[str, Requirement] | None = None, +) -> ToolPolicy: + return ToolPolicy( + name=name, + source_path=Path(f"/fake/{name}.yml"), + tools=tools or ["shell"], + match=match or {}, + requirements=requirements or {"r1": Requirement(rule="MUST check")}, + ) + + +class TestMatchPolicies: + def test_matches_by_tool_name(self) -> None: + policies = [_policy(tools=["shell"]), _policy(name="other", tools=["write_file"])] + result = match_policies("shell", {"command": "ls"}, policies) + assert len(result) == 1 + assert result[0].name == "test" + + def test_no_match_returns_empty(self) -> None: + policies = [_policy(tools=["write_file"])] + assert match_policies("shell", {"command": "ls"}, policies) == [] + + def test_match_with_param_regex(self) -> None: + policies = [_policy(tools=["shell"], match={"command": "rm "})] + assert len(match_policies("shell", {"command": "rm -rf /"}, policies)) == 1 + assert len(match_policies("shell", {"command": "ls -la"}, policies)) == 0 + + def test_match_requires_at_least_one_param_hit(self) -> None: + policies = [_policy(tools=["shell"], match={"command": "rm", "extra": "foo"})] + # command matches even though extra doesn't + assert len(match_policies("shell", {"command": "rm file"}, policies)) == 1 + + def test_match_no_params_present(self) -> None: + policies = [_policy(tools=["shell"], match={"command": "rm"})] + assert len(match_policies("shell", {}, policies)) == 0 + + def test_no_match_dict_means_always_match(self) -> None: + policies = [_policy(tools=["shell"], match={})] + assert len(match_policies("shell", {"command": "anything"}, policies)) == 1 + + def test_invalid_regex_skipped(self) -> None: + policies = [_policy(tools=["shell"], match={"command": "[invalid"})] + # Invalid regex is skipped, no match + assert len(match_policies("shell", {"command": "test"}, policies)) == 0 + + def test_multiple_tools_in_policy(self) -> None: + policies = [_policy(tools=["shell", "write_file"])] + assert len(match_policies("shell", {}, policies)) == 1 + assert len(match_policies("write_file", {}, policies)) == 1 + assert len(match_policies("read_file", {}, policies)) == 0 + + +class TestMergeRequirements: + def test_merge_distinct(self) -> None: + p1 = _policy(name="a", requirements={"r1": Requirement(rule="Rule 1")}) + p2 = _policy(name="b", requirements={"r2": Requirement(rule="Rule 2")}) + merged = merge_requirements([p1, p2]) + assert set(merged.keys()) == {"r1", "r2"} + + def test_first_wins_on_conflict(self) -> None: + p1 = _policy(name="a", requirements={"r1": Requirement(rule="First")}) + p2 = _policy(name="b", requirements={"r1": Requirement(rule="Second")}) + merged = merge_requirements([p1, p2]) + assert merged["r1"].rule == "First" + + def test_empty_policies(self) -> None: + assert merge_requirements([]) == {} From 291ab6ca50abf11e115ca9d6c213100863ded893 Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 12:36:07 -0600 Subject: [PATCH 2/7] fix: address code review findings - engine.py: rename loop variable `f` to `failure` for clarity - sidecar.py: move `import asyncio` to module level, fix event loop leak with try/finally, fix inaccurate comment, add session_id validation - evaluator.py: change `continue` to `break` on raw JSON array parse, filter non-dict items in _extract_json_array - discovery.py: fix double-name warning message, remove dead code - test_engine.py: add type hints to MockEvaluator.evaluate, remove unused imports - test_tool_requirements_hook.py: remove redundant test Co-Authored-By: Claude Opus 4.6 (1M context) --- CHANGELOG.md | 4 + src/deepwork/hooks/tool_requirements.py | 16 +- src/deepwork/jobs/mcp/server.py | 20 +- src/deepwork/tool_requirements/discovery.py | 8 +- src/deepwork/tool_requirements/engine.py | 17 +- src/deepwork/tool_requirements/evaluator.py | 42 ++-- src/deepwork/tool_requirements/sidecar.py | 84 +++++--- tests/unit/test_tool_requirements_hook.py | 31 +-- tests/unit/tool_requirements/test_config.py | 2 - .../unit/tool_requirements/test_discovery.py | 124 ++++++++---- tests/unit/tool_requirements/test_engine.py | 189 +++++++++++------- .../unit/tool_requirements/test_evaluator.py | 7 +- tests/unit/tool_requirements/test_matcher.py | 2 - 13 files changed, 316 insertions(+), 230 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 08e5b807..2e559132 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Tool requirements policy enforcement system: define RFC 2119-style rules in `.deepwork/tool_requirements/*.yml` to govern AI agent tool calls, with LLM-based semantic evaluation, appeal mechanism, and 1-hour TTL caching + ### Changed - Renamed default reviewer agent from `reviewer` to `deepwork:reviewer` (plugin-namespaced) in review instructions output @@ -16,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Settings schema missing `if`, `asyncRewake`, `once`, `shell` fields on hook definitions and missing `StopFailure`, `PermissionDenied`, `TaskCreated`, `FileChanged`, `CwdChanged` hook event types + ### Removed ## [0.13.7] - 2026-04-14 diff --git a/src/deepwork/hooks/tool_requirements.py b/src/deepwork/hooks/tool_requirements.py index b34066ac..20c5ba8b 100644 --- a/src/deepwork/hooks/tool_requirements.py +++ b/src/deepwork/hooks/tool_requirements.py @@ -56,12 +56,16 @@ def tool_requirements_hook(hook_input: HookInput) -> HookOutput: # Send check request to sidecar try: - response = _http_post(sidecar["port"], "/check", { - "tool_name": hook_input.tool_name, - "tool_input": hook_input.tool_input, - "raw_tool_name": raw_tool, - "session_id": session_id, - }) + response = _http_post( + sidecar["port"], + "/check", + { + "tool_name": hook_input.tool_name, + "tool_input": hook_input.tool_input, + "raw_tool_name": raw_tool, + "session_id": session_id, + }, + ) except Exception as e: return _deny( f"DeepWork Tool Requirements: Failed to reach MCP server sidecar: {e}. " diff --git a/src/deepwork/jobs/mcp/server.py b/src/deepwork/jobs/mcp/server.py index af43fac9..e19b769b 100644 --- a/src/deepwork/jobs/mcp/server.py +++ b/src/deepwork/jobs/mcp/server.py @@ -16,7 +16,7 @@ import logging import shutil from pathlib import Path -from typing import Any +from typing import Any, cast from fastmcp import Context, FastMCP @@ -573,15 +573,15 @@ async def appeal_tool_requirement( import http.client import json as json_mod - conn = http.client.HTTPConnection( - "127.0.0.1", sidecar["port"], timeout=60 - ) + conn = http.client.HTTPConnection("127.0.0.1", sidecar["port"], timeout=60) try: - payload = json_mod.dumps({ - "tool_name": tool_name, - "tool_input": tool_input, - "policy_justification": policy_justification, - }).encode("utf-8") + payload = json_mod.dumps( + { + "tool_name": tool_name, + "tool_input": tool_input, + "policy_justification": policy_justification, + } + ).encode("utf-8") conn.request( "POST", "/appeal", @@ -589,7 +589,7 @@ async def appeal_tool_requirement( headers={"Content-Type": "application/json"}, ) response = conn.getresponse() - return json_mod.loads(response.read()) + return cast(dict[str, Any], json_mod.loads(response.read())) finally: conn.close() except Exception as e: diff --git a/src/deepwork/tool_requirements/discovery.py b/src/deepwork/tool_requirements/discovery.py index 99ac7511..a549d82e 100644 --- a/src/deepwork/tool_requirements/discovery.py +++ b/src/deepwork/tool_requirements/discovery.py @@ -69,8 +69,8 @@ def resolve(name: str, visited: set[str]) -> ToolPolicy: return resolved[name] if name not in by_name: - logger.warning("Policy '%s' extends unknown policy '%s'", name, name) - return by_name.get(name, ToolPolicy(name=name, source_path=Path())) + logger.warning("Policy '%s' not found", name) + return ToolPolicy(name=name, source_path=Path()) policy = by_name[name] @@ -84,9 +84,7 @@ def resolve(name: str, visited: set[str]) -> ToolPolicy: merged_requirements: dict[str, Requirement] = {} for parent_name in policy.extends: if parent_name not in by_name: - logger.warning( - "Policy '%s' extends unknown policy '%s'", policy.name, parent_name - ) + logger.warning("Policy '%s' extends unknown policy '%s'", policy.name, parent_name) continue parent = resolve(parent_name, visited) merged_requirements.update(parent.requirements) diff --git a/src/deepwork/tool_requirements/engine.py b/src/deepwork/tool_requirements/engine.py index 7616b304..b7c83aae 100644 --- a/src/deepwork/tool_requirements/engine.py +++ b/src/deepwork/tool_requirements/engine.py @@ -91,12 +91,12 @@ async def check(self, tool_name: str, tool_input: dict[str, Any]) -> CheckResult # Build error message with ALL failures error_lines = ["Tool call blocked by the following policy violations:\n"] - for f in failures: - req = all_requirements.get(f.requirement_id) + for failure in failures: + req = all_requirements.get(failure.requirement_id) no_exc = "" if req and req.no_exception: no_exc = " [NO EXCEPTION - cannot be appealed]" - error_lines.append(f"- **{f.requirement_id}**{no_exc}: {f.explanation}") + error_lines.append(f"- **{failure.requirement_id}**{no_exc}: {failure.explanation}") error_lines.append( "\nTo appeal, call the `appeal_tool_requirement` MCP tool with:\n" @@ -155,8 +155,7 @@ async def appeal( return AppealResult( passed=False, reason=( - "Cannot appeal no_exception requirements: " - + ", ".join(no_exception_blocked) + "Cannot appeal no_exception requirements: " + ", ".join(no_exception_blocked) ), no_exception_blocked=no_exception_blocked, ) @@ -177,10 +176,12 @@ async def appeal( # Cache the approval so the retried tool call passes cache_key = self.cache.make_key(tool_name, tool_input) self.cache.approve(cache_key) - return AppealResult(passed=True, reason="Appeal accepted — you may retry the tool call.") + return AppealResult( + passed=True, reason="Appeal accepted — you may retry the tool call." + ) error_lines = ["Appeal denied. The following checks still fail:\n"] - for f in failures: - error_lines.append(f"- **{f.requirement_id}**: {f.explanation}") + for failure in failures: + error_lines.append(f"- **{failure.requirement_id}**: {failure.explanation}") return AppealResult(passed=False, reason="\n".join(error_lines)) diff --git a/src/deepwork/tool_requirements/evaluator.py b/src/deepwork/tool_requirements/evaluator.py index 202c4548..ec822fdf 100644 --- a/src/deepwork/tool_requirements/evaluator.py +++ b/src/deepwork/tool_requirements/evaluator.py @@ -124,24 +124,28 @@ def _build_prompt( if justifications: parts.append("") - parts.append("The agent has provided justifications for why certain requirements should pass:") + parts.append( + "The agent has provided justifications for why certain requirements should pass:" + ) for req_id, justification in justifications.items(): parts.append(f"- {req_id}: {justification}") - parts.extend([ - "", - "For each requirement, determine if the tool call PASSES or FAILS.", - "Consider RFC 2119 keywords:", - "- MUST/MUST NOT: strict pass/fail — any violation is a failure", - "- SHOULD/SHOULD NOT: fail only if the violation is clear and easily avoidable", - "- MAY: always pass (informational only)", - "", - "If justifications are provided, consider them when making your determination.", - "A good justification can override a SHOULD violation but not a MUST violation.", - "", - "Return ONLY a JSON array with no other text:", - '[{"requirement_id": "...", "passed": true/false, "explanation": "..."}]', - ]) + parts.extend( + [ + "", + "For each requirement, determine if the tool call PASSES or FAILS.", + "Consider RFC 2119 keywords:", + "- MUST/MUST NOT: strict pass/fail — any violation is a failure", + "- SHOULD/SHOULD NOT: fail only if the violation is clear and easily avoidable", + "- MAY: always pass (informational only)", + "", + "If justifications are provided, consider them when making your determination.", + "A good justification can override a SHOULD violation but not a MUST violation.", + "", + "Return ONLY a JSON array with no other text:", + '[{"requirement_id": "...", "passed": true/false, "explanation": "..."}]', + ] + ) return "\n".join(parts) @@ -163,7 +167,7 @@ def _parse_result( if not isinstance(obj, dict): # Raw JSON array — use as-is result_text = line - continue + break # stream-json emits objects with "type" field if obj.get("type") == "result": result_text = obj.get("result", "") @@ -224,12 +228,12 @@ def _parse_result( def _extract_json_array(text: str) -> list[dict[str, Any]] | None: - """Extract a JSON array from text that may contain surrounding prose.""" + """Extract a JSON array of dicts from text that may contain surrounding prose.""" # Try direct parse first try: parsed = json.loads(text.strip()) if isinstance(parsed, list): - return parsed + return [item for item in parsed if isinstance(item, dict)] except json.JSONDecodeError: pass @@ -249,7 +253,7 @@ def _extract_json_array(text: str) -> list[dict[str, Any]] | None: try: parsed = json.loads(text[start : i + 1]) if isinstance(parsed, list): - return parsed + return [item for item in parsed if isinstance(item, dict)] except json.JSONDecodeError: return None diff --git a/src/deepwork/tool_requirements/sidecar.py b/src/deepwork/tool_requirements/sidecar.py index 6aad5a96..8febe878 100644 --- a/src/deepwork/tool_requirements/sidecar.py +++ b/src/deepwork/tool_requirements/sidecar.py @@ -8,10 +8,12 @@ from __future__ import annotations +import asyncio import atexit import json import logging import os +import re import threading from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path @@ -27,7 +29,7 @@ class _SidecarHandler(BaseHTTPRequestHandler): """HTTP request handler for sidecar endpoints.""" - engine: ToolRequirementsEngine # Set via partial/class attr + engine: ToolRequirementsEngine # Injected by start_sidecar via type() def do_POST(self) -> None: try: @@ -45,6 +47,14 @@ def do_POST(self) -> None: else: self._respond(404, {"error": f"Unknown endpoint: {self.path}"}) + def _run_async(self, coro: Any) -> Any: + """Run an async coroutine in a new event loop, ensuring cleanup.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + def _handle_check(self, data: dict[str, Any]) -> None: tool_name = data.get("tool_name", "") tool_input = data.get("tool_input", {}) @@ -53,27 +63,27 @@ def _handle_check(self, data: dict[str, Any]) -> None: self._respond(400, {"error": "Missing tool_name"}) return - import asyncio - try: - loop = asyncio.new_event_loop() - result = loop.run_until_complete( - self.engine.check(tool_name, tool_input) - ) - loop.close() + result = self._run_async(self.engine.check(tool_name, tool_input)) except Exception as e: logger.exception("Error checking tool requirements") - self._respond(500, { - "decision": "deny", - "reason": f"Tool requirements evaluation error: {e}", - }) + self._respond( + 500, + { + "decision": "deny", + "reason": f"Tool requirements evaluation error: {e}", + }, + ) return - self._respond(200, { - "decision": "allow" if result.allowed else "deny", - "reason": result.reason, - "failed_checks": result.failed_checks, - }) + self._respond( + 200, + { + "decision": "allow" if result.allowed else "deny", + "reason": result.reason, + "failed_checks": result.failed_checks, + }, + ) def _handle_appeal(self, data: dict[str, Any]) -> None: tool_name = data.get("tool_name", "") @@ -87,27 +97,29 @@ def _handle_appeal(self, data: dict[str, Any]) -> None: self._respond(400, {"error": "Missing policy_justification"}) return - import asyncio - try: - loop = asyncio.new_event_loop() - result = loop.run_until_complete( + result = self._run_async( self.engine.appeal(tool_name, tool_input, justifications) ) - loop.close() except Exception as e: logger.exception("Error processing appeal") - self._respond(500, { - "passed": False, - "reason": f"Appeal evaluation error: {e}", - }) + self._respond( + 500, + { + "passed": False, + "reason": f"Appeal evaluation error: {e}", + }, + ) return - self._respond(200, { - "passed": result.passed, - "reason": result.reason, - "no_exception_blocked": result.no_exception_blocked, - }) + self._respond( + 200, + { + "passed": result.passed, + "reason": result.reason, + "no_exception_blocked": result.no_exception_blocked, + }, + ) def _respond(self, status: int, body: dict[str, Any]) -> None: self.send_response(status) @@ -192,6 +204,11 @@ def cleanup() -> None: return SidecarInfo(pid=pid, port=port, port_file=port_file) +def _is_safe_session_id(session_id: str) -> bool: + """Validate that session_id is safe for use in filenames.""" + return bool(re.match(r"^[a-zA-Z0-9_-]+$", session_id)) + + def register_session(project_root: Path, session_id: str) -> None: """Register a session-to-sidecar mapping. @@ -202,6 +219,9 @@ def register_session(project_root: Path, session_id: str) -> None: project_root: Project root directory. session_id: The Claude Code session ID. """ + if not _is_safe_session_id(session_id): + return + pid = os.getpid() sidecar_dir = project_root / ".deepwork" / "tmp" / "tool_req_sidecar" @@ -238,7 +258,7 @@ def discover_sidecar(project_root: Path, session_id: str) -> dict[str, Any] | No return None # Try session-specific mapping first - if session_id: + if session_id and _is_safe_session_id(session_id): session_file = sidecar_dir / f"session_{session_id}.json" info = _read_and_validate_port_file(session_file) if info is not None: diff --git a/tests/unit/test_tool_requirements_hook.py b/tests/unit/test_tool_requirements_hook.py index f8ab7024..7dcb4bad 100644 --- a/tests/unit/test_tool_requirements_hook.py +++ b/tests/unit/test_tool_requirements_hook.py @@ -1,11 +1,8 @@ """Tests for the tool requirements PreToolUse hook.""" -import json from unittest.mock import MagicMock, patch -import pytest - -from deepwork.hooks.wrapper import HookInput, HookOutput, NormalizedEvent, Platform +from deepwork.hooks.wrapper import HookInput, NormalizedEvent, Platform class TestToolRequirementsHook: @@ -42,15 +39,6 @@ def test_skips_non_before_tool_events(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "" - def test_loop_prevention_skips_appeal_tool(self) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - - hook_input = self._make_input( - raw_tool_name="mcp__plugin_deepwork_deepwork__appeal_tool_requirement" - ) - result = tool_requirements_hook(hook_input) - assert result.decision == "" - def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook @@ -77,9 +65,7 @@ def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: @patch("deepwork.hooks.tool_requirements._http_post") @patch("deepwork.hooks.tool_requirements.discover_sidecar") - def test_allow_on_sidecar_allow( - self, mock_discover: MagicMock, mock_post: MagicMock - ) -> None: + def test_allow_on_sidecar_allow(self, mock_discover: MagicMock, mock_post: MagicMock) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook mock_discover.return_value = {"pid": 123, "port": 9999} @@ -92,9 +78,7 @@ def test_allow_on_sidecar_allow( @patch("deepwork.hooks.tool_requirements._http_post") @patch("deepwork.hooks.tool_requirements.discover_sidecar") - def test_deny_on_sidecar_deny( - self, mock_discover: MagicMock, mock_post: MagicMock - ) -> None: + def test_deny_on_sidecar_deny(self, mock_discover: MagicMock, mock_post: MagicMock) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook mock_discover.return_value = {"pid": 123, "port": 9999} @@ -107,7 +91,10 @@ def test_deny_on_sidecar_deny( result = tool_requirements_hook(hook_input) assert result.raw_output["hookSpecificOutput"]["permissionDecision"] == "deny" - assert "Policy violation" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + assert ( + "Policy violation" + in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + ) @patch("deepwork.hooks.tool_requirements._http_post") @patch("deepwork.hooks.tool_requirements.discover_sidecar") @@ -123,4 +110,6 @@ def test_fail_closed_on_connection_error( result = tool_requirements_hook(hook_input) assert result.raw_output["hookSpecificOutput"]["permissionDecision"] == "deny" - assert "Failed to reach" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + assert ( + "Failed to reach" in result.raw_output["hookSpecificOutput"]["permissionDecisionReason"] + ) diff --git a/tests/unit/tool_requirements/test_config.py b/tests/unit/tool_requirements/test_config.py index 9b5d18e8..97d18df4 100644 --- a/tests/unit/tool_requirements/test_config.py +++ b/tests/unit/tool_requirements/test_config.py @@ -6,8 +6,6 @@ import yaml from deepwork.tool_requirements.config import ( - Requirement, - ToolPolicy, ToolRequirementsError, parse_policy_file, ) diff --git a/tests/unit/tool_requirements/test_discovery.py b/tests/unit/tool_requirements/test_discovery.py index 62a95d85..052ddf8e 100644 --- a/tests/unit/tool_requirements/test_discovery.py +++ b/tests/unit/tool_requirements/test_discovery.py @@ -28,33 +28,49 @@ def test_empty_dir(self, project: Path) -> None: assert load_all_policies(project) == [] def test_single_policy(self, project: Path) -> None: - _write_policy(project, "bash_safety", { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - }) + _write_policy( + project, + "bash_safety", + { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }, + ) policies = load_all_policies(project) assert len(policies) == 1 assert policies[0].name == "bash_safety" def test_multiple_policies(self, project: Path) -> None: - _write_policy(project, "bash_safety", { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - }) - _write_policy(project, "write_safety", { - "tools": ["write_file"], - "requirements": {"r2": {"rule": "SHOULD verify"}}, - }) + _write_policy( + project, + "bash_safety", + { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }, + ) + _write_policy( + project, + "write_safety", + { + "tools": ["write_file"], + "requirements": {"r2": {"rule": "SHOULD verify"}}, + }, + ) policies = load_all_policies(project) assert len(policies) == 2 names = {p.name for p in policies} assert names == {"bash_safety", "write_safety"} def test_bad_file_skipped(self, project: Path) -> None: - _write_policy(project, "good", { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - }) + _write_policy( + project, + "good", + { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + }, + ) # Write invalid YAML bad = project / ".deepwork" / "tool_requirements" / "bad.yml" bad.write_text("tools: !!invalid") @@ -66,51 +82,71 @@ def test_bad_file_skipped(self, project: Path) -> None: class TestInheritance: def test_extends_merges_requirements(self, project: Path) -> None: - _write_policy(project, "parent", { - "tools": ["shell"], - "requirements": { - "parent-req": {"rule": "MUST do parent thing"}, + _write_policy( + project, + "parent", + { + "tools": ["shell"], + "requirements": { + "parent-req": {"rule": "MUST do parent thing"}, + }, }, - }) - _write_policy(project, "child", { - "tools": ["shell"], - "extends": ["parent"], - "requirements": { - "child-req": {"rule": "MUST do child thing"}, + ) + _write_policy( + project, + "child", + { + "tools": ["shell"], + "extends": ["parent"], + "requirements": { + "child-req": {"rule": "MUST do child thing"}, + }, }, - }) + ) policies = load_all_policies(project) child = next(p for p in policies if p.name == "child") assert "parent-req" in child.requirements assert "child-req" in child.requirements def test_child_overrides_parent(self, project: Path) -> None: - _write_policy(project, "parent", { - "tools": ["shell"], - "requirements": { - "shared": {"rule": "Parent version", "no_exception": True}, + _write_policy( + project, + "parent", + { + "tools": ["shell"], + "requirements": { + "shared": {"rule": "Parent version", "no_exception": True}, + }, }, - }) - _write_policy(project, "child", { - "tools": ["shell"], - "extends": ["parent"], - "requirements": { - "shared": {"rule": "Child version"}, + ) + _write_policy( + project, + "child", + { + "tools": ["shell"], + "extends": ["parent"], + "requirements": { + "shared": {"rule": "Child version"}, + }, }, - }) + ) policies = load_all_policies(project) child = next(p for p in policies if p.name == "child") assert child.requirements["shared"].rule == "Child version" assert child.requirements["shared"].no_exception is False def test_unknown_parent_ignored(self, project: Path) -> None: - _write_policy(project, "child", { - "tools": ["shell"], - "extends": ["nonexistent"], - "requirements": { - "r1": {"rule": "MUST check"}, + _write_policy( + project, + "child", + { + "tools": ["shell"], + "extends": ["nonexistent"], + "requirements": { + "r1": {"rule": "MUST check"}, + }, }, - }) + ) policies = load_all_policies(project) assert len(policies) == 1 assert "r1" in policies[0].requirements diff --git a/tests/unit/tool_requirements/test_engine.py b/tests/unit/tool_requirements/test_engine.py index 00766eff..6ef09618 100644 --- a/tests/unit/tool_requirements/test_engine.py +++ b/tests/unit/tool_requirements/test_engine.py @@ -1,12 +1,11 @@ """Tests for tool requirements engine.""" from pathlib import Path -from unittest.mock import AsyncMock +from typing import Any import pytest import yaml -from deepwork.tool_requirements.cache import ToolRequirementsCache from deepwork.tool_requirements.config import Requirement from deepwork.tool_requirements.engine import ToolRequirementsEngine from deepwork.tool_requirements.evaluator import RequirementEvaluator, RequirementVerdict @@ -20,16 +19,19 @@ def __init__(self, verdicts: list[RequirementVerdict] | None = None) -> None: self.call_count = 0 self.last_justifications: dict[str, str] | None = None - async def evaluate(self, requirements, tool_name, tool_input, justifications=None): + async def evaluate( + self, + requirements: dict[str, Requirement], + tool_name: str, + tool_input: dict[str, Any], + justifications: dict[str, str] | None = None, + ) -> list[RequirementVerdict]: self.call_count += 1 self.last_justifications = justifications if self._verdicts: return self._verdicts # Default: all pass - return [ - RequirementVerdict(req_id, True, "OK") - for req_id in requirements - ] + return [RequirementVerdict(req_id, True, "OK") for req_id in requirements] def _setup_project(tmp_path: Path, policies: dict[str, dict]) -> Path: @@ -49,24 +51,30 @@ async def test_no_policies_allows(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "write_rules": { - "tools": ["write_file"], - "requirements": {"r1": {"rule": "MUST check"}}, - } - }) + project = _setup_project( + tmp_path, + { + "write_rules": { + "tools": ["write_file"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }, + ) engine = ToolRequirementsEngine(project, MockEvaluator()) result = await engine.check("shell", {"command": "ls"}) assert result.allowed is True @pytest.mark.asyncio() async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "bash_rules": { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - } - }) + project = _setup_project( + tmp_path, + { + "bash_rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }, + ) evaluator = MockEvaluator() engine = ToolRequirementsEngine(project, evaluator) @@ -81,19 +89,24 @@ async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": { - "r1": {"rule": "MUST do A"}, - "r2": {"rule": "MUST do B"}, - }, - } - }) - evaluator = MockEvaluator(verdicts=[ - RequirementVerdict("r1", False, "Failed A"), - RequirementVerdict("r2", False, "Failed B"), - ]) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST do A"}, + "r2": {"rule": "MUST do B"}, + }, + } + }, + ) + evaluator = MockEvaluator( + verdicts=[ + RequirementVerdict("r1", False, "Failed A"), + RequirementVerdict("r2", False, "Failed B"), + ] + ) engine = ToolRequirementsEngine(project, evaluator) result = await engine.check("shell", {"command": "bad"}) @@ -106,17 +119,22 @@ async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_no_exception_label_in_error(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": { - "r1": {"rule": "MUST check", "no_exception": True}, - }, - } - }) - evaluator = MockEvaluator(verdicts=[ - RequirementVerdict("r1", False, "Blocked"), - ]) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST check", "no_exception": True}, + }, + } + }, + ) + evaluator = MockEvaluator( + verdicts=[ + RequirementVerdict("r1", False, "Blocked"), + ] + ) engine = ToolRequirementsEngine(project, evaluator) result = await engine.check("shell", {"command": "bad"}) @@ -126,19 +144,23 @@ async def test_no_exception_label_in_error(self, tmp_path: Path) -> None: class TestEngineAppeal: @pytest.mark.asyncio() async def test_successful_appeal_caches(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": { - "r1": {"rule": "SHOULD check"}, - }, - } - }) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "SHOULD check"}, + }, + } + }, + ) evaluator = MockEvaluator() # All pass by default engine = ToolRequirementsEngine(project, evaluator) result = await engine.appeal( - "shell", {"command": "rm file"}, + "shell", + {"command": "rm file"}, justifications={"r1": "It's a temp file"}, ) assert result.passed is True @@ -150,18 +172,22 @@ async def test_successful_appeal_caches(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": { - "r1": {"rule": "MUST NOT", "no_exception": True}, - }, - } - }) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": { + "r1": {"rule": "MUST NOT", "no_exception": True}, + }, + } + }, + ) engine = ToolRequirementsEngine(project, MockEvaluator()) result = await engine.appeal( - "shell", {"command": "bad"}, + "shell", + {"command": "bad"}, justifications={"r1": "Please?"}, ) assert result.passed is False @@ -170,12 +196,15 @@ async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_empty_justifications_rejected(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - } - }) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }, + ) engine = ToolRequirementsEngine(project, MockEvaluator()) result = await engine.appeal("shell", {"command": "bad"}, justifications={}) @@ -183,19 +212,25 @@ async def test_empty_justifications_rejected(self, tmp_path: Path) -> None: @pytest.mark.asyncio() async def test_failed_appeal_not_cached(self, tmp_path: Path) -> None: - project = _setup_project(tmp_path, { - "rules": { - "tools": ["shell"], - "requirements": {"r1": {"rule": "MUST check"}}, - } - }) - evaluator = MockEvaluator(verdicts=[ - RequirementVerdict("r1", False, "Still bad"), - ]) + project = _setup_project( + tmp_path, + { + "rules": { + "tools": ["shell"], + "requirements": {"r1": {"rule": "MUST check"}}, + } + }, + ) + evaluator = MockEvaluator( + verdicts=[ + RequirementVerdict("r1", False, "Still bad"), + ] + ) engine = ToolRequirementsEngine(project, evaluator) result = await engine.appeal( - "shell", {"command": "bad"}, + "shell", + {"command": "bad"}, justifications={"r1": "Please"}, ) assert result.passed is False diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py index 33f3ffc3..39a22618 100644 --- a/tests/unit/tool_requirements/test_evaluator.py +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -2,11 +2,8 @@ import json -import pytest - from deepwork.tool_requirements.config import Requirement from deepwork.tool_requirements.evaluator import ( - RequirementVerdict, _build_prompt, _extract_json_array, _parse_result, @@ -25,7 +22,9 @@ def test_includes_tool_info(self) -> None: def test_includes_justifications_when_present(self) -> None: reqs = {"r1": Requirement(rule="MUST check")} prompt = _build_prompt( - reqs, "shell", {"command": "ls"}, + reqs, + "shell", + {"command": "ls"}, justifications={"r1": "This is safe because..."}, ) assert "This is safe because..." in prompt diff --git a/tests/unit/tool_requirements/test_matcher.py b/tests/unit/tool_requirements/test_matcher.py index 62392af8..cb95246c 100644 --- a/tests/unit/tool_requirements/test_matcher.py +++ b/tests/unit/tool_requirements/test_matcher.py @@ -2,8 +2,6 @@ from pathlib import Path -import pytest - from deepwork.tool_requirements.config import Requirement, ToolPolicy from deepwork.tool_requirements.matcher import match_policies, merge_requirements From 0944b536ba3b5118b9d660e7e51675a3d1c818ec Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 12:39:13 -0600 Subject: [PATCH 3/7] docs: update documentation for tool requirements system - doc/mcp_interface.md: add appeal_tool_requirement as tool #12, bump count - doc/architecture.md: add tool_requirements/ package and hook to structure - CLAUDE.md: add tool_requirements/ and hook to project structure appendix - src/deepwork/hooks/README.md: add tool_requirements.py to files table - CHANGELOG.md: add tool requirements feature to Unreleased Co-Authored-By: Claude Opus 4.6 (1M context) --- doc/architecture.md | 13 +++++++++++-- doc/mcp_interface.md | 26 +++++++++++++++++++++++++- src/deepwork/hooks/README.md | 1 + 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/doc/architecture.md b/doc/architecture.md index 53292dc7..dc97aada 100644 --- a/doc/architecture.md +++ b/doc/architecture.md @@ -93,7 +93,16 @@ deepwork/ # DeepWork tool repository │ ├── schemas/ # Definition schemas │ │ ├── deepreview_schema.json │ │ ├── deepschema_schema.json -│ │ └── doc_spec_schema.py +│ │ ├── doc_spec_schema.py +│ │ └── tool_requirements_schema.json +│ ├── tool_requirements/ # Tool requirements policy enforcement +│ │ ├── cache.py # In-memory TTL cache for approved calls +│ │ ├── config.py # ToolPolicy/Requirement dataclasses, parser +│ │ ├── discovery.py # Load policies from .deepwork/tool_requirements/ +│ │ ├── engine.py # Check + appeal orchestration +│ │ ├── evaluator.py # LLM evaluator (Haiku) for requirement checking +│ │ ├── matcher.py # Match policies to tool calls +│ │ └── sidecar.py # HTTP sidecar server for hook communication │ └── utils/ │ ├── fs.py │ ├── git.py @@ -119,7 +128,7 @@ deepwork/ # DeepWork tool repository │ │ │ ├── new_user/SKILL.md │ │ │ ├── record/SKILL.md │ │ │ └── review/SKILL.md -│ │ ├── hooks/ # hooks.json, post_commit_reminder.sh, post_compact.sh, startup_context.sh, deepschema_write.sh +│ │ ├── hooks/ # hooks.json, post_commit_reminder.sh, post_compact.sh, startup_context.sh, deepschema_write.sh, tool_requirements.sh │ │ └── .mcp.json # MCP server config │ └── gemini/ # Gemini CLI extension │ └── skills/deepwork/SKILL.md diff --git a/doc/mcp_interface.md b/doc/mcp_interface.md index dace2763..9b949779 100644 --- a/doc/mcp_interface.md +++ b/doc/mcp_interface.md @@ -10,7 +10,7 @@ This document describes the Model Context Protocol (MCP) tools exposed by the De ## Tools -DeepWork exposes eleven MCP tools: +DeepWork exposes twelve MCP tools: ### 1. `get_workflows` @@ -308,6 +308,29 @@ Retrieve the YAML content of a session-scoped job definition previously register } ``` +### 12. `appeal_tool_requirement` + +Appeal a tool requirement policy denial. When a tool call is blocked by a tool requirement policy, call this to appeal specific failed checks by providing justifications. Some checks are marked `no_exception` and cannot be appealed. If the appeal succeeds, the tool call is cached as approved and you can retry the original tool call. + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `tool_name` | `string` | Yes | The normalized tool name that was blocked | +| `tool_input` | `dict` | Yes | The exact tool_input that was blocked | +| `policy_justification` | `dict[string, string]` | Yes | Map of failed check names to justification strings | +| `session_id` | `string` | No | Session identifier (CLAUDE_CODE_SESSION_ID on Claude Code) | + +#### Returns + +```typescript +{ + passed: boolean; // Whether the appeal succeeded + reason: string; // Explanation of result + no_exception_blocked?: string[]; // Checks that cannot be appealed +} +``` + --- ## Shared Types @@ -491,6 +514,7 @@ Add to your `.mcp.json`: | Version | Changes | |---------|---------| +| 2.4.0 | Added `appeal_tool_requirement` tool for appealing tool requirement policy denials with justifications. | | 2.3.0 | Added `project_root` field to `ActiveStepInfo` — the absolute path to the MCP server's project root. Added `register_session_job` and `get_session_job` tools for transient session-scoped job definitions. Session jobs are discoverable by `start_workflow` via `session_id` lookup — they take priority over standard discovery. Added `deepplan` standard job with `create_deep_plan` workflow. | | 2.2.0 | `session_id` is now optional (`str | None`) on `start_workflow` only. On Claude Code (platform `"claude"`), the server raises `ToolError` if omitted. On other platforms, omitting it auto-generates a stable UUID; callers use the returned `begin_step.session_id` for all subsequent calls. `finished_step`, `abort_workflow`, and `go_to_step` continue to require `session_id`. Added `inputs` optional parameter to `start_workflow` for passing step argument values directly at workflow start. Added `issue_detected` optional field to all tool responses — present when the server detects configuration issues at startup; instructs agent to suggest repair to the user. | | 2.1.0 | Added `important_note` field to `StartWorkflowResponse` — instructs agents to clarify ambiguous user requests via `AskUserQuestion` when available. | diff --git a/src/deepwork/hooks/README.md b/src/deepwork/hooks/README.md index efec09ae..a8d171d1 100644 --- a/src/deepwork/hooks/README.md +++ b/src/deepwork/hooks/README.md @@ -140,6 +140,7 @@ pytest tests/shell_script_tests/test_hook_wrappers.py -v | `wrapper.py` | Cross-platform input/output normalization | | `deepschema_write.py` | DeepSchema write-time validation hook | | `post_commit_reminder.py` | Post-commit hook that nudges the agent to run `/review` (skips if all reviews already passed) | +| `tool_requirements.py` | PreToolUse hook for tool requirements policy enforcement | | `claude_hook.sh` | Shell wrapper for Claude Code | | `gemini_hook.sh` | Shell wrapper for Gemini CLI | | `.deepreview` | Review rule ensuring hooks use correct output routing (DW-REQ-006.6) | From feb32d5ab29fc7e96229607e0bf0f9c186a4cc9c Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 12:44:04 -0600 Subject: [PATCH 4/7] fix: address round-2 review findings - evaluator.py: fix comment accuracy, extract _filter_dicts to reduce DRY - discovery.py: fix diamond inheritance by copying visited set per parent - test_engine.py: remove redundant @pytest.mark.asyncio decorators, fix dict type annotation, replace internal cache access with call_count - test_evaluator.py: add tests for HaikuSubprocessEvaluator, deduplication, non-dict filtering, and invalid bracket JSON Co-Authored-By: Claude Opus 4.6 (1M context) --- AGENTS.md | 5 +- src/deepwork/tool_requirements/discovery.py | 2 +- src/deepwork/tool_requirements/evaluator.py | 10 ++- src/deepwork/tool_requirements/sidecar.py | 4 +- tests/unit/tool_requirements/test_engine.py | 19 ++---- .../unit/tool_requirements/test_evaluator.py | 61 +++++++++++++++++++ 6 files changed, 78 insertions(+), 23 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index a033ca05..ca3e5836 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -208,7 +208,8 @@ deepwork/ │ │ └── deepplan/ │ ├── standard_schemas/ # Built-in DeepSchema definitions │ ├── review/ # DeepWork Reviews system (.deepreview pipeline) -│ ├── schemas/ # Definition schemas (deepreview, deepschema, doc_spec) +│ ├── schemas/ # Definition schemas (deepreview, deepschema, doc_spec, tool_requirements) +│ ├── tool_requirements/ # Tool requirements policy enforcement (config, discovery, matcher, evaluator, engine, cache, sidecar) │ └── utils/ # Utilities (fs, git, yaml, validation) ├── platform/ # Shared platform-agnostic content │ └── skill-body.md # Canonical skill body (source of truth) @@ -227,7 +228,7 @@ deepwork/ │ │ │ ├── new_user/SKILL.md │ │ │ ├── record/SKILL.md │ │ │ └── review/SKILL.md -│ │ ├── hooks/ # hooks.json, post_commit_reminder.sh, post_compact.sh, startup_context.sh, deepschema_write.sh +│ │ ├── hooks/ # hooks.json, post_commit_reminder.sh, post_compact.sh, startup_context.sh, deepschema_write.sh, tool_requirements.sh │ │ └── .mcp.json # MCP server config │ └── gemini/ # Gemini CLI extension │ └── skills/deepwork/SKILL.md diff --git a/src/deepwork/tool_requirements/discovery.py b/src/deepwork/tool_requirements/discovery.py index a549d82e..ef67f34f 100644 --- a/src/deepwork/tool_requirements/discovery.py +++ b/src/deepwork/tool_requirements/discovery.py @@ -86,7 +86,7 @@ def resolve(name: str, visited: set[str]) -> ToolPolicy: if parent_name not in by_name: logger.warning("Policy '%s' extends unknown policy '%s'", policy.name, parent_name) continue - parent = resolve(parent_name, visited) + parent = resolve(parent_name, set(visited)) merged_requirements.update(parent.requirements) # Child requirements override parent diff --git a/src/deepwork/tool_requirements/evaluator.py b/src/deepwork/tool_requirements/evaluator.py index ec822fdf..c212137d 100644 --- a/src/deepwork/tool_requirements/evaluator.py +++ b/src/deepwork/tool_requirements/evaluator.py @@ -165,7 +165,7 @@ def _parse_result( try: obj = json.loads(line) if not isinstance(obj, dict): - # Raw JSON array — use as-is + # Non-object JSON value — try using as result text result_text = line break # stream-json emits objects with "type" field @@ -229,11 +229,15 @@ def _parse_result( def _extract_json_array(text: str) -> list[dict[str, Any]] | None: """Extract a JSON array of dicts from text that may contain surrounding prose.""" + + def _filter_dicts(items: list[Any]) -> list[dict[str, Any]]: + return [item for item in items if isinstance(item, dict)] + # Try direct parse first try: parsed = json.loads(text.strip()) if isinstance(parsed, list): - return [item for item in parsed if isinstance(item, dict)] + return _filter_dicts(parsed) except json.JSONDecodeError: pass @@ -253,7 +257,7 @@ def _extract_json_array(text: str) -> list[dict[str, Any]] | None: try: parsed = json.loads(text[start : i + 1]) if isinstance(parsed, list): - return [item for item in parsed if isinstance(item, dict)] + return _filter_dicts(parsed) except json.JSONDecodeError: return None diff --git a/src/deepwork/tool_requirements/sidecar.py b/src/deepwork/tool_requirements/sidecar.py index 8febe878..12b36248 100644 --- a/src/deepwork/tool_requirements/sidecar.py +++ b/src/deepwork/tool_requirements/sidecar.py @@ -98,9 +98,7 @@ def _handle_appeal(self, data: dict[str, Any]) -> None: return try: - result = self._run_async( - self.engine.appeal(tool_name, tool_input, justifications) - ) + result = self._run_async(self.engine.appeal(tool_name, tool_input, justifications)) except Exception as e: logger.exception("Error processing appeal") self._respond( diff --git a/tests/unit/tool_requirements/test_engine.py b/tests/unit/tool_requirements/test_engine.py index 6ef09618..0bf85551 100644 --- a/tests/unit/tool_requirements/test_engine.py +++ b/tests/unit/tool_requirements/test_engine.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Any -import pytest import yaml from deepwork.tool_requirements.config import Requirement @@ -34,7 +33,7 @@ async def evaluate( return [RequirementVerdict(req_id, True, "OK") for req_id in requirements] -def _setup_project(tmp_path: Path, policies: dict[str, dict]) -> Path: +def _setup_project(tmp_path: Path, policies: dict[str, Any]) -> Path: policy_dir = tmp_path / ".deepwork" / "tool_requirements" policy_dir.mkdir(parents=True) for name, data in policies.items(): @@ -43,13 +42,11 @@ def _setup_project(tmp_path: Path, policies: dict[str, dict]) -> Path: class TestEngineCheck: - @pytest.mark.asyncio() async def test_no_policies_allows(self, tmp_path: Path) -> None: engine = ToolRequirementsEngine(tmp_path, MockEvaluator()) result = await engine.check("shell", {"command": "ls"}) assert result.allowed is True - @pytest.mark.asyncio() async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -64,7 +61,6 @@ async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: result = await engine.check("shell", {"command": "ls"}) assert result.allowed is True - @pytest.mark.asyncio() async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -87,7 +83,6 @@ async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: assert result2.allowed is True assert evaluator.call_count == 1 # Not called again - @pytest.mark.asyncio() async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -117,7 +112,6 @@ async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: assert "Failed B" in result.reason assert set(result.failed_checks) == {"r1", "r2"} - @pytest.mark.asyncio() async def test_no_exception_label_in_error(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -142,7 +136,6 @@ async def test_no_exception_label_in_error(self, tmp_path: Path) -> None: class TestEngineAppeal: - @pytest.mark.asyncio() async def test_successful_appeal_caches(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -170,7 +163,6 @@ async def test_successful_appeal_caches(self, tmp_path: Path) -> None: check = await engine.check("shell", {"command": "rm file"}) assert check.allowed is True - @pytest.mark.asyncio() async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -194,7 +186,6 @@ async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: assert "no_exception" in result.reason.lower() assert "r1" in result.no_exception_blocked - @pytest.mark.asyncio() async def test_empty_justifications_rejected(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -210,7 +201,6 @@ async def test_empty_justifications_rejected(self, tmp_path: Path) -> None: result = await engine.appeal("shell", {"command": "bad"}, justifications={}) assert result.passed is False - @pytest.mark.asyncio() async def test_failed_appeal_not_cached(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -235,6 +225,7 @@ async def test_failed_appeal_not_cached(self, tmp_path: Path) -> None: ) assert result.passed is False - # Should NOT be cached - cache_key = engine.cache.make_key("shell", {"command": "bad"}) - assert not engine.cache.is_approved(cache_key) + # Verify not cached by checking evaluator is called again + result2 = await engine.check("shell", {"command": "bad"}) + assert result2.allowed is False + assert evaluator.call_count == 2 # Called again, not cached diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py index 39a22618..a0f6c397 100644 --- a/tests/unit/tool_requirements/test_evaluator.py +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -1,9 +1,13 @@ """Tests for tool requirements LLM evaluator.""" import json +from unittest.mock import AsyncMock, patch + +import pytest from deepwork.tool_requirements.config import Requirement from deepwork.tool_requirements.evaluator import ( + HaikuSubprocessEvaluator, _build_prompt, _extract_json_array, _parse_result, @@ -99,3 +103,60 @@ def test_result_type_message(self) -> None: reqs = {"r1": Requirement(rule="Rule 1")} result = _parse_result(raw, reqs) assert result[0].passed is True + + def test_duplicate_requirement_ids_deduplicated(self) -> None: + verdicts = [ + {"requirement_id": "r1", "passed": True, "explanation": "First"}, + {"requirement_id": "r1", "passed": False, "explanation": "Duplicate"}, + ] + raw = json.dumps(verdicts) + reqs = {"r1": Requirement(rule="Rule 1")} + result = _parse_result(raw, reqs) + assert len(result) == 1 + assert result[0].passed is True # First occurrence wins + + def test_non_dict_items_filtered(self) -> None: + result = _extract_json_array('[1, {"a": 1}, "str"]') + assert result == [{"a": 1}] + + def test_bracket_search_invalid_json(self) -> None: + result = _extract_json_array("prefix [not valid json] suffix") + assert result is None + + +class TestHaikuSubprocessEvaluator: + @pytest.mark.asyncio() + @patch("asyncio.create_subprocess_exec") + async def test_call_haiku_success(self, mock_exec: AsyncMock) -> None: + verdicts = [{"requirement_id": "r1", "passed": True, "explanation": "OK"}] + stdout = json.dumps(verdicts).encode() + + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (stdout, b"") + mock_proc.returncode = 0 + mock_exec.return_value = mock_proc + + evaluator = HaikuSubprocessEvaluator() + reqs = {"r1": Requirement(rule="MUST check")} + result = await evaluator.evaluate(reqs, "shell", {"command": "ls"}) + assert len(result) == 1 + assert result[0].passed is True + + @pytest.mark.asyncio() + @patch("asyncio.create_subprocess_exec") + async def test_call_haiku_failure_raises(self, mock_exec: AsyncMock) -> None: + mock_proc = AsyncMock() + mock_proc.communicate.return_value = (b"", b"error") + mock_proc.returncode = 1 + mock_exec.return_value = mock_proc + + evaluator = HaikuSubprocessEvaluator() + reqs = {"r1": Requirement(rule="MUST check")} + with pytest.raises(RuntimeError, match="Haiku subprocess failed"): + await evaluator.evaluate(reqs, "shell", {"command": "ls"}) + + @pytest.mark.asyncio() + async def test_empty_requirements_returns_empty(self) -> None: + evaluator = HaikuSubprocessEvaluator() + result = await evaluator.evaluate({}, "shell", {"command": "ls"}) + assert result == [] From 2174e03a5d889450a9c920461db49dfd57762d4c Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 15:34:30 -0600 Subject: [PATCH 5/7] docs: add DW-REQ-012 requirement spec and test traceability - Create DW-REQ-012-tool-requirements.md with 12 sub-requirements covering policy format, discovery, inheritance, matching, evaluation, check flow, appeal, caching, hook, sidecar, multi-instance, and startup - Add PLUG-REQ-001.15 for the PreToolUse hook registration - Add requirement ID references to all test module docstrings - Add THIS TEST VALIDATES traceability comments to critical tests Co-Authored-By: Claude Opus 4.6 (1M context) --- .../deepwork/DW-REQ-012-tool-requirements.md | 104 ++++++++++++++++++ .../PLUG-REQ-001-claude-code-plugin.md | 6 + tests/unit/test_tool_requirements_hook.py | 6 +- tests/unit/tool_requirements/test_cache.py | 2 +- tests/unit/tool_requirements/test_config.py | 2 +- .../unit/tool_requirements/test_discovery.py | 2 +- tests/unit/tool_requirements/test_engine.py | 8 +- .../unit/tool_requirements/test_evaluator.py | 2 +- tests/unit/tool_requirements/test_matcher.py | 2 +- 9 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 doc/specs/deepwork/DW-REQ-012-tool-requirements.md diff --git a/doc/specs/deepwork/DW-REQ-012-tool-requirements.md b/doc/specs/deepwork/DW-REQ-012-tool-requirements.md new file mode 100644 index 00000000..2421af95 --- /dev/null +++ b/doc/specs/deepwork/DW-REQ-012-tool-requirements.md @@ -0,0 +1,104 @@ +# DW-REQ-012: Tool Requirements Policy Enforcement + +## Overview + +The tool requirements system enforces RFC 2119-style policies on AI agent tool calls. Users define rules in `.deepwork/tool_requirements/*.yml` files. A PreToolUse hook checks these rules via an HTTP sidecar server using LLM-based semantic evaluation. Failed checks can be appealed via an MCP tool. Approved calls are cached with a TTL. + +## DW-REQ-012.1: Policy File Format + +1. Policy files MUST be YAML files located in `.deepwork/tool_requirements/` with a `.yml` extension. +2. Each policy file MUST be validated against the `tool_requirements_schema.json` JSON Schema. +3. The `tools` field MUST be a non-empty array of normalized tool names (e.g., `shell`, `write_file`, `edit_file`) or MCP tool names (e.g., `mcp__server__tool`). +4. The `requirements` field MUST be a mapping of requirement identifiers to objects containing a `rule` string (RFC 2119 statement) and an optional `no_exception` boolean (default: `false`). +5. The `match` field MAY be a mapping of tool_input parameter names to regex patterns for parameter-level filtering. +6. The `extends` field MAY be an array of policy file stems for inheritance. +7. The `summary` field MAY be a human-readable description of the policy. + +## DW-REQ-012.2: Policy Discovery + +1. The system MUST scan `.deepwork/tool_requirements/` for `*.yml` files (single-directory, no tree walk). +2. Files that fail to parse MUST be skipped with a warning logged — they MUST NOT prevent other policies from loading. +3. If the `.deepwork/tool_requirements/` directory does not exist, the system MUST return an empty policy list without error. + +## DW-REQ-012.3: Policy Inheritance + +1. When a policy lists `extends`, the system MUST merge parent requirements into the child. +2. Child requirements MUST override parent requirements with the same key. +3. Unknown parent names MUST be logged as warnings and skipped — they MUST NOT cause errors. +4. Circular inheritance MUST be detected and MUST NOT cause infinite loops. +5. Diamond inheritance (two parents sharing a common ancestor) MUST be handled correctly — the common ancestor's requirements MUST be included once. + +## DW-REQ-012.4: Policy Matching + +1. A policy MUST match a tool call if the tool's normalized name is in the policy's `tools` list. +2. If the policy has a `match` dict, the policy MUST match only when at least one parameter regex matches a value in `tool_input` (via `re.search`). +3. If the policy has no `match` dict, it MUST match all calls to the listed tools. +4. Multiple policies MAY match a single tool call; all matched requirements MUST be merged. +5. If the same requirement key appears in multiple matched policies, the first occurrence MUST win. +6. Invalid regex patterns in `match` MUST be skipped without error. + +## DW-REQ-012.5: Requirement Evaluation + +1. Requirements MUST be evaluated by an LLM evaluator (Haiku by default) that considers RFC 2119 keywords. +2. `MUST`/`MUST NOT` violations MUST always result in failure. +3. `SHOULD`/`SHOULD NOT` violations MUST result in failure only when the violation is clear and easily avoidable. +4. `MAY` requirements MUST always pass. +5. The evaluator MUST return a verdict for every requirement — requirements not evaluated MUST fail closed. +6. The evaluator MUST be encapsulated behind an abstract interface to allow implementation swapping. +7. Large `tool_input` values MUST be truncated to avoid exceeding LLM token limits. + +## DW-REQ-012.6: Check Flow + +1. When a tool call is checked, the system MUST first check the cache — if approved, it MUST allow immediately. +2. If no policies match the tool call, it MUST be allowed. +3. If evaluation passes all requirements, the result MUST be cached and the call MUST be allowed. +4. If any requirements fail, the response MUST include ALL failures (not one at a time). +5. Each failure MUST include the requirement ID and an explanation. +6. `no_exception` requirements MUST be labeled as such in the failure message. +7. The failure message MUST include instructions for how to appeal via the `appeal_tool_requirement` MCP tool. + +## DW-REQ-012.7: Appeal Mechanism + +1. The system MUST provide an `appeal_tool_requirement` MCP tool. +2. The tool MUST accept `tool_name`, `tool_input`, and `policy_justification` (a dict mapping failed check IDs to justification strings). +3. `no_exception` requirements MUST NOT be appealable — appeals for them MUST be rejected immediately. +4. For appealable requirements, the evaluator MUST re-evaluate considering the provided justifications. +5. If the appeal succeeds, the result MUST be cached so the retried tool call passes the hook. +6. If the appeal fails, the response MUST list all still-failing requirements. + +## DW-REQ-012.8: Caching + +1. Approved tool calls MUST be cached with a 1-hour TTL. +2. The cache key MUST be deterministic, derived from the tool name and tool input. +3. Expired cache entries MUST be evicted on lookup. +4. The cache MUST be in-memory within the sidecar server process. + +## DW-REQ-012.9: PreToolUse Hook + +1. The hook MUST fire on all PreToolUse events (empty matcher). +2. The hook MUST skip the `appeal_tool_requirement` MCP tool to prevent infinite loops (substring match on raw tool name). +3. If the sidecar is unreachable (port file missing or PID dead), the hook MUST deny with an error message instructing the user to restart the MCP server (fail-closed). +4. If communication with the sidecar fails, the hook MUST deny with an error message (fail-closed). +5. The hook MUST use `hookSpecificOutput.permissionDecision: "deny"` format for Claude Code PreToolUse events. +6. The hook MUST use the cross-platform wrapper system (`run_hook`, `HookInput`, `HookOutput`). + +## DW-REQ-012.10: Sidecar HTTP Server + +1. The sidecar MUST start as a daemon thread alongside the MCP server when policy files exist. +2. The sidecar MUST bind to `127.0.0.1` on a random port. +3. The sidecar MUST write a port file to `.deepwork/tmp/tool_req_sidecar/.json` containing `{"pid": , "port": }`. +4. The sidecar MUST provide `POST /check` and `POST /appeal` endpoints. +5. The sidecar MUST clean up its port file and any session mapping files on exit. +6. Session IDs used in filenames MUST be validated against `^[a-zA-Z0-9_-]+$` to prevent path traversal. + +## DW-REQ-012.11: Multi-Instance Support + +1. When the first MCP tool call arrives with a `session_id`, the server MUST write a session mapping file at `.deepwork/tmp/tool_req_sidecar/session_.json`. +2. The hook MUST look for a session-specific mapping file first, then fall back to scanning PID-keyed port files for live processes. +3. Stale port files (PID no longer alive) MUST be cleaned up during discovery. + +## DW-REQ-012.12: Sidecar Startup Gating + +1. The sidecar MUST NOT start if no `.deepwork/tool_requirements/` directory exists. +2. The sidecar MUST NOT start if the directory contains no `*.yml` files. +3. If sidecar startup fails, the MCP server MUST continue running — the failure MUST be logged as a warning. diff --git a/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md b/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md index 8a045800..a2a434df 100644 --- a/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md +++ b/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md @@ -99,6 +99,12 @@ The Claude Code plugin is the primary distribution mechanism for DeepWork on the 3. The skill MUST explain how DeepSchemas automatically generate synthetic review rules. 4. The skill MUST describe workflow quality gates and how `finished_step` triggers reviews on step outputs. +### PLUG-REQ-001.15: Tool Requirements PreToolUse Hook + +1. The plugin MUST register a PreToolUse hook in `plugins/claude/hooks/hooks.json` with an empty matcher (matches all tool calls). +2. The hook MUST delegate to `deepwork hook tool_requirements` via a shell wrapper at `plugins/claude/hooks/tool_requirements.sh`. +3. The hook MUST skip the `appeal_tool_requirement` MCP tool to prevent infinite loops (see DW-REQ-012.9.2). + ### PLUG-REQ-001.14: Default Reviewer Subagent 1. The plugin MUST ship a default reviewer subagent at `plugins/claude/agents/reviewer.md`. diff --git a/tests/unit/test_tool_requirements_hook.py b/tests/unit/test_tool_requirements_hook.py index 7dcb4bad..01f9db76 100644 --- a/tests/unit/test_tool_requirements_hook.py +++ b/tests/unit/test_tool_requirements_hook.py @@ -1,4 +1,4 @@ -"""Tests for the tool requirements PreToolUse hook.""" +"""Tests for the tool requirements PreToolUse hook (DW-REQ-012.9, PLUG-REQ-001.15).""" from unittest.mock import MagicMock, patch @@ -39,6 +39,8 @@ def test_skips_non_before_tool_events(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "" + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9.2). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook @@ -51,6 +53,8 @@ def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "", f"Failed for {prefix}" + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9.3). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES @patch("deepwork.hooks.tool_requirements.discover_sidecar") def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook diff --git a/tests/unit/tool_requirements/test_cache.py b/tests/unit/tool_requirements/test_cache.py index 57c4d64b..dd4086a6 100644 --- a/tests/unit/tool_requirements/test_cache.py +++ b/tests/unit/tool_requirements/test_cache.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements TTL cache.""" +"""Tests for tool requirements TTL cache (DW-REQ-012.8).""" import time from unittest.mock import patch diff --git a/tests/unit/tool_requirements/test_config.py b/tests/unit/tool_requirements/test_config.py index 97d18df4..1725d4c3 100644 --- a/tests/unit/tool_requirements/test_config.py +++ b/tests/unit/tool_requirements/test_config.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements config parsing.""" +"""Tests for tool requirements config parsing (DW-REQ-012.1).""" from pathlib import Path diff --git a/tests/unit/tool_requirements/test_discovery.py b/tests/unit/tool_requirements/test_discovery.py index 052ddf8e..f68bf1aa 100644 --- a/tests/unit/tool_requirements/test_discovery.py +++ b/tests/unit/tool_requirements/test_discovery.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements discovery and inheritance.""" +"""Tests for tool requirements discovery (DW-REQ-012.2) and inheritance (DW-REQ-012.3).""" from pathlib import Path diff --git a/tests/unit/tool_requirements/test_engine.py b/tests/unit/tool_requirements/test_engine.py index 0bf85551..fe0a41a6 100644 --- a/tests/unit/tool_requirements/test_engine.py +++ b/tests/unit/tool_requirements/test_engine.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements engine.""" +"""Tests for tool requirements engine (DW-REQ-012.6, DW-REQ-012.7).""" from pathlib import Path from typing import Any @@ -61,6 +61,8 @@ async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: result = await engine.check("shell", {"command": "ls"}) assert result.allowed is True + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6.3, DW-REQ-012.8.1). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -83,6 +85,8 @@ async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: assert result2.allowed is True assert evaluator.call_count == 1 # Not called again + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6.4). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, @@ -163,6 +167,8 @@ async def test_successful_appeal_caches(self, tmp_path: Path) -> None: check = await engine.check("shell", {"command": "rm file"}) assert check.allowed is True + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.7.3). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: project = _setup_project( tmp_path, diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py index a0f6c397..8d133063 100644 --- a/tests/unit/tool_requirements/test_evaluator.py +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements LLM evaluator.""" +"""Tests for tool requirements LLM evaluator (DW-REQ-012.5).""" import json from unittest.mock import AsyncMock, patch diff --git a/tests/unit/tool_requirements/test_matcher.py b/tests/unit/tool_requirements/test_matcher.py index cb95246c..ffb17d94 100644 --- a/tests/unit/tool_requirements/test_matcher.py +++ b/tests/unit/tool_requirements/test_matcher.py @@ -1,4 +1,4 @@ -"""Tests for tool requirements policy matching.""" +"""Tests for tool requirements policy matching (DW-REQ-012.4).""" from pathlib import Path From 9358aee835ba0063ed6ef972e6ea77e0faf40d71 Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 16:36:06 -0600 Subject: [PATCH 6/7] fix: address round-3 review findings - DW-REQ-012.5.3: make SHOULD violation criterion concrete and testable - PLUG-REQ-001: fix section ordering (001.14 before 001.15) - test_engine.py: use two-level REQ ID format (DW-REQ-012.6 not 012.6.3) - test_hook.py: use two-level REQ ID format, fix traceability comment placement - test_evaluator.py: move tests to correct class, remove redundant decorators Co-Authored-By: Claude Opus 4.6 (1M context) --- .../deepwork/DW-REQ-012-tool-requirements.md | 2 +- .../PLUG-REQ-001-claude-code-plugin.md | 12 +++++------ tests/unit/test_tool_requirements_hook.py | 6 +++--- tests/unit/tool_requirements/test_engine.py | 6 +++--- .../unit/tool_requirements/test_evaluator.py | 20 ++++++++----------- 5 files changed, 21 insertions(+), 25 deletions(-) diff --git a/doc/specs/deepwork/DW-REQ-012-tool-requirements.md b/doc/specs/deepwork/DW-REQ-012-tool-requirements.md index 2421af95..3df696c0 100644 --- a/doc/specs/deepwork/DW-REQ-012-tool-requirements.md +++ b/doc/specs/deepwork/DW-REQ-012-tool-requirements.md @@ -41,7 +41,7 @@ The tool requirements system enforces RFC 2119-style policies on AI agent tool c 1. Requirements MUST be evaluated by an LLM evaluator (Haiku by default) that considers RFC 2119 keywords. 2. `MUST`/`MUST NOT` violations MUST always result in failure. -3. `SHOULD`/`SHOULD NOT` violations MUST result in failure only when the violation is clear and easily avoidable. +3. `SHOULD`/`SHOULD NOT` violations MUST result in failure only when the tool call could be trivially modified to comply (e.g., adding a flag, choosing a different command) — the evaluator prompt MUST instruct the LLM to apply this criterion. 4. `MAY` requirements MUST always pass. 5. The evaluator MUST return a verdict for every requirement — requirements not evaluated MUST fail closed. 6. The evaluator MUST be encapsulated behind an abstract interface to allow implementation swapping. diff --git a/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md b/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md index a2a434df..f8e3efec 100644 --- a/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md +++ b/doc/specs/deepwork/cli_plugins/PLUG-REQ-001-claude-code-plugin.md @@ -99,12 +99,6 @@ The Claude Code plugin is the primary distribution mechanism for DeepWork on the 3. The skill MUST explain how DeepSchemas automatically generate synthetic review rules. 4. The skill MUST describe workflow quality gates and how `finished_step` triggers reviews on step outputs. -### PLUG-REQ-001.15: Tool Requirements PreToolUse Hook - -1. The plugin MUST register a PreToolUse hook in `plugins/claude/hooks/hooks.json` with an empty matcher (matches all tool calls). -2. The hook MUST delegate to `deepwork hook tool_requirements` via a shell wrapper at `plugins/claude/hooks/tool_requirements.sh`. -3. The hook MUST skip the `appeal_tool_requirement` MCP tool to prevent infinite loops (see DW-REQ-012.9.2). - ### PLUG-REQ-001.14: Default Reviewer Subagent 1. The plugin MUST ship a default reviewer subagent at `plugins/claude/agents/reviewer.md`. @@ -113,3 +107,9 @@ The Claude Code plugin is the primary distribution mechanism for DeepWork on the 4. The agent body MUST instruct the subagent to read the instruction file from the user prompt, perform the review against the criteria in that file, and call `mark_review_as_passed` to report results. 5. The agent body MUST instruct the subagent not to edit files and not to explore beyond what the review instructions direct. 6. When the review formatter renders tasks with no per-rule agent persona specified (`agent_name` is `None`), it MUST default to `"reviewer"` as the `subagent_type` (see REVIEW-REQ-006.3.3c). + +### PLUG-REQ-001.15: Tool Requirements PreToolUse Hook + +1. The plugin MUST register a PreToolUse hook in `plugins/claude/hooks/hooks.json` with an empty matcher (matches all tool calls). +2. The hook MUST delegate to `deepwork hook tool_requirements` via a shell wrapper at `plugins/claude/hooks/tool_requirements.sh`. +3. The hook MUST skip the `appeal_tool_requirement` MCP tool to prevent infinite loops (see DW-REQ-012.9). diff --git a/tests/unit/test_tool_requirements_hook.py b/tests/unit/test_tool_requirements_hook.py index 01f9db76..680333b1 100644 --- a/tests/unit/test_tool_requirements_hook.py +++ b/tests/unit/test_tool_requirements_hook.py @@ -39,7 +39,7 @@ def test_skips_non_before_tool_events(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "" - # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9.2). + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook @@ -53,9 +53,9 @@ def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "", f"Failed for {prefix}" - # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9.3). - # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES @patch("deepwork.hooks.tool_requirements.discover_sidecar") + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9). + # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: from deepwork.hooks.tool_requirements import tool_requirements_hook diff --git a/tests/unit/tool_requirements/test_engine.py b/tests/unit/tool_requirements/test_engine.py index fe0a41a6..66217bc1 100644 --- a/tests/unit/tool_requirements/test_engine.py +++ b/tests/unit/tool_requirements/test_engine.py @@ -61,7 +61,7 @@ async def test_no_matching_policies_allows(self, tmp_path: Path) -> None: result = await engine.check("shell", {"command": "ls"}) assert result.allowed is True - # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6.3, DW-REQ-012.8.1). + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6, DW-REQ-012.8). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: project = _setup_project( @@ -85,7 +85,7 @@ async def test_all_pass_allows_and_caches(self, tmp_path: Path) -> None: assert result2.allowed is True assert evaluator.call_count == 1 # Not called again - # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6.4). + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.6). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_failure_denies_with_all_errors(self, tmp_path: Path) -> None: project = _setup_project( @@ -167,7 +167,7 @@ async def test_successful_appeal_caches(self, tmp_path: Path) -> None: check = await engine.check("shell", {"command": "rm file"}) assert check.allowed is True - # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.7.3). + # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.7). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES async def test_no_exception_blocks_appeal(self, tmp_path: Path) -> None: project = _setup_project( diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py index 8d133063..cdf50144 100644 --- a/tests/unit/tool_requirements/test_evaluator.py +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -56,6 +56,14 @@ def test_no_array(self) -> None: def test_empty_array(self) -> None: assert _extract_json_array("[]") == [] + def test_non_dict_items_filtered(self) -> None: + result = _extract_json_array('[1, {"a": 1}, "str"]') + assert result == [{"a": 1}] + + def test_bracket_search_invalid_json(self) -> None: + result = _extract_json_array("prefix [not valid json] suffix") + assert result is None + class TestParseResult: def test_parses_stream_json_result(self) -> None: @@ -115,17 +123,7 @@ def test_duplicate_requirement_ids_deduplicated(self) -> None: assert len(result) == 1 assert result[0].passed is True # First occurrence wins - def test_non_dict_items_filtered(self) -> None: - result = _extract_json_array('[1, {"a": 1}, "str"]') - assert result == [{"a": 1}] - - def test_bracket_search_invalid_json(self) -> None: - result = _extract_json_array("prefix [not valid json] suffix") - assert result is None - - class TestHaikuSubprocessEvaluator: - @pytest.mark.asyncio() @patch("asyncio.create_subprocess_exec") async def test_call_haiku_success(self, mock_exec: AsyncMock) -> None: verdicts = [{"requirement_id": "r1", "passed": True, "explanation": "OK"}] @@ -142,7 +140,6 @@ async def test_call_haiku_success(self, mock_exec: AsyncMock) -> None: assert len(result) == 1 assert result[0].passed is True - @pytest.mark.asyncio() @patch("asyncio.create_subprocess_exec") async def test_call_haiku_failure_raises(self, mock_exec: AsyncMock) -> None: mock_proc = AsyncMock() @@ -155,7 +152,6 @@ async def test_call_haiku_failure_raises(self, mock_exec: AsyncMock) -> None: with pytest.raises(RuntimeError, match="Haiku subprocess failed"): await evaluator.evaluate(reqs, "shell", {"command": "ls"}) - @pytest.mark.asyncio() async def test_empty_requirements_returns_empty(self) -> None: evaluator = HaikuSubprocessEvaluator() result = await evaluator.evaluate({}, "shell", {"command": "ls"}) From 012f0afd8c2315facba7af912819568e312350a1 Mon Sep 17 00:00:00 2001 From: Noah Horton Date: Wed, 15 Apr 2026 16:49:39 -0600 Subject: [PATCH 7/7] fix: address round-4 review findings - test_tool_requirements_hook.py: move import to module level (DRY) - test_evaluator.py: add missing blank line between classes (E302) Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/unit/test_tool_requirements_hook.py | 15 ++------------- tests/unit/tool_requirements/test_evaluator.py | 1 + 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_tool_requirements_hook.py b/tests/unit/test_tool_requirements_hook.py index 680333b1..02eafae9 100644 --- a/tests/unit/test_tool_requirements_hook.py +++ b/tests/unit/test_tool_requirements_hook.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch +from deepwork.hooks.tool_requirements import tool_requirements_hook from deepwork.hooks.wrapper import HookInput, NormalizedEvent, Platform @@ -29,8 +30,6 @@ def _make_input( ) def test_skips_non_before_tool_events(self) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - hook_input = HookInput( platform=Platform.CLAUDE, event=NormalizedEvent.AFTER_TOOL, @@ -42,8 +41,6 @@ def test_skips_non_before_tool_events(self) -> None: # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - for prefix in [ "mcp__deepwork__appeal_tool_requirement", "mcp__deepwork-dev__appeal_tool_requirement", @@ -53,12 +50,10 @@ def test_loop_prevention_skips_all_appeal_prefixes(self) -> None: result = tool_requirements_hook(hook_input) assert result.decision == "", f"Failed for {prefix}" - @patch("deepwork.hooks.tool_requirements.discover_sidecar") # THIS TEST VALIDATES A HARD REQUIREMENT (DW-REQ-012.9). # YOU MUST NOT MODIFY THIS TEST UNLESS THE REQUIREMENT CHANGES + @patch("deepwork.hooks.tool_requirements.discover_sidecar") def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - mock_discover.return_value = None hook_input = self._make_input() result = tool_requirements_hook(hook_input) @@ -70,8 +65,6 @@ def test_fail_closed_when_no_sidecar(self, mock_discover: MagicMock) -> None: @patch("deepwork.hooks.tool_requirements._http_post") @patch("deepwork.hooks.tool_requirements.discover_sidecar") def test_allow_on_sidecar_allow(self, mock_discover: MagicMock, mock_post: MagicMock) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - mock_discover.return_value = {"pid": 123, "port": 9999} mock_post.return_value = {"decision": "allow", "reason": "OK"} @@ -83,8 +76,6 @@ def test_allow_on_sidecar_allow(self, mock_discover: MagicMock, mock_post: Magic @patch("deepwork.hooks.tool_requirements._http_post") @patch("deepwork.hooks.tool_requirements.discover_sidecar") def test_deny_on_sidecar_deny(self, mock_discover: MagicMock, mock_post: MagicMock) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - mock_discover.return_value = {"pid": 123, "port": 9999} mock_post.return_value = { "decision": "deny", @@ -105,8 +96,6 @@ def test_deny_on_sidecar_deny(self, mock_discover: MagicMock, mock_post: MagicMo def test_fail_closed_on_connection_error( self, mock_discover: MagicMock, mock_post: MagicMock ) -> None: - from deepwork.hooks.tool_requirements import tool_requirements_hook - mock_discover.return_value = {"pid": 123, "port": 9999} mock_post.side_effect = ConnectionRefusedError("Connection refused") diff --git a/tests/unit/tool_requirements/test_evaluator.py b/tests/unit/tool_requirements/test_evaluator.py index cdf50144..084e1fba 100644 --- a/tests/unit/tool_requirements/test_evaluator.py +++ b/tests/unit/tool_requirements/test_evaluator.py @@ -123,6 +123,7 @@ def test_duplicate_requirement_ids_deduplicated(self) -> None: assert len(result) == 1 assert result[0].passed is True # First occurrence wins + class TestHaikuSubprocessEvaluator: @patch("asyncio.create_subprocess_exec") async def test_call_haiku_success(self, mock_exec: AsyncMock) -> None: