diff --git a/README.md b/README.md index 052f783..37a67ee 100644 --- a/README.md +++ b/README.md @@ -1,237 +1,159 @@

- Dhee - the context firewall for AI coding agents + Dhee - context compiler for AI coding agents

-

- Dhee decides what your agent should see, remember, forget, compress, and expand each turn. -

+

Dhee

+ +

+ Local-first context compiler, supervisor, and proof layer for AI coding agents. +

PyPI - Production ready Python 3.9+ MIT License LongMemEval R@1 94.8%

- Why | - Dhee UI | - Install | - How it works | - Integrations | - Benchmarks | - FAQ | - Security + Why | + Quick Start | + How It Works | + Protected Mode | + Update Capsules

--- -## Why Dhee +## Why -Coding agents do not usually fail because the model is too weak. They fail -because context gets messy: +AI coding agents do not fail only because the model is weak. They fail because +the working context gets messy: stale decisions, huge logs, repeated file reads, +lost handoffs, private memory mixed with repo context, and unverified edits. -- They reread the same files and logs. -- They carry stale decisions after task pivots. -- They dump huge test output into the model. -- They forget state after compaction or handoff. -- Teams cannot reuse what one agent learned without copying prompt sludge. - -Dhee runs locally beside your coding agent and governs context before it becomes -a token problem. +Dhee sits beside the agent and turns that mess into a compact, auditable +working state. | Without Dhee | With Dhee | | --- | --- | -| Raw logs, diffs, files, and subagent output flood context. | Large outputs become compact digests with expandable evidence pointers. | -| The agent guesses what still matters after compaction. | Dhee keeps a current state card: goal, facts, decisions, files, tests, next step. | -| Team knowledge lives in random transcripts and markdown files. | Promoted learnings and repo context are reusable across agents with provenance. | -| Memory grows forever. | Dhee scores, decays, tombstones, and gates what gets injected. | -| Switching agents means re-explaining the project. | Claude Code, Codex, Cursor, Gemini CLI, Aider, Cline, Hermes, and MCP clients share one local context layer. | - -The promise is simple: - -> Your agent should not see everything. It should see the right thing, with proof. - ---- - -## Dhee UI - -Run the local Dhee workspace UI. It needs no API key and no connected agent: - -```bash -dhee ui -``` - -

- -

- -

- Watch the 13-second UI demo -

- -The UI opens on a command center, then lets you inspect: - -- Context Firewall: token savings, digests, evidence pointers, expansions, and session history -- Repo Brain: an infinite folders canvas for linked repos, projects, active sessions, tasks, and shared context -- Handoff Hub: resumable task state without replaying the transcript -- Proof Replay: what Dhee injected, hid, digested, expanded, promoted, or rejected -- Learning Inbox: evidence-backed candidate learnings with promote/reject actions -- Portability & Trust: signed `.dheemem` export/import readiness and dry-run inspection - -The raw evidence still stays behind `dhee_expand_result(ptr="...")`; the UI -makes the routing and expansion decisions inspectable. +| Raw files, logs, diffs, chats, and screenshots flood the prompt. | Dhee compiles ranked context cards with evidence pointers. | +| The agent makes free-form plans and hopes they are safe. | Dhee emits task contracts, allowed actions, verifier cards, and proof bundles. | +| Memory grows until it becomes noise. | Dhee admits, scores, decays, summarizes, and promotes only what survives. | +| Team updates require copying code or long prompts. | Dhee packages reproducible change stories as update capsules. | ---- +The product promise is simple: -## Install +> No untracked context. No unproven edit. No repeated preventable failure. -One command: +## Quick Start ```bash curl -fsSL https://raw.githubusercontent.com/Sankhya-AI/Dhee/main/install.sh | sh +cd /path/to/repo +dhee init +dhee status ``` -The installer creates the managed local runtime, verifies the handoff bus, wires -detected harnesses, and ends with the two commands most people need first: -`cd /path/to/repo && dhee init` and `dhee ui`. - -Or via pip: +Or install from PyPI: ```bash pip install dhee dhee install ``` -Then open your coding agent in a project. Dhee auto-wires supported local -harnesses when detected and keeps its personal state under `~/.dhee`. - -Useful first commands: +Open the local workspace: ```bash -dhee status -dhee doctor dhee ui -dhee handoff -dhee context state --card -dhee runtime status ``` -Clean uninstall is part of the trust contract: +Core Dhee supports Python 3.9+. MCP server dependencies require Python 3.10+: ```bash -dhee uninstall --yes +python3.12 -m pip install "dhee[mcp]" ``` -It stops the daemon, removes Dhee-owned harness wiring and shell PATH blocks, -and deletes the managed local runtime/data directory. - ---- - -## What You Get - -**1. Current state, not transcript replay** - -Dhee keeps a compact state card for the active task: goal, facts, decisions, -files, tests, evidence pointers, and next step. - -```bash -dhee context provision "fix expired-token KeyError" -dhee context state --card -dhee context checkpoint --reason "before compaction" -``` +## How It Works -**2. Source-side routing** +

+ Dhee flow chart +

-Heavy `Read`, `Bash`, `Grep`, and agent results are digested before they flood -the model. +Dhee compiles context like a software build: -```text -10 MB pytest log -> failing test, first error, summary, head/tail, pointer -large git diff -> files changed, hunks, additions/deletions, pointer -source file -> symbols, imports, focus lines, pointer -``` +1. Read the current task, repo, branch state, tests, memories, agents, and tool output. +2. Produce a deterministic task contract: goal, files, allowed writes, forbidden paths, tests, budget, rollback plan. +3. Supervise every action against the contract. +4. Verify the result with tests, diffs, proof bundles, and contamination checks. +5. Store only compact lessons and scene cards. Raw evidence stays behind pointers. -**3. Evidence on demand** +## Protected Mode -The model can expand raw data only when the digest is not enough: +Use protected mode when the agent is allowed to modify code: -```text -dhee_expand_result(ptr="B-demo-pytest") +```bash +dhee context task create "Fix failing context firewall tests" --repo . +dhee context task enforce deny --repo . +dhee context task activate --repo . +dhee doctor contract-runtime --repo . ``` -Expansion reasons are logged, so Dhee learns which digests need more depth. +In `deny` mode, Dhee fails closed: -**4. Git-shared repo context** +- no active contract, no coding action +- supervisor unavailable, action blocked +- corrupt runtime state, diagnostic surfaced +- proof bundle required before submit -Teams can share decisions and conventions through the repository itself: +Release gate: ```bash -dhee link /path/to/repo -dhee context check --repo /path/to/repo +dhee release check --repo . ``` -Dhee stores shared context under `/.dhee/context`, with append-only -entries and conflict detection. No hosted server or org account is required. +This refuses release tagging unless `git status` is clean. Release intent can +document scope, but it does not bypass the clean-tree rule. -**5. Portable local memory** +## Update Capsules -`.dheemem` packs move Dhee state between machines and harnesses: +Dhee can turn a completed repo change into a portable update recipe: ```bash -dhee export --format dheemem --output backup.dheemem -dhee import backup.dheemem --format dheemem --strategy dry-run +dhee context capsule create --repo . --since HEAD~1 +dhee context capsule list --repo . +dhee context capsule show --repo . ``` -Packs are signed and validated before import. - ---- - -## How It Works +Each capsule stores: -```text -Agent asks for context - | - v -Dhee reads current task state, repo context, memories, and tool output - | - v -Context firewall decides: - state -> compact current truth - proof -> pointer-backed evidence - source -> exact raw expansion only when needed - | - v -Agent sees a small, relevant, auditable packet -``` +- `capsule.md`: before/after story, behavior, tests, reproduction guide +- `capsule.json`: changed paths, hashes, compact hunks, commands, evidence refs -The core interfaces stay small: +Capsules are not raw memory dumps. Personal context is private by default and +only sanitized lessons can be promoted into shareable repo context. -```python -from dhee import Dhee +## Memory Layer -d = Dhee() -d.remember("User prefers FastAPI over Flask") -d.recall("what framework does this project use?") -d.context("fixing the auth bug") -d.checkpoint("Fixed auth bug", what_worked="checked logs", outcome_score=1.0) -``` +Dhee memory is designed for long-lived developer work: -Every surface uses the same primitives: CLI, Python SDK, Claude Code hooks, -Codex session sync, and MCP tools. +- temporal scenes from noisy evidence +- hot, warm, and cold tiers +- pointer-backed artifacts, transcripts, screenshots, media, and future wearable streams +- provenance fields for user, agent, app, event, run, memory type, and privacy scope +- context packs that fit a hard token budget ---- +For passive capture, Dhee rejects low-quality UI noise and stores searchable +derivatives instead of raw prompt-heavy media. ## Integrations -| Surface | Dhee support | +| Surface | Support | | --- | --- | -| Claude Code | Deepest integration: hooks, MCP, handoff, shared tasks, router enforcement. | -| Codex | MCP config, global `AGENTS.md`, server instructions, and session-stream sync. | -| Cursor / Gemini CLI / Cline / Goose | MCP-first integration through `dhee-mcp`. | -| Hermes | Native MemoryProvider, learning import, promotion, and playbook exchange. | -| Aider / other CLIs | CLI, MCP, repo context, and portable `.dheemem` flows. | +| Claude Code | hooks, MCP, handoff, shared tasks, router enforcement | +| Codex | MCP config, `AGENTS.md`, server instructions, session-log sync | +| Cursor, Cline, Gemini CLI, Goose | MCP-first integration | +| Hermes | MemoryProvider, learning import, promotion, playbook exchange | +| Git | repo context, update capsules, conflict checks | MCP config: @@ -243,11 +165,17 @@ MCP config: } ``` -Codex note: Codex does not expose Claude-style pre-tool hooks. Dhee uses the -strongest truthful Codex surfaces available: MCP, `AGENTS.md`, config, server -instructions, and session-log sync. +## Useful Commands ---- +```bash +dhee handoff +dhee context state --card +dhee context checkpoint --reason "before compaction" +dhee context check --repo . +dhee doctor +dhee export --format dheemem --output backup.dheemem +dhee import backup.dheemem --format dheemem --strategy dry-run +``` ## Benchmarks @@ -256,89 +184,29 @@ Dhee reports LongMemEval retrieval results on the full 500-question set: | System | R@1 | R@5 | R@10 | | --- | ---: | ---: | ---: | | Dhee | 94.8% | 99.4% | 99.8% | -| MemPalace raw | - | 96.6% | - | -| MemPalace hybrid v4 | - | 98.4% | - | | agentmemory | - | 95.2% | 98.6% | +| MemPalace hybrid v4 | - | 98.4% | - | -Stack: NVIDIA `llama-nemotron-embed-vl-1b-v2` embedder plus -`llama-3.2-nv-rerankqa-1b-v2` reranker, top-k 10. - -The proof is committed under [`benchmarks/longmemeval/`](benchmarks/longmemeval/): -commands, metrics, and per-question output. - -Retrieval is only one piece. Dhee's stronger claim is context governance: -controlling what reaches the model before memory retrieval becomes prompt -pollution. - ---- - -## Public Core and Paid Layer - -Public Dhee is MIT and complete for local developer use: memory, router, -handoff, DheeFS, MCP, repo context, `.dheemem`, runtime, security checks, and -replay/report data. - -A paid team layer can sit on top for company needs: org dashboards, policy, -audit, SSO/RBAC, fleet health, billing, and governance workflows. The local -developer brain stays useful without it. - ---- - -## FAQ - -**Is Dhee another memory database?** - -No. Memory is part of Dhee, but the wedge is context governance: deciding what -the model sees now, what stays hidden behind proof pointers, and what should be -forgotten or tombstoned. - -**Does it require a server?** - -No. Dhee is local-first and uses SQLite by default. Repo-shared context uses git. - -**Does it store secrets in the repo?** - -It should not. Repo-shared context is meant for decisions and conventions, not -secrets or bulk private data. See [`SECURITY.md`](SECURITY.md). - -**Can I inspect or export my data?** - -Yes. Dhee exposes local shell/MCP surfaces and signed `.dheemem` export/import. -Clean uninstall is supported. - -**Which agent should I use it with first?** - -Claude Code gets the deepest routing integration. Codex gets the best available -MCP/session-sync integration. Cursor, Gemini CLI, Cline, Goose, and others work -through MCP. +Proof and commands live in [`benchmarks/longmemeval/`](benchmarks/longmemeval/). ---- - -## Contributing +## Develop ```bash git clone https://github.com/Sankhya-AI/Dhee.git cd Dhee -./scripts/bootstrap_dev_env.sh -source .venv-dhee/bin/activate +pip install -e ".[dev]" pytest ``` -Full verification: +Full release check: ```bash -./scripts/verify_full_suite.sh +python3 -m compileall -q dhee tests +python3 -m pytest -q +python3 -m build +dhee release check --repo . ``` ---- - -

- Your agent stops drowning in context. -

- GitHub | - PyPI | - Issues | - Sankhya AI -

+## License -

MIT License - built by Sankhya AI Labs.

+MIT. Built by Sankhya AI Labs. diff --git a/dhee/__init__.py b/dhee/__init__.py index 7531633..dc3d957 100644 --- a/dhee/__init__.py +++ b/dhee/__init__.py @@ -30,6 +30,12 @@ from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult from dhee.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig +from dhee.memory.admission import ( + MemoryAdmissionDecision, + evaluate_memory_candidate, + forget_reason_for_memory, + sanitize_admitted_content, +) # Default: CoreMemory (lightest, zero-config) Memory = CoreMemory @@ -65,4 +71,9 @@ "EchoMemConfig", "CategoryMemConfig", "ScopeConfig", + # Memory admission + "MemoryAdmissionDecision", + "evaluate_memory_candidate", + "forget_reason_for_memory", + "sanitize_admitted_content", ] diff --git a/dhee/cli.py b/dhee/cli.py index 7d788a6..b06e421 100644 --- a/dhee/cli.py +++ b/dhee/cli.py @@ -13,6 +13,7 @@ dhee why Explain why a memory or artifact exists dhee handoff Emit structured resume JSON for a new harness/agent dhee harness status Show native Claude Code / Codex integration state + dhee release check Block release tagging unless repo state is clean dhee demo token-router Show how Dhee keeps raw tool output behind pointers dhee ui Open the local Dhee dashboard dhee benchmark Run performance benchmarks @@ -497,6 +498,407 @@ def _store() -> ContextStateStore: print(f"Rollover required: {'yes' if data['rollover_required'] else 'no'}") return + if action == "task": + from dhee import task_contracts + + subaction = args.entry_id or "list" + extra = list(getattr(args, "context_args", []) or []) + task_actions = { + "compile", + "create", + "list", + "show", + "import", + "interpret", + "supervise", + "observe", + "proof", + "proof-bundle", + "activate", + "deactivate", + "enforce", + "status", + "runtime", + } + if subaction not in task_actions: + extra = [subaction, *extra] + subaction = "compile" + if subaction == "compile": + goal = " ".join(str(item) for item in extra).strip() + if not goal: + print("Pass a task goal: dhee context task compile \"Fix failing context firewall tests\"") + sys.exit(1) + try: + result = task_contracts.compile_task_contract( + goal, + repo=args.repo or os.getcwd(), + mode=getattr(args, "mode", None) or "patch", + risk=getattr(args, "risk", None), + allowed_write_paths=getattr(args, "allowed_write_paths", None), + forbidden_paths=getattr(args, "forbidden_paths", None), + must_run=getattr(args, "must_run", None), + ) + except Exception as exc: + if args.json: + _json_out({"error": str(exc)}) + else: + print(f"task contract compile failed: {exc}") + sys.exit(1) + if args.json: + _json_out(result) + return + print(task_contracts.render_task_contract(result)) + return + if subaction == "create": + goal = " ".join(str(item) for item in extra).strip() + if not goal: + print("Pass a task goal: dhee context task create \"Fix failing context firewall tests\"") + sys.exit(1) + try: + result = task_contracts.create_task_contract( + goal, + repo=args.repo or os.getcwd(), + out=getattr(args, "out", None), + mode=getattr(args, "mode", None) or "patch", + risk=getattr(args, "risk", None), + allowed_write_paths=getattr(args, "allowed_write_paths", None), + forbidden_paths=getattr(args, "forbidden_paths", None), + must_run=getattr(args, "must_run", None), + ) + except Exception as exc: + if args.json: + _json_out({"error": str(exc)}) + else: + print(f"task contract create failed: {exc}") + sys.exit(1) + if args.json: + _json_out(result) + return + contract = result.get("contract") or {} + paths = result.get("paths") or {} + print(f"Created task contract {contract.get('task_id')}") + print(f" markdown {paths.get('markdown')}") + print(f" json {paths.get('json')}") + return + if subaction == "list": + contracts = task_contracts.list_task_contracts(repo=args.repo or os.getcwd()) + if args.json: + _json_out(contracts) + return + if not contracts: + print("No task contracts found.") + return + for item in contracts: + print(f" [{str(item.get('task_id') or '')[:22]}] {item.get('risk') or '?':>6} {item.get('goal') or '(untitled)'}") + return + if subaction == "show": + if not extra: + print("Pass a task id: dhee context task show ") + sys.exit(1) + try: + result = task_contracts.get_task_contract(extra[0], repo=args.repo or os.getcwd()) + except Exception as exc: + if args.json: + _json_out({"error": str(exc)}) + else: + print(f"task contract show failed: {exc}") + sys.exit(1) + if args.json: + _json_out(result) + return + print(result.get("markdown") or json.dumps(result.get("compiled") or {}, indent=2, default=str)) + return + if subaction == "import": + if not extra: + print("Pass a task contract path: dhee context task import ") + sys.exit(1) + try: + result = task_contracts.import_task_contract(extra[0], repo=args.repo or os.getcwd()) + except Exception as exc: + if args.json: + _json_out({"error": str(exc)}) + else: + print(f"task contract import failed: {exc}") + sys.exit(1) + if args.json: + _json_out(result) + return + contract = result.get("contract") or {} + paths = result.get("paths") or {} + print(f"Imported task contract {contract.get('task_id')}") + print(f" dir {paths.get('dir')}") + return + if subaction == "interpret": + if not extra: + print("Pass a task id or path: dhee context task interpret ") + sys.exit(1) + result = task_contracts.interpret_task_contract( + extra[0], + repo=args.repo or os.getcwd(), + strict=bool(getattr(args, "strict", False)), + ) + if args.json: + _json_out(result) + return + print(f"Readiness: {result.get('readiness')}") + for step in result.get("execution_plan") or []: + print(f" {step.get('step')}. {step.get('state')} {step.get('type')} {step.get('target') or ''}") + diagnostics = result.get("diagnostics") or [] + if diagnostics: + print("\nDiagnostics:") + for diag in diagnostics: + print(f" [{diag.get('level')}] {diag.get('code')}: {diag.get('message')}") + return + if subaction == "supervise": + if not extra: + print("Pass a task id or path: dhee context task supervise --action-json '{...}'") + sys.exit(1) + if not getattr(args, "action_json", None): + print("Pass --action-json for the proposed ChotuAction") + sys.exit(1) + from dhee.contract_supervisor import supervise_action + + result = supervise_action( + extra[0], + json.loads(args.action_json), + repo=args.repo or os.getcwd(), + strict=bool(getattr(args, "strict", False)), + ) + if args.json: + _json_out(result) + return + print(f"Decision: {result.get('decision')}") + for violation in result.get("violations") or []: + print(f" {violation.get('code')}: {violation.get('message')}") + return + if subaction == "observe": + if not extra: + print("Pass a task id or path: dhee context task observe --action-json '{...}' --observation '...'") + sys.exit(1) + if not getattr(args, "action_json", None): + print("Pass --action-json for the observed ChotuAction") + sys.exit(1) + from dhee.contract_supervisor import record_observation_transition + + next_action = json.loads(args.next_action_json) if getattr(args, "next_action_json", None) else None + result = record_observation_transition( + extra[0], + json.loads(args.action_json), + getattr(args, "observation", None) or "", + repo=args.repo or os.getcwd(), + outcome=getattr(args, "outcome", None) or "observed", + next_action=next_action, + strict=bool(getattr(args, "strict", False)), + ) + if args.json: + _json_out(result) + return + print(f"Recorded observation {result.get('event', {}).get('event_id')}") + print(f" decision {result.get('decision', {}).get('decision')}") + print(f" events {result.get('paths', {}).get('events')}") + return + if subaction in {"proof", "proof-bundle"}: + if not extra: + print("Pass a task id or path: dhee context task proof ") + sys.exit(1) + from dhee.contract_supervisor import build_proof_bundle + + result = build_proof_bundle( + extra[0], + repo=args.repo or os.getcwd(), + strict=bool(getattr(args, "strict", False)), + persist=not bool(getattr(args, "dry_run", False)), + ) + if args.json: + _json_out(result) + return + bundle = result.get("proof_bundle") or {} + verifier = bundle.get("verifier_result") or {} + print(f"Proof bundle {bundle.get('contract_id')}") + print(f" verifier {verifier.get('status')}") + print(f" tests {len(verifier.get('passed_tests') or [])}/{len(verifier.get('required_tests') or [])} passed") + if result.get("paths", {}).get("proof_bundle"): + print(f" path {result['paths']['proof_bundle']}") + for path in verifier.get("out_of_contract_changed_paths") or []: + print(f" out-of-contract change: {path}") + for path in verifier.get("forbidden_changed_paths") or []: + print(f" forbidden change: {path}") + return + if subaction == "enforce": + from dhee.contract_runtime import set_contract_enforcement + + selected_mode = (extra[0] if extra else getattr(args, "mode", None) or "").strip().lower() + if selected_mode not in {"off", "warn", "deny"}: + print("Pass an enforcement mode: dhee context task enforce off|warn|deny") + sys.exit(1) + result = set_contract_enforcement( + selected_mode, + repo=args.repo or os.getcwd(), + agent_id="cli", + reason=getattr(args, "reason", None) or "manual", + ) + if args.json: + _json_out(result) + return + effective = result.get("effective") or {} + print(f"Contract enforcement set to {selected_mode}") + print(f" effective {effective.get('mode')}") + print(f" policy {(result.get('paths') or {}).get('policy')}") + return + if subaction == "activate": + if not extra: + print("Pass a task id or path: dhee context task activate ") + sys.exit(1) + from dhee.contract_runtime import activate_contract_runtime + + result = activate_contract_runtime( + extra[0], + repo=args.repo or os.getcwd(), + strict=bool(getattr(args, "strict", False)), + force=bool(getattr(args, "force", False)), + agent_id="cli", + harness="cli", + ) + if args.json: + _json_out(result) + return + if not result.get("active"): + print(f"Task contract activation rejected: {result.get('reason') or result.get('status')}") + for diag in ((result.get("interpretation") or {}).get("diagnostics") or []): + print(f" [{diag.get('level')}] {diag.get('code')}: {diag.get('message')}") + sys.exit(1) + print(f"Activated task contract {result.get('task_id')}") + print(f" active {(result.get('paths') or {}).get('active')}") + print(f" events {(result.get('paths') or {}).get('events')}") + print(" router enforcing dhee_read/dhee_grep/dhee_bash") + return + if subaction in {"status", "runtime"}: + from dhee.contract_runtime import contract_runtime_status + + result = contract_runtime_status(repo=args.repo or os.getcwd()) + if args.json: + _json_out(result) + return + if not result.get("active"): + print("No active task contract runtime.") + return + print(f"Active task contract {result.get('task_id')}") + print(f" status {result.get('status')}") + print(f" strict {result.get('strict')}") + print(f" active {(result.get('paths') or {}).get('active')}") + print(f" events {(result.get('paths') or {}).get('events')}") + print(f" readiness {(result.get('interpretation') or {}).get('readiness')}") + return + if subaction == "deactivate": + from dhee.contract_runtime import deactivate_contract_runtime + + result = deactivate_contract_runtime( + repo=args.repo or os.getcwd(), + agent_id="cli", + reason=getattr(args, "reason", None) or "manual", + ) + if args.json: + _json_out(result) + return + if result.get("task_id"): + print(f"Deactivated task contract {result.get('task_id')}") + else: + print("No active task contract runtime.") + return + print("Use: dhee context task compile|create|list|show|import|interpret|supervise|observe|proof|enforce|activate|status|deactivate") + sys.exit(1) + + if action == "capsule": + from dhee import update_capsules + + subaction = args.entry_id or "list" + extra = list(getattr(args, "context_args", []) or []) + repo = args.repo or os.getcwd() + try: + if subaction == "create": + result = update_capsules.create_update_capsule( + repo=repo, + since=getattr(args, "since", None), + task_id=getattr(args, "task_id", None), + out=getattr(args, "out", None), + ) + if args.json: + _json_out(result) + return + capsule = result.get("capsule") or {} + paths = result.get("paths") or {} + entry = result.get("entry") or {} + print(f"Created update capsule {capsule.get('id')}") + print(f" markdown {paths.get('markdown')}") + print(f" json {paths.get('json')}") + print(f" entry {entry.get('id')}") + return + if subaction == "list": + capsules = update_capsules.list_update_capsules(repo=repo) + if args.json: + _json_out(capsules) + return + if not capsules: + print("No update capsules found.") + return + for capsule in capsules: + changed = len(capsule.get("changed_paths") or []) + print(f" [{str(capsule.get('id') or '')[:18]}] {changed:>2} paths {capsule.get('title') or '(untitled)'}") + return + if subaction == "show": + if not extra: + print("Pass a capsule id: dhee context capsule show ") + sys.exit(1) + result = update_capsules.get_update_capsule(extra[0], repo=repo) + if args.json: + _json_out(result) + return + print(result.get("markdown") or json.dumps(result.get("capsule") or {}, indent=2, default=str)) + return + if subaction == "import": + if not extra: + print("Pass a capsule path: dhee context capsule import ") + sys.exit(1) + result = update_capsules.import_update_capsule(extra[0], repo=repo) + if args.json: + _json_out(result) + return + capsule = result.get("capsule") or {} + paths = result.get("paths") or {} + print(f"Imported update capsule {capsule.get('id')}") + print(f" dir {paths.get('dir')}") + return + if subaction == "interpret": + if not extra: + print("Pass a capsule id or path: dhee context capsule interpret ") + sys.exit(1) + result = update_capsules.interpret_update_capsule( + extra[0], + repo=repo, + strict=bool(getattr(args, "strict", False)), + ) + if args.json: + _json_out(result) + return + print(f"Readiness: {result.get('readiness')}") + for step in result.get("execution_plan") or []: + print(f" {step.get('step')}. {step.get('state')} {step.get('action')} {step.get('path')}") + print(f" {step.get('instruction')}") + diagnostics = result.get("diagnostics") or [] + if diagnostics: + print("\nDiagnostics:") + for diag in diagnostics: + print(f" [{diag.get('level')}] {diag.get('code')}: {diag.get('message')}") + return + except Exception as exc: + if args.json: + _json_out({"error": str(exc)}) + else: + print(f"capsule {subaction} failed: {exc}") + sys.exit(1) + print("Use: dhee context capsule create|list|show|import|interpret") + sys.exit(1) + from dhee import repo_link repo_root = repo_link._resolve_repo(args.repo) if args.repo else repo_link.repo_for_path(os.getcwd()) @@ -1396,11 +1798,91 @@ def cmd_status(args: argparse.Namespace) -> None: def cmd_doctor(args: argparse.Namespace) -> None: """Composite observability report. No controls, just truth.""" + topic = getattr(args, "doctor_topic", None) + if topic == "contract-runtime": + from dhee.contract_runtime import contract_runtime_doctor + + result = contract_runtime_doctor(repo=getattr(args, "repo", None) or os.getcwd()) + if getattr(args, "json", False): + _json_out(result) + return + print(f"Dhee contract runtime: {result.get('status')}") + active = result.get("active_contract") or {} + enforcement = result.get("enforcement") or {} + router = result.get("hook_router_health") or {} + print(f" enforcement {enforcement.get('mode')} (configured={enforcement.get('configured_mode')})") + print(f" active {active.get('active')} {active.get('task_id') or ''}") + print(f" router {'enabled' if router.get('enabled') else 'not enabled'}") + risks = result.get("bypass_risks") or [] + if risks: + print(" risks " + ", ".join(str(item) for item in risks)) + corrupt = result.get("corrupt_files") or [] + if corrupt: + print(" corrupt " + ", ".join(str(item.get("path") or item.get("code")) for item in corrupt)) + return from dhee.doctor import run sys.stdout.write(run(as_json=bool(getattr(args, "json", False)))) +def cmd_release(args: argparse.Namespace) -> None: + """Release hygiene gates for premium builds.""" + from dhee.release_hygiene import ( + format_release_check, + release_check, + write_release_intent, + ) + + action = getattr(args, "release_action", None) or "check" + repo = getattr(args, "repo", None) or os.getcwd() + if action == "intent": + paths = list(getattr(args, "paths", None) or []) + if not paths: + print("Pass at least one --path for the intended release scope.", file=sys.stderr) + sys.exit(2) + result = write_release_intent( + repo, + paths, + reason=getattr(args, "reason", None) or "", + agent_id="cli", + ) + if getattr(args, "json", False): + _json_out(result) + if not result.get("ok"): + sys.exit(1) + return + if not result.get("ok"): + diagnostics = result.get("diagnostics") or [] + print("Release intent write failed.", file=sys.stderr) + for diag in diagnostics: + print(f" {diag.get('code')}: {diag.get('message')}", file=sys.stderr) + sys.exit(1) + intent = result.get("intent") or {} + print(f"Release intent written: {result.get('path')}") + print(f" paths {', '.join(intent.get('intended_paths') or [])}") + if intent.get("reason"): + print(f" reason {intent.get('reason')}") + print(" note intent documents scope; release still requires a clean git tree") + return + + if action == "check": + report = release_check( + repo, + intended_paths=getattr(args, "intended_paths", None), + require_clean=True, + ) + if getattr(args, "json", False): + _json_out(report) + else: + print(format_release_check(report)) + if not report.get("release_allowed") and not getattr(args, "no_fail", False): + sys.exit(1) + return + + print("Use: dhee release check|intent") + sys.exit(1) + + def cmd_runtime(args: argparse.Namespace) -> None: """Inspect or manage the local Dhee runtime daemon.""" from dhee import runtime @@ -2731,12 +3213,28 @@ def build_parser() -> argparse.ArgumentParser: p_context.add_argument( "context_action", nargs="?", - choices=["list", "show", "delete", "refresh", "check", "status", "state", "checkpoint", "rollover", "provision", "debt"], + choices=["list", "show", "delete", "refresh", "check", "status", "state", "checkpoint", "rollover", "provision", "debt", "capsule", "task"], default="list", help="Subcommand (default: list)", ) - p_context.add_argument("entry_id", nargs="?", help="Entry id for show/delete, or task text for provision") + p_context.add_argument("entry_id", nargs="?", help="Entry id for show/delete, task text for provision, or capsule/task subcommand") + p_context.add_argument("context_args", nargs="*", help=argparse.SUPPRESS) p_context.add_argument("--repo", help="Repo path (default: linked repo containing cwd)") + p_context.add_argument("--since", help="For `context capsule create`, base git ref") + p_context.add_argument("--task-id", help="For `context capsule create`, attach a task id") + p_context.add_argument("--out", help="For `context capsule create`, write to a specific capsule directory") + p_context.add_argument("--strict", action="store_true", help="For `context capsule interpret`, treat precondition mismatch as blocking") + p_context.add_argument("--force", action="store_true", help="For `context task activate`, activate even when interpretation is blocked") + p_context.add_argument("--dry-run", action="store_true", help="For `context task proof`, build the proof bundle without writing it") + p_context.add_argument("--mode", default="patch", help="For `context task compile`, task mode; for `context task enforce`, off|warn|deny") + p_context.add_argument("--risk", help="For `context task compile`, override inferred risk") + p_context.add_argument("--allowed-write-path", action="append", dest="allowed_write_paths", help="For `context task compile`, allowed write path") + p_context.add_argument("--forbidden-path", action="append", dest="forbidden_paths", help="For `context task compile`, forbidden path") + p_context.add_argument("--must-run", action="append", dest="must_run", help="For `context task compile`, required test/check command") + p_context.add_argument("--action-json", help="For `context task supervise|observe`, proposed/observed ChotuAction JSON") + p_context.add_argument("--next-action-json", help="For `context task observe`, optional next ChotuAction JSON") + p_context.add_argument("--observation", help="For `context task observe`, compact observation text") + p_context.add_argument("--outcome", default="observed", help="For `context task observe`, outcome label") p_context.add_argument("--user-id", default="default", help="User ID for compiled-state commands") p_context.add_argument("--top", action="store_true", help="Show top context-debt sources") p_context.add_argument("--card", action="store_true", help="For `context state`, print state card XML") @@ -2794,8 +3292,48 @@ def build_parser() -> argparse.ArgumentParser: "doctor", help="Composite health + honesty report (router, cognition, memory, movement plan)", ) + p_doctor.add_argument( + "doctor_topic", + nargs="?", + choices=["contract-runtime"], + help="Optional focused doctor report", + ) + p_doctor.add_argument("--repo", help="Repo path for focused doctor reports") p_doctor.add_argument("--json", action="store_true", help="JSON output") + # release — hard gate before tagging or customer release + p_release = sub.add_parser( + "release", + help="Release hygiene gate (clean tree, intended scope, honest blockers)", + ) + p_release.add_argument( + "release_action", + nargs="?", + choices=["check", "intent"], + default="check", + help="Subcommand", + ) + p_release.add_argument("--repo", help="Repo path (default: cwd)") + p_release.add_argument( + "--intended-path", + action="append", + dest="intended_paths", + help="For `release check`, expected dirty path scope for review reporting", + ) + p_release.add_argument( + "--path", + action="append", + dest="paths", + help="For `release intent`, intended release path scope", + ) + p_release.add_argument("--reason", help="For `release intent`, compact release-scope reason") + p_release.add_argument( + "--no-fail", + action="store_true", + help="For `release check`, print/report blockers but exit 0", + ) + p_release.add_argument("--json", action="store_true", help="JSON output") + # runtime — managed venv + local daemon clarity p_runtime = sub.add_parser( "runtime", @@ -3132,6 +3670,7 @@ def build_parser() -> argparse.ArgumentParser: "shared-task": cmd_shared_task, "status": cmd_status, "doctor": cmd_doctor, + "release": cmd_release, "runtime": cmd_runtime, "task": cmd_task, "ingest": cmd_ingest, diff --git a/dhee/context_ir.py b/dhee/context_ir.py new file mode 100644 index 0000000..a759b28 --- /dev/null +++ b/dhee/context_ir.py @@ -0,0 +1,825 @@ +"""Portable Context IR compiler and interpreter. + +This module turns Dhee update-capsule evidence into a small intermediate +representation (IR), then interprets that IR on another machine/repo. It +does not auto-apply patches; it validates preconditions, maps symbols to the +target checkout, and emits a reproduction plan with diagnostics. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +import subprocess +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +CONTEXT_IR_SCHEMA = "dhee.context_ir.v1" +INTERPRETER_SCHEMA = "dhee.context_interpretation.v1" + +_LANG_BY_SUFFIX = { + ".py": "python", + ".js": "javascript", + ".jsx": "javascript-react", + ".ts": "typescript", + ".tsx": "typescript-react", + ".rs": "rust", + ".go": "go", + ".java": "java", + ".kt": "kotlin", + ".swift": "swift", + ".md": "markdown", + ".json": "json", + ".toml": "toml", + ".yaml": "yaml", + ".yml": "yaml", +} + +_ANCHOR_RE = re.compile( + r"^[+-]\s*(?:async\s+)?(?:def|class|function|const|let|var|export\s+function|pub\s+fn|fn)\s+([A-Za-z_][A-Za-z0-9_]*)", +) +_HEX_64_RE = re.compile(r"^[a-f0-9]{64}$") +_EXCLUDED_WALK_DIRS = { + ".git", + ".dhee", + ".hg", + ".svn", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".venv", + "venv", + "env", + "__pycache__", + "node_modules", + "dist", + "build", + ".next", +} +_MAX_CANDIDATE_SCAN_FILES = 5_000 +_MAX_CANDIDATE_MATCHES = 64 +_MAX_ANCHOR_SCAN_BYTES = 2 * 1024 * 1024 + + +def stable_hash(data: Any, length: int = 18) -> str: + raw = json.dumps(data, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:length] + + +def file_sha256(path: Path) -> Optional[str]: + try: + if not path.exists() or not path.is_file(): + return None + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + except OSError: + return None + + +def _language_for(path: str) -> str: + return _LANG_BY_SUFFIX.get(Path(path).suffix.lower(), "text") + + +def _op_for_status(status: str) -> str: + status = str(status or "").lower() + if status in {"added", "untracked"}: + return "create_file" + if status == "deleted": + return "delete_file" + if status == "renamed": + return "rename_or_modify_file" + return "modify_file" + + +def _anchors_from_diff(diff: str) -> List[str]: + anchors: List[str] = [] + for line in (diff or "").splitlines(): + if line.startswith(("+++", "---")): + continue + if line.startswith("@@"): + tail = line.split("@@", 2)[-1].strip() + if tail and tail not in anchors: + anchors.append(tail[:120]) + continue + match = _ANCHOR_RE.match(line) + if match: + name = match.group(1) + if name not in anchors: + anchors.append(name) + if len(anchors) >= 12: + break + return anchors + + +def _hunk_by_path(compact_hunks: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + out: Dict[str, Dict[str, Any]] = {} + for hunk in compact_hunks or []: + if isinstance(hunk, dict) and hunk.get("path"): + out[str(hunk["path"])] = hunk + return out + + +def build_context_ir( + *, + capsule_id: str, + title: str, + summary: str, + repo_id: str, + base_ref: str, + base_commit: str, + head_commit: str, + changed_paths: List[Dict[str, Any]], + compact_hunks: List[Dict[str, Any]], + commands: List[str], + evidence_pointers: List[Dict[str, Any]], + base_file_hashes: Dict[str, str], + file_hashes: Dict[str, str], + privacy: Dict[str, Any], + diagnostics: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Compile capsule metadata into portable Context IR.""" + + hunks = _hunk_by_path(compact_hunks) + file_symbols: List[Dict[str, Any]] = [] + operations: List[Dict[str, Any]] = [] + ir_diagnostics = list(diagnostics or []) + + for index, item in enumerate(changed_paths or []): + rel_path = str(item.get("path") or "") + if not rel_path: + continue + status = str(item.get("status") or "changed") + symbol = f"file:{rel_path}" + hunk = hunks.get(rel_path, {}) + diff = str(hunk.get("diff") or "") + anchors = _anchors_from_diff(diff) + before_hash = base_file_hashes.get(rel_path) + after_hash = file_hashes.get(rel_path) or item.get("sha256") + file_symbol = { + "symbol": symbol, + "path": rel_path, + "status": status, + "language": _language_for(rel_path), + "before_sha256": before_hash, + "after_sha256": after_hash, + "size": item.get("size"), + "anchors": anchors, + "hunk_ref": f"hunk:{index}", + } + file_symbols.append(file_symbol) + + preconditions: List[Dict[str, Any]] = [] + if before_hash: + preconditions.append({"kind": "sha256", "path": rel_path, "equals": before_hash}) + elif status not in {"added", "untracked"}: + ir_diagnostics.append({ + "level": "warning", + "code": "MISSING_BASE_HASH", + "path": rel_path, + "message": "Compiler could not capture a base-file hash; interpreter will rely on anchors and path existence.", + }) + postconditions: List[Dict[str, Any]] = [] + if after_hash: + postconditions.append({"kind": "sha256", "path": rel_path, "equals": after_hash}) + else: + postconditions.append({"kind": "path_absent", "path": rel_path}) + + if status == "untracked" and not diff: + ir_diagnostics.append({ + "level": "warning", + "code": "UNTRACKED_NO_DIFF", + "path": rel_path, + "message": "Untracked file has a hash but no git diff body; receiving agent must reconstruct from intent or source context.", + }) + if "" in diff: + ir_diagnostics.append({ + "level": "warning", + "code": "SECRET_REDACTED_IN_DIFF", + "path": rel_path, + "message": "A secret-like value was redacted from the compact diff; receiving agent must supply its own local secret/config.", + }) + + operations.append({ + "op": _op_for_status(status), + "target": symbol, + "path": rel_path, + "hunk_ref": f"hunk:{index}", + "anchor_hints": anchors, + "preconditions": preconditions, + "postconditions": postconditions, + "recipe": { + "kind": "compact_diff_recipe", + "diff_sha256": hashlib.sha256(diff.encode("utf-8")).hexdigest() if diff else None, + "has_diff_body": bool(diff), + "truncated": bool(hunk.get("truncated")), + "status": status, + }, + }) + + command_symbols = [ + { + "symbol": f"cmd:{index}", + "command": command, + "kind": "verification" if any(word in command for word in ("test", "pytest", "npm run build", "cargo test")) else "context", + } + for index, command in enumerate(commands or []) + ] + evidence_symbols = [ + { + "symbol": f"evidence:{index}", + "kind": pointer.get("kind"), + "ref": pointer.get("ref"), + "label": pointer.get("label"), + "source_app": pointer.get("source_app"), + "agent_id": pointer.get("agent_id"), + "confidentiality_scope": pointer.get("confidentiality_scope"), + } + for index, pointer in enumerate(evidence_pointers or []) + ] + + ir = { + "schema_version": CONTEXT_IR_SCHEMA, + "compiler": { + "name": "dhee-context-compiler", + "version": 1, + "phases": [ + "source_collection", + "privacy_gate", + "symbol_table", + "operation_ir", + "diagnostics", + "verification_plan", + ], + }, + "module": { + "id": capsule_id, + "title": title, + "intent": summary, + "source_repo": { + "repo_id": repo_id, + "base_ref": base_ref, + "base_commit": base_commit, + "head_commit": head_commit, + }, + }, + "symbol_table": { + "files": file_symbols, + "commands": command_symbols, + "evidence": evidence_symbols, + }, + "operations": operations, + "verification": { + "commands": command_symbols, + "assertions": [ + {"kind": "postcondition", "path": op["path"], "checks": op["postconditions"]} + for op in operations + ], + }, + "diagnostics": ir_diagnostics, + "privacy": dict(privacy or {}), + "semantics": { + "execution_model": "interpret_plan_not_auto_apply", + "raw_personal_memory_included": bool((privacy or {}).get("raw_personal_memory_included")), + "whole_file_snapshots_included": False, + }, + "fingerprint": stable_hash({ + "capsule_id": capsule_id, + "files": file_symbols, + "operations": operations, + "commands": commands, + }, 32), + } + return ir + + +def _resolve_repo_root(repo: str | os.PathLike[str] | None) -> Path: + base = Path(repo or os.getcwd()).expanduser().resolve() + proc = subprocess.run( + ["git", "-C", str(base), "rev-parse", "--show-toplevel"], + text=True, + capture_output=True, + check=False, + ) + if proc.returncode == 0 and proc.stdout.strip(): + return Path(proc.stdout.strip()).resolve() + return base + + +def _load_ir(capsule_or_ir: Dict[str, Any]) -> Dict[str, Any]: + if capsule_or_ir.get("schema_version") == CONTEXT_IR_SCHEMA: + return capsule_or_ir + ir = capsule_or_ir.get("context_ir") + if isinstance(ir, dict): + return ir + # Legacy fallback: build an interpreter-readable shell from capsule fields. + return build_context_ir( + capsule_id=str(capsule_or_ir.get("id") or stable_hash(capsule_or_ir)), + title=str(capsule_or_ir.get("title") or ""), + summary=str(capsule_or_ir.get("summary") or ""), + repo_id=str((capsule_or_ir.get("repo") or {}).get("repo_id") or capsule_or_ir.get("repo_id") or ""), + base_ref=str(capsule_or_ir.get("base_ref") or (capsule_or_ir.get("repo") or {}).get("base_ref") or ""), + base_commit=str(capsule_or_ir.get("base_commit") or (capsule_or_ir.get("repo") or {}).get("base_commit") or ""), + head_commit=str(capsule_or_ir.get("head_commit") or (capsule_or_ir.get("repo") or {}).get("head_commit") or ""), + changed_paths=list(capsule_or_ir.get("changed_paths") or []), + compact_hunks=list(capsule_or_ir.get("compact_hunks") or []), + commands=list(capsule_or_ir.get("commands") or []), + evidence_pointers=list(capsule_or_ir.get("evidence_pointers") or []), + base_file_hashes=dict(capsule_or_ir.get("base_file_hashes") or {}), + file_hashes=dict(capsule_or_ir.get("file_hashes") or {}), + privacy=dict(capsule_or_ir.get("privacy") or {}), + diagnostics=[{ + "level": "warning", + "code": "LEGACY_CAPSULE_COMPILED_ON_INTERPRET", + "message": "Capsule did not include context_ir; interpreter built a best-effort IR shell.", + }], + ) + + +def validate_context_ir(ir: Dict[str, Any], *, strict: bool = False) -> Dict[str, Any]: + """Validate Context IR before import or interpretation. + + The compiler emits a compact program. Validation checks that the program + has a known schema, a module identity, resolvable symbols, operation + contracts, and no raw private payloads. + """ + + diagnostics: List[Dict[str, Any]] = [] + + def add(level: str, code: str, message: str, **extra: Any) -> None: + diagnostics.append({"level": level, "code": code, "message": message, **extra}) + + if not isinstance(ir, dict): + add("error", "IR_NOT_OBJECT", "Context IR must be a JSON object.") + return { + "ok": False, + "schema_version": None, + "operation_count": 0, + "file_symbol_count": 0, + "diagnostics": diagnostics, + } + + schema_version = ir.get("schema_version") + if schema_version != CONTEXT_IR_SCHEMA: + add("error", "UNKNOWN_IR_SCHEMA", f"Unsupported context IR schema: {schema_version!r}") + + module = ir.get("module") + if not isinstance(module, dict): + add("error", "MISSING_MODULE", "Context IR is missing a module object.") + module = {} + if not module.get("id"): + add("error", "MISSING_MODULE_ID", "Context IR module is missing an id.") + + symbol_table = ir.get("symbol_table") + if not isinstance(symbol_table, dict): + add("error", "MISSING_SYMBOL_TABLE", "Context IR is missing a symbol_table object.") + symbol_table = {} + + file_symbols = symbol_table.get("files") or [] + if not isinstance(file_symbols, list): + add("error", "INVALID_FILE_SYMBOLS", "symbol_table.files must be a list.") + file_symbols = [] + + file_symbol_ids = set() + file_symbol_paths = set() + for index, symbol in enumerate(file_symbols): + if not isinstance(symbol, dict): + add("error", "INVALID_FILE_SYMBOL", "File symbol must be an object.", index=index) + continue + symbol_id = symbol.get("symbol") + path = symbol.get("path") + if not symbol_id: + add("error", "MISSING_FILE_SYMBOL_ID", "File symbol is missing symbol id.", index=index) + else: + file_symbol_ids.add(str(symbol_id)) + if not path: + add("error", "MISSING_FILE_SYMBOL_PATH", "File symbol is missing path.", index=index) + else: + path_text = str(path) + file_symbol_paths.add(path_text) + if Path(path_text).is_absolute() or ".." in Path(path_text).parts: + add("error", "UNSAFE_FILE_SYMBOL_PATH", "File symbol path must be repo-relative.", path=path_text) + for key in ("before_sha256", "after_sha256"): + digest = symbol.get(key) + if digest and not _HEX_64_RE.match(str(digest)): + add("error" if strict else "warning", "INVALID_FILE_HASH", f"{key} is not a sha256 hex digest.", path=path) + + operations = ir.get("operations") or [] + if not isinstance(operations, list): + add("error", "INVALID_OPERATIONS", "operations must be a list.") + operations = [] + if strict and not operations: + add("error", "NO_OPERATIONS", "Strict Context IR requires at least one operation.") + elif not operations: + add("warning", "NO_OPERATIONS", "Context IR contains no operations.") + + for index, op in enumerate(operations): + if not isinstance(op, dict): + add("error", "INVALID_OPERATION", "Operation must be an object.", index=index) + continue + action = op.get("op") + path = op.get("path") + target = op.get("target") + if action not in {"create_file", "modify_file", "delete_file", "rename_or_modify_file"}: + add("error", "UNKNOWN_OPERATION", f"Unsupported operation: {action!r}", index=index, path=path) + if not path: + add("error", "MISSING_OPERATION_PATH", "Operation is missing path.", index=index) + else: + path_text = str(path) + if Path(path_text).is_absolute() or ".." in Path(path_text).parts: + add("error", "UNSAFE_OPERATION_PATH", "Operation path must be repo-relative.", index=index, path=path_text) + if path_text not in file_symbol_paths: + add("warning", "OPERATION_PATH_NOT_IN_SYMBOL_TABLE", "Operation path has no matching file symbol.", index=index, path=path_text) + if target and file_symbol_ids and str(target) not in file_symbol_ids: + add("warning", "OPERATION_TARGET_UNRESOLVED", "Operation target is not present in the file symbol table.", index=index, target=target) + if not op.get("hunk_ref"): + add("warning", "MISSING_HUNK_REF", "Operation does not point to a compact hunk.", index=index, path=path) + for check_group in ("preconditions", "postconditions"): + checks = op.get(check_group) or [] + if not isinstance(checks, list): + add("error", "INVALID_CONDITION_GROUP", f"{check_group} must be a list.", index=index, path=path) + continue + for check_index, check in enumerate(checks): + if not isinstance(check, dict): + add("error", "INVALID_CONDITION", "Condition must be an object.", index=index, check_index=check_index) + continue + kind = check.get("kind") + if kind not in {"sha256", "path_absent"}: + add("error", "UNKNOWN_CONDITION", f"Unsupported condition kind: {kind!r}", index=index, path=path) + digest = check.get("equals") + if kind == "sha256" and (not digest or not _HEX_64_RE.match(str(digest))): + add("error", "INVALID_CONDITION_HASH", "sha256 condition must include a valid equals hash.", index=index, path=path) + + privacy = ir.get("privacy") or {} + if not isinstance(privacy, dict): + add("error", "INVALID_PRIVACY", "privacy must be an object.") + privacy = {} + if privacy.get("raw_personal_memory_included") or (ir.get("semantics") or {}).get("raw_personal_memory_included"): + add("error", "PRIVATE_BODY_PRESENT", "Context IR cannot contain raw personal-memory payloads.") + + return { + "ok": not any(item.get("level") == "error" for item in diagnostics), + "schema_version": schema_version, + "operation_count": len(operations), + "file_symbol_count": len(file_symbols), + "diagnostics": diagnostics, + } + + +def render_context_ir(ir_or_capsule: Dict[str, Any]) -> str: + """Render compact, agent-readable IR for humans and receiving agents.""" + + ir = _load_ir(ir_or_capsule) + validation = validate_context_ir(ir) + module = ir.get("module") or {} + lines = [ + f"- Schema: `{ir.get('schema_version') or '(none)'}`", + f"- Fingerprint: `{ir.get('fingerprint') or '(none)'}`", + f"- Module: `{module.get('id') or '(unknown)'}`", + f"- Validation: `{'ok' if validation['ok'] else 'failed'}` with `{len(validation['diagnostics'])}` diagnostic(s)", + f"- Operations: `{validation['operation_count']}`", + ] + for index, op in enumerate(ir.get("operations") or [], start=1): + anchors = op.get("anchor_hints") or [] + anchor_text = f"; anchors: {', '.join(str(item) for item in anchors[:3])}" if anchors else "" + lines.append( + f" {index}. `{op.get('op')}` `{op.get('path')}` via `{op.get('hunk_ref') or '(no hunk)'}`{anchor_text}" + ) + if not (ir.get("operations") or []): + lines.append(" No operations compiled.") + return "\n".join(lines) + + +def _repo_relative(repo_root: Path, path: Path) -> str: + try: + return os.path.relpath(path, repo_root).replace(os.sep, "/") + except ValueError: + return str(path) + + +def _safe_repo_path(repo_root: Path, rel_path: str) -> Optional[Path]: + if not rel_path: + return None + raw = Path(rel_path) + if raw.is_absolute() or ".." in raw.parts: + return None + try: + root = repo_root.resolve() + path = (root / raw).resolve() + if os.path.commonpath([str(root), str(path)]) != str(root): + return None + return path + except (OSError, ValueError): + return None + + +def _condition_hashes(op: Dict[str, Any], group: str) -> List[str]: + return [ + str(check.get("equals")) + for check in op.get(group) or [] + if isinstance(check, dict) and check.get("kind") == "sha256" and check.get("equals") + ] + + +def _candidate_files(repo_root: Path, rel_path: str, language: str) -> Tuple[List[Path], Dict[str, Any]]: + del language # Reserved for future grammar-aware symbol lookup. + basename = Path(rel_path).name + if not basename: + return [], {"scanned_files": 0, "truncated": False} + + candidates: List[Path] = [] + scanned = 0 + truncated = False + for root, dirnames, filenames in os.walk(repo_root): + dirnames[:] = [ + name + for name in dirnames + if name not in _EXCLUDED_WALK_DIRS and not name.endswith(".egg-info") + ] + for filename in filenames: + scanned += 1 + if scanned > _MAX_CANDIDATE_SCAN_FILES: + truncated = True + return candidates, {"scanned_files": scanned, "truncated": truncated} + if filename != basename: + continue + path = (Path(root) / filename).resolve() + if path.is_file(): + candidates.append(path) + if len(candidates) >= _MAX_CANDIDATE_MATCHES: + truncated = True + return candidates, {"scanned_files": scanned, "truncated": truncated} + return candidates, {"scanned_files": scanned, "truncated": truncated} + + +def _anchor_match_count(path: Path, anchors: Iterable[Any]) -> int: + cleaned = [str(anchor).strip() for anchor in anchors or [] if str(anchor or "").strip()] + if not cleaned: + return 0 + try: + if path.stat().st_size > _MAX_ANCHOR_SCAN_BYTES: + return 0 + text = path.read_text(encoding="utf-8", errors="ignore") + except OSError: + return 0 + return sum(1 for anchor in cleaned if anchor in text) + + +def _resolve_operation_target( + repo_root: Path, + op: Dict[str, Any], + file_symbol: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + declared_rel = str(op.get("path") or "") + declared_path = _safe_repo_path(repo_root, declared_rel) + if declared_path is None: + return { + "declared_path": declared_rel, + "resolved_path": declared_rel, + "absolute_path": None, + "exists": False, + "resolution": "unsafe_path", + "resolution_confidence": 0.0, + "diagnostics": [{ + "level": "error", + "code": "UNSAFE_OPERATION_PATH", + "path": declared_rel, + "message": "Operation path is absolute or escapes the target repo.", + }], + } + + if declared_path.exists(): + return { + "declared_path": declared_rel, + "resolved_path": declared_rel, + "absolute_path": declared_path, + "exists": True, + "resolution": "exact_path", + "resolution_confidence": 1.0, + "diagnostics": [], + } + + pre_hashes = _condition_hashes(op, "preconditions") + post_hashes = _condition_hashes(op, "postconditions") + language = str((file_symbol or {}).get("language") or _language_for(declared_rel)) + candidates, scan = _candidate_files(repo_root, declared_rel, language) + diagnostics: List[Dict[str, Any]] = [] + if scan.get("truncated"): + diagnostics.append({ + "level": "warning", + "code": "TARGET_SEARCH_TRUNCATED", + "path": declared_rel, + "message": "Target search hit the scan limit before inspecting the full repo.", + "scanned_files": scan.get("scanned_files"), + }) + + for candidate in candidates: + digest = file_sha256(candidate) + if digest and digest in post_hashes: + return { + "declared_path": declared_rel, + "resolved_path": _repo_relative(repo_root, candidate), + "absolute_path": candidate, + "exists": True, + "resolution": "moved_after_hash_match", + "resolution_confidence": 0.98, + "diagnostics": diagnostics, + } + if digest and digest in pre_hashes: + return { + "declared_path": declared_rel, + "resolved_path": _repo_relative(repo_root, candidate), + "absolute_path": candidate, + "exists": True, + "resolution": "moved_before_hash_match", + "resolution_confidence": 0.95, + "diagnostics": diagnostics, + } + + anchors = list(op.get("anchor_hints") or []) + list((file_symbol or {}).get("anchors") or []) + best_candidate: Optional[Path] = None + best_count = 0 + for candidate in candidates: + count = _anchor_match_count(candidate, anchors) + if count > best_count: + best_candidate = candidate + best_count = count + if best_candidate is not None and best_count > 0: + diagnostics.append({ + "level": "warning", + "code": "TARGET_RESOLVED_BY_ANCHOR", + "path": declared_rel, + "resolved_path": _repo_relative(repo_root, best_candidate), + "message": "Exact path was missing; interpreter resolved a same-name candidate using anchor hints.", + "anchor_matches": best_count, + }) + return { + "declared_path": declared_rel, + "resolved_path": _repo_relative(repo_root, best_candidate), + "absolute_path": best_candidate, + "exists": True, + "resolution": "anchor_match", + "resolution_confidence": min(0.85, 0.45 + (best_count * 0.1)), + "diagnostics": diagnostics, + } + + return { + "declared_path": declared_rel, + "resolved_path": declared_rel, + "absolute_path": declared_path, + "exists": False, + "resolution": "missing", + "resolution_confidence": 0.0, + "diagnostics": diagnostics, + } + + +def _operation_state( + repo_root: Path, + op: Dict[str, Any], + file_symbol: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + rel_path = str(op.get("path") or "") + resolution = _resolve_operation_target(repo_root, op, file_symbol) + path = resolution.get("absolute_path") + target_hash = file_sha256(path) if isinstance(path, Path) else None + pre_hashes = _condition_hashes(op, "preconditions") + post_hashes = _condition_hashes(op, "postconditions") + exists = bool(resolution.get("exists")) + if post_hashes and target_hash in post_hashes: + state = "already_applied" + elif pre_hashes and target_hash in pre_hashes: + state = "ready" + elif op.get("op") == "create_file" and not exists: + state = "ready" + elif op.get("op") == "delete_file" and not exists: + state = "already_applied" + elif not exists: + state = "blocked" + else: + state = "conflict" + return { + "path": rel_path, + "declared_path": resolution.get("declared_path") or rel_path, + "resolved_path": resolution.get("resolved_path") or rel_path, + "resolution": resolution.get("resolution"), + "resolution_confidence": resolution.get("resolution_confidence"), + "exists": exists, + "current_sha256": target_hash, + "state": state, + "expected_before": pre_hashes[0] if pre_hashes else None, + "expected_after": post_hashes[0] if post_hashes else None, + "diagnostics": resolution.get("diagnostics") or [], + } + + +def interpret_context_ir( + *, + repo: str | os.PathLike[str] | None, + capsule_or_ir: Dict[str, Any], + strict: bool = False, +) -> Dict[str, Any]: + """Interpret compiled context IR on a target checkout.""" + + repo_root = _resolve_repo_root(repo) + ir = _load_ir(capsule_or_ir) + diagnostics: List[Dict[str, Any]] = [] + validation = validate_context_ir(ir, strict=strict) + diagnostics.extend(validation["diagnostics"]) + file_symbols = { + str(symbol.get("symbol")): symbol + for symbol in ((ir.get("symbol_table") or {}).get("files") or []) + if isinstance(symbol, dict) and symbol.get("symbol") + } + operation_states = [ + _operation_state(repo_root, op, file_symbols.get(str(op.get("target")))) + for op in ir.get("operations") or [] + ] + for state in operation_states: + diagnostics.extend(state.get("diagnostics") or []) + if state["state"] == "conflict": + diagnostics.append({ + "level": "warning" if not strict else "error", + "code": "PRECONDITION_MISMATCH", + "path": state["path"], + "resolved_path": state.get("resolved_path"), + "message": "Target file hash matches neither the compiled before nor after state.", + }) + elif state["state"] == "blocked": + diagnostics.append({ + "level": "error", + "code": "TARGET_PATH_MISSING", + "path": state["path"], + "resolved_path": state.get("resolved_path"), + "message": "Target path is missing and cannot satisfy the operation precondition.", + }) + + states = {state["state"] for state in operation_states} + blocking_codes = {"UNKNOWN_IR_SCHEMA", "PRIVATE_BODY_PRESENT", "UNSAFE_OPERATION_PATH"} + if not validation["ok"] or any(diag.get("level") == "error" and diag.get("code") in blocking_codes for diag in diagnostics): + readiness = "blocked" + elif states and states <= {"already_applied"}: + readiness = "already_applied" + elif "conflict" in states: + readiness = "conflict" + elif "blocked" in states: + readiness = "blocked" + else: + readiness = "ready" + + steps: List[Dict[str, Any]] = [] + for index, op in enumerate(ir.get("operations") or [], start=1): + state = next((item for item in operation_states if item["path"] == op.get("path")), {}) + steps.append({ + "step": index, + "action": op.get("op"), + "path": op.get("path"), + "resolved_path": state.get("resolved_path"), + "resolution": state.get("resolution"), + "state": state.get("state"), + "instruction": _instruction_for_operation(op, state), + "anchor_hints": op.get("anchor_hints") or [], + "preconditions": op.get("preconditions") or [], + "postconditions": op.get("postconditions") or [], + }) + + return { + "format": INTERPRETER_SCHEMA, + "repo": str(repo_root), + "module": ir.get("module") or {}, + "readiness": readiness, + "validation": validation, + "operation_states": operation_states, + "execution_plan": steps, + "verification_plan": ir.get("verification") or {}, + "diagnostics": list(ir.get("diagnostics") or []) + diagnostics, + "policy": { + "auto_apply": False, + "requires_agent_editing": True, + "strict": bool(strict), + }, + } + + +def _instruction_for_operation(op: Dict[str, Any], state: Dict[str, Any]) -> str: + path = op.get("path") or "" + resolved_path = state.get("resolved_path") or path + resolved_note = f" (resolved on this repo as `{resolved_path}`)" if resolved_path != path else "" + if state.get("state") == "already_applied": + return f"`{path}`{resolved_note} already satisfies the compiled postcondition." + if state.get("state") == "conflict": + return f"Inspect `{path}`{resolved_note} manually; current content differs from both compiled before and after hashes." + if state.get("state") == "blocked": + return f"Resolve missing target path `{path}`{resolved_note} before replaying this operation." + action = op.get("op") + if action == "create_file": + return f"Create `{path}`{resolved_note} using the compact diff recipe and intent from the capsule." + if action == "delete_file": + return f"Delete `{path}`{resolved_note} if the target still matches the compiled precondition." + return f"Modify `{path}`{resolved_note} using compact hunk `{op.get('hunk_ref')}` and anchor hints, then verify postconditions." diff --git a/dhee/contract_runtime.py b/dhee/contract_runtime.py new file mode 100644 index 0000000..260a0c5 --- /dev/null +++ b/dhee/contract_runtime.py @@ -0,0 +1,1034 @@ +"""Active contract runtime for Dhee router tool enforcement. + +Task contracts are useful only if the execution boundary respects them. This +module binds one active contract to a repo and lets router tools ask a simple +question before doing work: is this read/search/test inside the contract? +""" + +from __future__ import annotations + +import os +import shlex +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from dhee import repo_link +from dhee.runtime_io import append_jsonl_locked, read_json_checked, read_jsonl_checked, write_json_atomic +from dhee.task_contracts import ( + _load_task_contract, + _resolve_repo_root, + _sanitize_obj, + _stable_hash, + interpret_task_contract, +) + + +ACTIVE_CONTRACT_SCHEMA = "dhee.active_contract_runtime.v1" +ENFORCEMENT_POLICY_SCHEMA = "dhee.contract_enforcement_policy.v1" +CONTRACT_TOOL_REFUSAL_SCHEMA = "dhee.contract_tool_refusal.v1" +CONTRACT_TOOL_GUARD_SCHEMA = "dhee.contract_tool_guard.v1" +CONTRACT_RUNTIME_DOCTOR_SCHEMA = "dhee.contract_runtime_doctor.v1" +CONTRACT_SUPERVISOR_UNAVAILABLE = "CONTRACT_SUPERVISOR_UNAVAILABLE" + +_CONTRACT_REF_KEYS = ("contract_task_id", "task_contract_id", "contract_id", "task_id", "contract_path") +_ENFORCEMENT_MODES = {"off", "warn", "deny"} +_READ_TOOL_NAMES = {"read", "dhee_read", "Read"} +_GREP_TOOL_NAMES = {"grep", "dhee_grep", "Grep"} +_BASH_TOOL_NAMES = {"bash", "dhee_bash", "Bash"} +_EDIT_TOOL_NAMES = {"edit", "write", "multi_edit", "notebook_edit", "Edit", "Write", "MultiEdit", "NotebookEdit"} + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _runtime_dir(repo_root: Path) -> Path: + return repo_link.repo_context_dir(repo_root) / "task_runs" + + +def _active_path(repo_root: Path) -> Path: + return _runtime_dir(repo_root) / "active_contract.json" + + +def _enforcement_path(repo_root: Path) -> Path: + return _runtime_dir(repo_root) / "enforcement.json" + + +def _task_runtime_events_path(repo_root: Path, task_id: str) -> Path: + return _runtime_dir(repo_root) / str(task_id or "unknown") / "runtime_events.jsonl" + + +def _write_json(path: Path, data: Dict[str, Any]) -> None: + result = write_json_atomic(path, data, sanitize=_sanitize_obj) + if not result.get("ok"): + diagnostic = result.get("diagnostic") or {} + raise RuntimeError(diagnostic.get("message") or f"failed to write {path}") + + +def _read_json_checked(path: Path, *, expected_schema: str | None = None, quarantine: bool = False) -> Dict[str, Any]: + return read_json_checked(path, expected_schema=expected_schema, quarantine=quarantine) + + +def _read_json(path: Path) -> Optional[Dict[str, Any]]: + result = _read_json_checked(path) + data = result.get("data") + return data if isinstance(data, dict) and result.get("ok") else None + + +def _append_runtime_event(repo_root: Path, task_id: str, event: Dict[str, Any]) -> Dict[str, Any]: + path = _task_runtime_events_path(repo_root, task_id) + return append_jsonl_locked(path, event, sanitize=_sanitize_obj) + + +def _contract_ref_from_args(arguments: Dict[str, Any]) -> Optional[str]: + for key in _CONTRACT_REF_KEYS: + value = arguments.get(key) + if value: + return str(value) + return None + + +def _looks_like_file(path: Path, raw: str) -> bool: + if path.exists(): + return path.is_file() + suffix = Path(raw).suffix + return bool(suffix) + + +def _candidate_repo_roots(arguments: Dict[str, Any]) -> Iterable[Path]: + keys = ("repo", "cwd", "file_path", "path") + seen: set[str] = set() + for key in keys: + raw = str(arguments.get(key) or "").strip() + if not raw: + continue + p = Path(raw).expanduser() + if not p.is_absolute(): + p = Path(os.getcwd()) / p + if _looks_like_file(p, raw): + p = p.parent + try: + root = _resolve_repo_root(p) + except Exception: + continue + marker = str(root) + if marker in seen: + continue + seen.add(marker) + yield root + try: + root = _resolve_repo_root(os.getcwd()) + if str(root) not in seen: + yield root + except Exception: + return + + +def _repo_root_for_policy(repo: str | os.PathLike[str] | Path | None) -> Path: + return _resolve_repo_root(repo or os.getcwd()) + + +def _mode_from_policy_data(data: Dict[str, Any] | None) -> Optional[str]: + mode = str((data or {}).get("mode") or "").strip().lower() + return mode if mode in _ENFORCEMENT_MODES else None + + +def _env_forces_deny() -> bool: + return _truthy(os.environ.get("DHEE_REQUIRE_ACTIVE_CONTRACT")) + + +def contract_enforcement_status(*, repo: str | os.PathLike[str] | None = None) -> Dict[str, Any]: + """Return the effective contract enforcement policy for a repo. + + Public installs default to ``off``. Premium/strict harnesses can set the + repo policy to ``deny`` or force it process-wide with + ``DHEE_REQUIRE_ACTIVE_CONTRACT=1``. + """ + + repo_root = _repo_root_for_policy(repo) + path = _enforcement_path(repo_root) + checked = _read_json_checked(path, expected_schema=ENFORCEMENT_POLICY_SCHEMA) + diagnostics = list(checked.get("diagnostics") or []) + data = checked.get("data") if isinstance(checked.get("data"), dict) else None + configured_mode = "off" + policy_corrupt = False + if checked.get("exists"): + mode = _mode_from_policy_data(data) + if mode: + configured_mode = mode + else: + configured_mode = "deny" + policy_corrupt = True + diagnostics.append( + { + "code": "ENFORCEMENT_POLICY_INVALID", + "message": "Enforcement policy exists but does not contain a valid off/warn/deny mode.", + "path": str(path), + "observed_mode": (data or {}).get("mode") if isinstance(data, dict) else None, + } + ) + forced_by_env = _env_forces_deny() + effective_mode = "deny" if forced_by_env else configured_mode + return { + "format": ENFORCEMENT_POLICY_SCHEMA, + "repo": str(repo_root), + "mode": effective_mode, + "configured_mode": configured_mode, + "forced_by_env": forced_by_env, + "policy_corrupt": policy_corrupt, + "diagnostics": diagnostics, + "paths": {"policy": str(path)}, + } + + +def set_contract_enforcement( + mode: str, + *, + repo: str | os.PathLike[str] | None = None, + agent_id: str | None = None, + reason: str | None = None, +) -> Dict[str, Any]: + """Persist the repo's contract enforcement policy.""" + + normalized = str(mode or "").strip().lower() + if normalized not in _ENFORCEMENT_MODES: + raise ValueError("mode must be one of: off, warn, deny") + repo_root = _repo_root_for_policy(repo) + repo_link._ensure_repo_skeleton(repo_root) + now = _now_iso() + policy = { + "format": ENFORCEMENT_POLICY_SCHEMA, + "schema_version": ENFORCEMENT_POLICY_SCHEMA, + "mode": normalized, + "repo": str(repo_root), + "updated_at": now, + "updated_by": agent_id or os.environ.get("DHEE_AGENT_ID") or "unknown", + "reason": reason or "manual", + } + _write_json(_enforcement_path(repo_root), policy) + _append_runtime_event( + repo_root, + "enforcement", + { + "event": "enforcement_policy_set", + "created_at": now, + "mode": normalized, + "agent_id": policy["updated_by"], + "reason": policy["reason"], + }, + ) + return { + **policy, + "effective": contract_enforcement_status(repo=repo_root), + "paths": {"policy": str(_enforcement_path(repo_root))}, + } + + +def _record_enforcement_warning( + repo_root: Path, + *, + tool_name: str, + code: str, + message: str, + diagnostics: Optional[List[Dict[str, Any]]] = None, +) -> None: + _append_runtime_event( + repo_root, + "enforcement", + { + "event": "enforcement_warning", + "created_at": _now_iso(), + "tool_name": tool_name, + "code": code, + "message": message, + "diagnostics": diagnostics or [], + }, + ) + + +def _last_event(path: Path) -> Optional[Dict[str, Any]]: + checked = read_jsonl_checked(path) + records = checked.get("records") or [] + return records[-1] if records else None + + +def _active_corrupt_codes(diagnostics: Iterable[Dict[str, Any]]) -> bool: + return any( + str(diag.get("code") or "") in {"RUNTIME_JSON_CORRUPT", "RUNTIME_JSON_NOT_OBJECT", "RUNTIME_SCHEMA_MISMATCH"} + for diag in diagnostics or [] + ) + + +def _repo_relative(repo_root: Path, path: str, *, cwd: str | os.PathLike[str] | None = None) -> str: + raw = Path(str(path or "")).expanduser() + if not raw.is_absolute(): + base = Path(cwd).expanduser() if cwd else repo_root + if not base.is_absolute(): + base = Path(os.getcwd()) / base + raw = base / raw + try: + resolved = raw.resolve() + root = repo_root.resolve() + if os.path.commonpath([str(root), str(resolved)]) == str(root): + return os.path.relpath(resolved, root).replace(os.sep, "/") + except (OSError, ValueError): + pass + return str(path or "") + + +def _scope_relative(repo_root: Path, path: str, *, cwd: str | os.PathLike[str] | None = None) -> str: + rel = _repo_relative(repo_root, path or ".", cwd=cwd) + return "." if rel in {"", "."} else rel + + +def _contract_hash(compiled: Dict[str, Any]) -> str: + return _stable_hash( + { + "contract": compiled.get("contract") or {}, + "compiler": compiled.get("compiler") or {}, + "actions": [ + { + "action_id": action.get("action_id"), + "type": action.get("type"), + "operands": action.get("operands") or {}, + "requires": action.get("requires") or [], + } + for action in compiled.get("actions") or [] + if isinstance(action, dict) + ], + }, + 24, + ) + + +def activate_contract_runtime( + task_contract: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + strict: bool = False, + force: bool = False, + agent_id: str | None = None, + harness: str | None = None, +) -> Dict[str, Any]: + """Select one task contract as the repo's active router runtime.""" + + repo_root = _resolve_repo_root(repo) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + interpretation = interpret_task_contract(compiled, repo=repo_root, strict=strict) + readiness = str(interpretation.get("readiness") or "") + if readiness == "blocked" and not force: + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "rejected", + "reason": "contract_not_ready", + "repo": str(repo_root), + "task_id": (compiled.get("contract") or {}).get("task_id"), + "interpretation": interpretation, + } + + contract = compiled.get("contract") or {} + task_id = str(contract.get("task_id") or "unknown") + if isinstance(task_contract, dict): + contract_ref = task_id + else: + source = Path(str(task_contract)).expanduser() + contract_ref = str(source.resolve()) if source.exists() else str(task_contract) + runtime = { + "format": ACTIVE_CONTRACT_SCHEMA, + "schema_version": ACTIVE_CONTRACT_SCHEMA, + "active": True, + "status": "active", + "task_id": task_id, + "contract_ref": contract_ref, + "repo": str(repo_root), + "strict": bool(strict), + "force": bool(force), + "contract_hash": _contract_hash(compiled), + "activated_at": _now_iso(), + "activated_by": agent_id or os.environ.get("DHEE_AGENT_ID") or "unknown", + "harness": harness or os.environ.get("DHEE_HARNESS") or os.environ.get("DHEE_AGENT_ID") or "unknown", + "policy": { + "enforce_router_tools": True, + "auto_record_observations": True, + "auto_execute": False, + "allowed_router_tools": ["dhee_read", "dhee_grep", "dhee_bash"], + }, + "interpretation": { + "readiness": interpretation.get("readiness"), + "diagnostic_count": len(interpretation.get("diagnostics") or []), + }, + } + repo_link._ensure_repo_skeleton(repo_root) + _write_json(_active_path(repo_root), runtime) + _append_runtime_event( + repo_root, + task_id, + { + "event": "activate", + "created_at": runtime["activated_at"], + "task_id": task_id, + "strict": bool(strict), + "force": bool(force), + "contract_hash": runtime["contract_hash"], + "agent_id": runtime["activated_by"], + "harness": runtime["harness"], + }, + ) + return { + **runtime, + "paths": { + "active": str(_active_path(repo_root)), + "events": str(_task_runtime_events_path(repo_root, task_id)), + }, + "interpretation": interpretation, + } + + +def deactivate_contract_runtime( + *, + repo: str | os.PathLike[str] | None = None, + agent_id: str | None = None, + reason: str = "manual", +) -> Dict[str, Any]: + """Deactivate the repo's selected contract without deleting history.""" + + repo_root = _resolve_repo_root(repo) + path = _active_path(repo_root) + checked = _read_json_checked(path, expected_schema=ACTIVE_CONTRACT_SCHEMA, quarantine=True) + diagnostics = list(checked.get("diagnostics") or []) + if checked.get("exists") and not checked.get("ok"): + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "corrupt", + "error": "ACTIVE_CONTRACT_CORRUPT", + "repo": str(repo_root), + "diagnostics": diagnostics, + "quarantine": checked.get("quarantine"), + "paths": {"active": str(path)}, + } + runtime = checked.get("data") if isinstance(checked.get("data"), dict) else None + if not runtime or not runtime.get("active"): + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "inactive", + "repo": str(repo_root), + "diagnostics": diagnostics, + "paths": {"active": str(path)}, + } + task_id = str(runtime.get("task_id") or "unknown") + now = _now_iso() + runtime.update({ + "active": False, + "status": "inactive", + "deactivated_at": now, + "deactivated_by": agent_id or os.environ.get("DHEE_AGENT_ID") or "unknown", + "deactivation_reason": reason, + }) + _write_json(path, runtime) + _append_runtime_event( + repo_root, + task_id, + { + "event": "deactivate", + "created_at": now, + "task_id": task_id, + "reason": reason, + "agent_id": runtime["deactivated_by"], + }, + ) + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "inactive", + "repo": str(repo_root), + "task_id": task_id, + "paths": { + "active": str(path), + "events": str(_task_runtime_events_path(repo_root, task_id)), + }, + } + + +def contract_runtime_status(*, repo: str | os.PathLike[str] | None = None) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + path = _active_path(repo_root) + checked = _read_json_checked(path, expected_schema=ACTIVE_CONTRACT_SCHEMA, quarantine=True) + diagnostics = list(checked.get("diagnostics") or []) + enforcement = contract_enforcement_status(repo=repo_root) + if checked.get("exists") and not checked.get("ok"): + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "corrupt", + "error": "ACTIVE_CONTRACT_CORRUPT", + "repo": str(repo_root), + "diagnostics": diagnostics, + "enforcement": enforcement, + "quarantine": checked.get("quarantine"), + "paths": {"active": str(path)}, + } + runtime = checked.get("data") if isinstance(checked.get("data"), dict) else None + if not runtime or not runtime.get("active"): + return { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": False, + "status": "inactive", + "repo": str(repo_root), + "diagnostics": diagnostics, + "enforcement": enforcement, + "paths": {"active": str(path)}, + } + task_id = str(runtime.get("task_id") or "unknown") + try: + interpretation = interpret_task_contract( + str(runtime.get("contract_ref") or task_id), + repo=repo_root, + strict=bool(runtime.get("strict")), + ) + except Exception as exc: + interpretation = { + "readiness": "blocked", + "diagnostics": [ + { + "level": "error", + "code": "ACTIVE_CONTRACT_LOAD_FAILED", + "message": f"{type(exc).__name__}: {exc}", + } + ], + } + return { + **runtime, + "repo": str(repo_root), + "paths": { + "active": str(path), + "events": str(_task_runtime_events_path(repo_root, task_id)), + }, + "diagnostics": diagnostics, + "enforcement": enforcement, + "interpretation": interpretation, + } + + +def _active_runtime_for_call(arguments: Dict[str, Any]) -> Tuple[Optional[Path], Optional[Dict[str, Any]], List[Dict[str, Any]]]: + contract_ref = _contract_ref_from_args(arguments) + collected_diagnostics: List[Dict[str, Any]] = [] + for repo_root in _candidate_repo_roots(arguments): + enforcement = contract_enforcement_status(repo=repo_root) + if contract_ref and enforcement.get("mode") != "deny": + return repo_root, { + "format": ACTIVE_CONTRACT_SCHEMA, + "active": True, + "status": "ephemeral", + "task_id": contract_ref, + "repo": str(repo_root), + "strict": bool(arguments.get("contract_strict") or arguments.get("strict") or False), + "policy": { + "enforce_router_tools": True, + "auto_record_observations": True, + "auto_execute": False, + }, + }, collected_diagnostics + checked = _read_json_checked(_active_path(repo_root), expected_schema=ACTIVE_CONTRACT_SCHEMA, quarantine=True) + diagnostics = [ + diag + for diag in checked.get("diagnostics") or [] + if diag.get("code") != "RUNTIME_FILE_MISSING" + ] + if diagnostics: + collected_diagnostics.extend(diagnostics) + if checked.get("exists") and not checked.get("ok"): + return repo_root, None, collected_diagnostics + runtime = checked.get("data") if isinstance(checked.get("data"), dict) else None + if runtime and runtime.get("active"): + return repo_root, runtime, collected_diagnostics + return None, None, collected_diagnostics + + +def _bash_action_from_command(command: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + timeout = arguments.get("timeout_sec", arguments.get("timeout", 120)) + try: + timeout_sec = int(float(timeout)) + except (TypeError, ValueError): + timeout_sec = 120 + return { + "type": "RUN_TEST", + "command": str(command or "").strip(), + "timeout_sec": timeout_sec, + "reason": "Execute a compiled must_run command under the active contract.", + } + + +def _router_action(tool_name: str, arguments: Dict[str, Any], repo_root: Path) -> Dict[str, Any]: + normalized = str(tool_name or "").strip() + if normalized in _READ_TOOL_NAMES: + file_path = str(arguments.get("file_path") or "") + return { + "type": "READ_FILE", + "path": _repo_relative(repo_root, file_path, cwd=arguments.get("cwd")), + "reason": "Read through dhee_read under the active contract.", + } + if normalized in _GREP_TOOL_NAMES: + path = str(arguments.get("path") or ".") + return { + "type": "SEARCH_CODE", + "query": str(arguments.get("pattern") or arguments.get("query") or ""), + "scope": _scope_relative(repo_root, path), + "reason": "Search through dhee_grep under the active contract.", + } + if normalized in _BASH_TOOL_NAMES: + return _bash_action_from_command(str(arguments.get("command") or ""), arguments) + if normalized in _EDIT_TOOL_NAMES: + path = str( + arguments.get("file_path") + or arguments.get("path") + or arguments.get("notebook_path") + or "" + ) + patch_payload = { + "old_string": arguments.get("old_string"), + "new_string": arguments.get("new_string"), + "edits": arguments.get("edits"), + "content_hash": _stable_hash(arguments.get("content") or arguments.get("new_string") or arguments.get("edits") or "", 12), + } + proof = arguments.get("proof") if isinstance(arguments.get("proof"), dict) else arguments.get("dhee_proof") + return { + "type": "EDIT_FILE", + "path": _repo_relative(repo_root, path, cwd=arguments.get("cwd")), + "patch": arguments.get("patch") or f"native_edit:{_stable_hash(patch_payload, 16)}", + "proof": proof if isinstance(proof, dict) else {}, + "reason": "Native edit tool call under the active contract.", + } + return {"type": str(tool_name or ""), "reason": "Unknown router tool."} + + +def _contract_ref_for_runtime(runtime: Dict[str, Any]) -> str: + return str(runtime.get("contract_ref") or runtime.get("contract_path") or runtime.get("task_id") or runtime.get("contract_id") or "") + + +def guard_router_call(tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Return allow/deny for a router call under the active contract, if any.""" + + repo_root, runtime, runtime_diagnostics = _active_runtime_for_call(arguments) + if repo_root is None: + try: + repo_root = _repo_root_for_policy(arguments.get("repo") or arguments.get("cwd") or os.getcwd()) + except Exception: + repo_root = None + enforcement = contract_enforcement_status(repo=repo_root) if repo_root else { + "mode": "deny" if _env_forces_deny() else "off", + "configured_mode": "off", + "forced_by_env": _env_forces_deny(), + "diagnostics": [], + } + mode = str(enforcement.get("mode") or "off") + diagnostics = [*runtime_diagnostics, *(enforcement.get("diagnostics") or [])] + if not repo_root or not runtime: + required = _truthy(arguments.get("require_active_contract")) or mode == "deny" + corrupt_active = _active_corrupt_codes(runtime_diagnostics) + error = None + if corrupt_active: + error = "ACTIVE_CONTRACT_CORRUPT" + elif required: + error = "ACTIVE_CONTRACT_REQUIRED" + warning = "" + if not required and mode == "warn": + warning = "No active task contract is bound to this repo; warn mode allowed the tool call." + if repo_root: + _record_enforcement_warning( + repo_root, + tool_name=tool_name, + code="ACTIVE_CONTRACT_MISSING_WARN", + message=warning, + diagnostics=diagnostics, + ) + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": False, + "allowed": not required and not (corrupt_active and mode == "deny"), + "tool_name": tool_name, + "repo": str(repo_root) if repo_root else None, + "error": error if required or corrupt_active else None, + "message": ( + "Active task contract runtime is corrupt and was quarantined." + if corrupt_active + else ("No active task contract is bound to this repo." if required else warning) + ), + "diagnostics": diagnostics, + "enforcement": enforcement, + } + action = _router_action(tool_name, arguments, repo_root) + task_ref = _contract_ref_for_runtime(runtime) + if not task_ref: + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": True, + "allowed": False, + "repo": str(repo_root), + "tool_name": tool_name, + "proposed_action": action, + "error": "active_runtime_missing_task_ref", + "runtime": runtime, + "diagnostics": diagnostics, + "enforcement": enforcement, + } + try: + from dhee.contract_supervisor import supervise_action + + decision = supervise_action( + task_ref, + action, + repo=repo_root, + strict=bool(runtime.get("strict")), + ) + except Exception as exc: + message = f"{type(exc).__name__}: {exc}" + diagnostics.append({ + "code": CONTRACT_SUPERVISOR_UNAVAILABLE, + "message": message, + "tool_name": tool_name, + "repo": str(repo_root), + }) + if mode == "deny": + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": True, + "allowed": False, + "repo": str(repo_root), + "task_id": str(runtime.get("task_id") or task_ref), + "tool_name": tool_name, + "proposed_action": action, + "error": CONTRACT_SUPERVISOR_UNAVAILABLE, + "message": "Contract supervisor could not load or execute; deny mode blocks the tool call.", + "diagnostics": diagnostics, + "enforcement": enforcement, + "runtime": { + "status": runtime.get("status"), + "task_id": runtime.get("task_id"), + "strict": bool(runtime.get("strict")), + "contract_hash": runtime.get("contract_hash"), + }, + } + if mode == "warn": + warning = "Contract supervisor could not load or execute; warn mode allowed the tool call." + _record_enforcement_warning( + repo_root, + tool_name=tool_name, + code=CONTRACT_SUPERVISOR_UNAVAILABLE, + message=warning, + diagnostics=diagnostics, + ) + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": True, + "allowed": True, + "repo": str(repo_root), + "task_id": str(runtime.get("task_id") or task_ref), + "tool_name": tool_name, + "proposed_action": action, + "warning": warning, + "diagnostics": diagnostics, + "enforcement": enforcement, + "runtime": { + "status": runtime.get("status"), + "task_id": runtime.get("task_id"), + "strict": bool(runtime.get("strict")), + "contract_hash": runtime.get("contract_hash"), + }, + } + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": True, + "allowed": True, + "repo": str(repo_root), + "task_id": str(runtime.get("task_id") or task_ref), + "tool_name": tool_name, + "proposed_action": action, + "warning": "Contract supervisor unavailable; enforcement mode off preserved compatibility behavior.", + "diagnostics": diagnostics, + "enforcement": enforcement, + "runtime": { + "status": runtime.get("status"), + "task_id": runtime.get("task_id"), + "strict": bool(runtime.get("strict")), + "contract_hash": runtime.get("contract_hash"), + }, + } + allowed = decision.get("decision") == "allow" + return { + "format": CONTRACT_TOOL_GUARD_SCHEMA, + "active": True, + "allowed": allowed, + "repo": str(repo_root), + "task_id": decision.get("task_id") or task_ref, + "tool_name": tool_name, + "proposed_action": action, + "decision": decision, + "diagnostics": diagnostics, + "enforcement": enforcement, + "runtime": { + "status": runtime.get("status"), + "task_id": runtime.get("task_id"), + "strict": bool(runtime.get("strict")), + "contract_hash": runtime.get("contract_hash"), + }, + } + + +def _truthy(value: Any) -> bool: + return str(value or "").strip().lower() in {"1", "true", "yes", "on"} + + +def router_refusal(guard: Dict[str, Any]) -> Dict[str, Any]: + decision = guard.get("decision") or {} + violations = decision.get("violations") or [] + codes = [str(item.get("code")) for item in violations if isinstance(item, dict) and item.get("code")] + if guard.get("error") and not codes: + codes = [str(guard.get("error"))] + return { + "format": CONTRACT_TOOL_REFUSAL_SCHEMA, + "error": "CONTRACT_TOOL_CALL_DENIED", + "message": guard.get("message") or "Active task contract refused this router tool call.", + "will_execute": False, + "tool_name": guard.get("tool_name"), + "repo": guard.get("repo"), + "task_id": guard.get("task_id"), + "proposed_action": guard.get("proposed_action"), + "decision": decision, + "violation_codes": codes, + "runtime": guard.get("runtime"), + "diagnostics": guard.get("diagnostics") or [], + "enforcement": guard.get("enforcement") or {}, + } + + +def router_result_runtime(guard: Dict[str, Any], observation: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + """Compact metadata attached to allowed router tool results.""" + + if not guard.get("active"): + return None + decision = guard.get("decision") or {} + matched = decision.get("matched_contract_action") or {} + out = { + "format": "dhee.contract_router_runtime.v1", + "task_id": guard.get("task_id"), + "decision": decision.get("decision"), + "action_id": (guard.get("proposed_action") or {}).get("action_id") or matched.get("action_id"), + "action_type": (guard.get("proposed_action") or {}).get("type"), + "matched_action": { + "action_id": matched.get("action_id"), + "type": matched.get("type"), + "phase": matched.get("phase"), + "target": matched.get("target"), + } if matched else None, + "observation": observation, + } + return out + + +def _observation_for_result(tool_name: str, result: Dict[str, Any]) -> Dict[str, Any]: + observation = { + "tool": tool_name, + "ptr": result.get("ptr"), + } + for key in ( + "line_count", + "char_count", + "match_count", + "file_count", + "total_bytes", + "exit_code", + "duration_ms", + "class", + "stdout_bytes", + "stderr_bytes", + "timed_out", + "inlined", + ): + if key in result: + observation[key] = result.get(key) + return observation + + +def _outcome_for_result(tool_name: str, result: Dict[str, Any]) -> str: + if result.get("error"): + return "failed" + if str(tool_name) in _BASH_TOOL_NAMES: + if result.get("timed_out"): + return "timed_out" + try: + exit_code = int(result.get("exit_code")) + except (TypeError, ValueError): + exit_code = 1 + return "passed" if exit_code == 0 else "failed" + return "observed" + + +def record_router_observation(guard: Dict[str, Any], result: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Persist successful router tool observations for the active interpreter.""" + + if not guard.get("active") or not guard.get("allowed") or result.get("error"): + return None + if not guard.get("task_id") or not guard.get("repo"): + return None + try: + from dhee.contract_supervisor import record_observation_transition + + record = record_observation_transition( + str(guard["task_id"]), + dict(guard.get("proposed_action") or {}), + _observation_for_result(str(guard.get("tool_name") or ""), result), + repo=str(guard["repo"]), + outcome=_outcome_for_result(str(guard.get("tool_name") or ""), result), + strict=bool((guard.get("runtime") or {}).get("strict")), + ) + except Exception as exc: + return { + "format": "dhee.contract_router_observation_error.v1", + "error": f"{type(exc).__name__}: {exc}", + } + return { + "format": "dhee.contract_router_observation.v1", + "event_id": (record.get("event") or {}).get("event_id"), + "outcome": (record.get("event") or {}).get("outcome"), + "events_path": (record.get("paths") or {}).get("events"), + } + + +def _path_writable(path: Path) -> Dict[str, Any]: + existing = path if path.exists() else path.parent + while not existing.exists() and existing.parent != existing: + existing = existing.parent + writable = os.access(str(existing), os.W_OK) if existing.exists() else False + return { + "path": str(path), + "exists": path.exists(), + "nearest_existing_parent": str(existing), + "writable": bool(writable), + } + + +def contract_runtime_doctor(*, repo: str | os.PathLike[str] | None = None) -> Dict[str, Any]: + """Report whether the contract runtime is actually protecting this repo.""" + + repo_root = _repo_root_for_policy(repo) + task_runs = _runtime_dir(repo_root) + active_path = _active_path(repo_root) + enforcement_path = _enforcement_path(repo_root) + enforcement = contract_enforcement_status(repo=repo_root) + active = contract_runtime_status(repo=repo_root) + corrupt_files: List[Dict[str, Any]] = [] + diagnostics: List[Dict[str, Any]] = [] + for source in (active, enforcement): + for diag in source.get("diagnostics") or []: + diagnostics.append(diag) + if str(diag.get("code") or "") in { + "RUNTIME_JSON_CORRUPT", + "RUNTIME_JSON_NOT_OBJECT", + "RUNTIME_JSONL_LINE_CORRUPT", + "RUNTIME_JSONL_LINE_NOT_OBJECT", + "ENFORCEMENT_POLICY_INVALID", + "RUNTIME_SCHEMA_MISMATCH", + }: + corrupt_files.append({ + "path": diag.get("path"), + "code": diag.get("code"), + "message": diag.get("message"), + }) + + router_health: Dict[str, Any] + try: + from dhee.router import install as router_install + + state = router_install.status() + router_health = { + "available": True, + "enabled": bool(state.enabled), + "managed": bool(state.managed), + "env_flag": bool(state.env_flag), + "allowed_tools": list(state.allowed_tools or []), + "settings_path": str(state.settings_path), + } + except Exception as exc: + router_health = { + "available": False, + "enabled": False, + "error": f"{type(exc).__name__}: {exc}", + } + + task_id = str(active.get("task_id") or "") + last_decision = None + if task_id: + supervisor_last = _last_event(task_runs / task_id / "events.jsonl") + runtime_last = _last_event(_task_runtime_events_path(repo_root, task_id)) + candidates = [item for item in [supervisor_last, runtime_last] if isinstance(item, dict)] + candidates.sort(key=lambda item: str(item.get("created_at") or "")) + last_decision = candidates[-1] if candidates else None + + bypass_risks: List[str] = [] + mode = str(enforcement.get("mode") or "off") + if mode == "off": + bypass_risks.append("enforcement_off") + if mode == "warn": + bypass_risks.append("warn_mode_allows_actions") + if not active.get("active"): + bypass_risks.append("no_active_contract") + if not router_health.get("enabled"): + bypass_risks.append("native_hook_or_router_not_enabled") + if corrupt_files: + bypass_risks.append("corrupt_runtime_state") + if not all(item.get("writable") for item in [ + _path_writable(task_runs), + _path_writable(active_path), + _path_writable(enforcement_path), + ]): + bypass_risks.append("runtime_path_not_writable") + + if corrupt_files or mode == "off": + protection = "unprotected" + elif mode == "deny" and active.get("active") and router_health.get("enabled"): + protection = "protected" + else: + protection = "partially_protected" + + writable_paths = { + "task_runs": _path_writable(task_runs), + "active_contract": _path_writable(active_path), + "enforcement_policy": _path_writable(enforcement_path), + } + return { + "format": CONTRACT_RUNTIME_DOCTOR_SCHEMA, + "repo": str(repo_root), + "status": protection, + "protected": protection == "protected", + "active_contract": { + "active": bool(active.get("active")), + "status": active.get("status"), + "task_id": active.get("task_id"), + "path": (active.get("paths") or {}).get("active"), + "readiness": (active.get("interpretation") or {}).get("readiness"), + }, + "enforcement": enforcement, + "hook_router_health": router_health, + "writable_runtime_paths": writable_paths, + "corrupt_files": corrupt_files, + "last_decision": last_decision, + "bypass_risks": bypass_risks, + "diagnostics": diagnostics, + } + + +def command_preview(command: str) -> Dict[str, Any]: + """Small helper for user-facing diagnostics around bash denials.""" + + try: + argv = shlex.split(command) + except ValueError: + argv = [] + return { + "argv0": argv[0] if argv else "", + "is_test_like": bool(argv and (argv[0] in {"pytest", "tox", "nox", "npm", "pnpm", "uv", "python", "python3"})), + } diff --git a/dhee/contract_supervisor.py b/dhee/contract_supervisor.py new file mode 100644 index 0000000..473a844 --- /dev/null +++ b/dhee/contract_supervisor.py @@ -0,0 +1,930 @@ +"""Contract supervisor for deterministic agent action enforcement. + +The task compiler says what should be done. The interpreter says whether a +target checkout can run it. The supervisor is the runtime gate: it decides +whether a proposed tool action is inside the interpreted contract and records +observation-to-next-action transitions. +""" + +from __future__ import annotations + +import json +import os +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +from dhee import repo_link +from dhee.runtime_io import append_jsonl_locked, read_jsonl_checked, write_json_atomic +from dhee.task_contracts import ( + ACTION_TYPES, + _SECRET_PATTERNS, + _action_operands, + _is_forbidden_path, + _path_under_allowed, + _safe_repo_path, + _sanitize_obj, + _stable_hash, + _tokens, + interpret_task_contract, + _load_task_contract, + _resolve_repo_root, +) + + +SUPERVISOR_DECISION_SCHEMA = "dhee.contract_supervisor_decision.v1" +OBSERVATION_EVENT_SCHEMA = "dhee.contract_observation_event.v1" +PROOF_BUNDLE_SCHEMA = "dhee.proof_bundle.v1" + +_RECOVERY_ACTIONS = {"SEARCH_CODE", "ASK_USER"} +_FORBIDDEN_SUBAGENT_PERMISSIONS = {"write:any", "shell:unsafe", "secrets:read", "network:unbounded"} +_SUCCESS_OUTCOMES = {"pass", "passed", "success", "succeeded", "ok"} +_BLOCKED_OUTCOMES = {"blocked", "denied", "rejected"} +_RUNTIME_ARTIFACT_PREFIXES = ( + ".dhee/context/task_runs/", + ".dhee/context/task_contracts/", + ".dhee/context/repo_brain/", +) + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _task_run_dir(repo_root: Path, task_id: str) -> Path: + return repo_link.repo_context_dir(repo_root) / "task_runs" / str(task_id or "unknown") + + +def _events_path(repo_root: Path, task_id: str) -> Path: + return _task_run_dir(repo_root, task_id) / "events.jsonl" + + +def _load_events(repo_root: Path, task_id: str) -> List[Dict[str, Any]]: + return list(_load_events_checked(repo_root, task_id).get("records") or []) + + +def _load_events_checked(repo_root: Path, task_id: str) -> Dict[str, Any]: + return read_jsonl_checked(_events_path(repo_root, task_id)) + + +def _outcome_is_success(outcome: Any) -> bool: + return str(outcome or "").strip().lower() in _SUCCESS_OUTCOMES + + +def _outcome_is_blocked(outcome: Any) -> bool: + return str(outcome or "").strip().lower() in _BLOCKED_OUTCOMES + + +def _decision_allows(event: Dict[str, Any]) -> bool: + decision = event.get("decision") or {} + return str(decision.get("decision") or "") in {"allow", "needs_input"} + + +def _action_key(action: Dict[str, Any]) -> Dict[str, Any]: + return { + "type": action.get("type"), + "operands": _action_operands(action), + } + + +def _match_planned_action(planned_actions: Iterable[Dict[str, Any]], proposed_action: Dict[str, Any]) -> Optional[Dict[str, Any]]: + action_id = proposed_action.get("action_id") + if action_id: + for planned in planned_actions or []: + if planned.get("action_id") == action_id: + return planned + + proposed_type = proposed_action.get("type") + proposed_operands = _action_operands(proposed_action) + exact: List[Dict[str, Any]] = [] + same_type: List[Dict[str, Any]] = [] + for planned in planned_actions or []: + if planned.get("type") != proposed_type: + continue + same_type.append(planned) + planned_operands = _action_operands(planned) + if proposed_operands and all(planned_operands.get(key) == value for key, value in proposed_operands.items() if key != "timeout_sec"): + exact.append(planned) + if exact: + return exact[0] + if len(same_type) == 1 and proposed_type in {"SUBMIT_PATCH", "WRITE_MEMORY_NOTE", "ASK_USER"}: + return same_type[0] + return None + + +def _event_action(event: Dict[str, Any]) -> Dict[str, Any]: + action = event.get("action") or {} + return action if isinstance(action, dict) else {} + + +def _event_action_id(event: Dict[str, Any], planned_actions: Iterable[Dict[str, Any]]) -> Optional[str]: + action = _event_action(event) + if action.get("action_id"): + return str(action.get("action_id")) + decision = event.get("decision") or {} + matched = decision.get("matched_contract_action") or {} + if matched.get("action_id"): + return str(matched.get("action_id")) + planned = _match_planned_action(planned_actions, action) + if planned and planned.get("action_id"): + return str(planned.get("action_id")) + return None + + +def _observed_action_ids(events: Iterable[Dict[str, Any]], planned_actions: Iterable[Dict[str, Any]]) -> List[str]: + out: List[str] = [] + for event in events or []: + if not _decision_allows(event) or _outcome_is_blocked(event.get("outcome")): + continue + action_id = _event_action_id(event, planned_actions) + if action_id and action_id not in out: + out.append(action_id) + return out + + +def _passed_test_commands(events: Iterable[Dict[str, Any]]) -> List[str]: + out: List[str] = [] + for event in events or []: + action = _event_action(event) + if action.get("type") != "RUN_TEST": + continue + if not _decision_allows(event) or not _outcome_is_success(event.get("outcome")): + continue + command = str(action.get("command") or "").strip() + if command and command not in out: + out.append(command) + return out + + +def _observed_read_paths(events: Iterable[Dict[str, Any]]) -> List[str]: + out: List[str] = [] + for event in events or []: + action = _event_action(event) + if action.get("type") != "READ_FILE": + continue + if not _decision_allows(event) or _outcome_is_blocked(event.get("outcome")): + continue + path = str(action.get("path") or "").strip() + if path and path not in out: + out.append(path) + return out + + +def _accepted_events(events: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + event + for event in events or [] + if event.get("accepted") is not False + and _decision_allows(event) + and not _outcome_is_blocked(event.get("outcome")) + ] + + +def _edit_proof(action: Dict[str, Any]) -> Dict[str, Any]: + proof = action.get("proof") if isinstance(action.get("proof"), dict) else {} + return { + "edit_span": action.get("edit_span") or action.get("span") or proof.get("edit_span") or proof.get("span"), + "invariant": action.get("invariant") or proof.get("invariant"), + "related_tests": action.get("related_tests") or action.get("related_test") or proof.get("related_tests") or proof.get("related_test"), + "rollback_point": action.get("rollback_point") or proof.get("rollback_point"), + } + + +def _as_list(value: Any) -> List[str]: + if value is None: + return [] + if isinstance(value, (list, tuple, set)): + return [str(item) for item in value if str(item).strip()] + return [str(value)] if str(value).strip() else [] + + +def _edit_span_is_valid(span: Any, path: str) -> bool: + if isinstance(span, dict): + span_path = str(span.get("path") or path) + if span_path != path: + return False + try: + start = int(span.get("start_line")) + end = int(span.get("end_line")) + except (TypeError, ValueError): + return False + return start > 0 and end >= start + text = str(span or "").strip() + return bool(text) + + +def _next_allowed_actions(planned_actions: Iterable[Dict[str, Any]], observed_ids: Iterable[str]) -> List[Dict[str, Any]]: + observed = set(observed_ids) + out: List[Dict[str, Any]] = [] + for action in planned_actions or []: + action_id = str(action.get("action_id") or "") + if action_id and action_id in observed: + continue + missing = [dep for dep in action.get("requires") or [] if dep not in observed] + if missing: + continue + out.append({ + "action_id": action_id, + "type": action.get("type"), + "phase": action.get("phase"), + "target": action.get("path") or action.get("command") or action.get("query") or action.get("summary") or action.get("category"), + "requires": action.get("requires") or [], + }) + return out[:8] + + +def _git_out(repo_root: Path, args: List[str]) -> str: + proc = subprocess.run( + ["git", "-C", str(repo_root), *args], + text=True, + capture_output=True, + check=False, + ) + return proc.stdout.strip() if proc.returncode == 0 else "" + + +def _normalize_status_path(raw: str) -> str: + value = str(raw or "").strip() + if " -> " in value: + value = value.split(" -> ", 1)[1].strip() + if len(value) >= 2 and value[0] == value[-1] == '"': + try: + return json.loads(value) + except Exception: + return value.strip('"') + return value + + +def _worktree_changed_paths(repo_root: Path) -> List[str]: + out: List[str] = [] + status = _git_out(repo_root, ["status", "--porcelain=v1", "--untracked-files=all"]) + for line in status.splitlines(): + if not line: + continue + path = _normalize_status_path(line[3:] if len(line) > 3 else line) + if path and path not in out: + out.append(path) + return out + + +def _is_runtime_artifact_path(path: str) -> bool: + normalized = str(path or "").replace("\\", "/").lstrip("./") + return any(normalized.startswith(prefix) for prefix in _RUNTIME_ARTIFACT_PREFIXES) + + +def _code_changed_paths(repo_root: Path) -> List[str]: + return [path for path in _worktree_changed_paths(repo_root) if not _is_runtime_artifact_path(path)] + + +def _secret_findings_in_diff(repo_root: Path) -> List[Dict[str, Any]]: + diff = _git_out(repo_root, ["diff", "--", "."]) + findings: List[Dict[str, Any]] = [] + current_file = "" + for line in diff.splitlines(): + if line.startswith("+++ b/"): + current_file = line[6:] + continue + if not line.startswith("+") or line.startswith("+++"): + continue + if _is_runtime_artifact_path(current_file): + continue + for pattern in _SECRET_PATTERNS: + if pattern.search(line): + findings.append({ + "path": current_file, + "pattern": pattern.pattern[:80], + "line_hash": _stable_hash(line, 16), + }) + break + return findings[:20] + + +def _submit_diff_violations(repo_root: Path, contract: Dict[str, Any], proposed_action: Dict[str, Any]) -> List[Dict[str, Any]]: + violations: List[Dict[str, Any]] = [] + allowed_paths = contract.get("allowed_write_paths") or [] + forbidden_paths = contract.get("forbidden_paths") or [] + changed_paths = _code_changed_paths(repo_root) + forbidden_changed = [path for path in changed_paths if _is_forbidden_path(path, forbidden_paths)] + outside_allowed = [ + path + for path in changed_paths + if not _is_forbidden_path(path, forbidden_paths) and not _path_under_allowed(path, allowed_paths) + ] + if forbidden_changed: + violations.append({ + "code": "SUBMIT_CHANGED_FORBIDDEN_PATH", + "message": "SUBMIT_PATCH cannot proceed while forbidden paths are changed.", + "paths": forbidden_changed, + }) + if outside_allowed: + violations.append({ + "code": "SUBMIT_CHANGED_PATH_OUT_OF_CONTRACT", + "message": "SUBMIT_PATCH cannot include changed paths outside allowed_write_paths.", + "paths": outside_allowed, + }) + secret_findings = _secret_findings_in_diff(repo_root) + if secret_findings: + violations.append({ + "code": "SUBMIT_SECRET_PATTERN_IN_DIFF", + "message": "SUBMIT_PATCH cannot proceed while the diff contains secret-like additions.", + "findings": secret_findings, + }) + contamination = contract.get("contamination_status") or {} + contamination_status = str(contamination.get("status") or "clean") + if contamination_status not in {"clean", "none"} and not bool(proposed_action.get("contamination_quarantine_ack")): + violations.append({ + "code": "SUBMIT_CONTAMINATION_NOT_CLEAN", + "message": "SUBMIT_PATCH requires clean contamination status or explicit quarantine acknowledgement.", + "status": contamination_status, + "quarantined_refs": contamination.get("quarantined_refs") or [], + }) + return violations + + +def _checkpoint_stage(action: Dict[str, Any], outcome: str) -> Optional[str]: + action_type = str(action.get("type") or "") + outcome_text = str(outcome or "").lower() + if action_type in {"SEARCH_CODE", "READ_FILE", "LSP_SYMBOL"}: + return "after_localization" + if action_type == "RUN_TEST" and outcome_text in {"failed", "fail", "timed_out", "error"}: + return "after_failing_test" + if action_type == "EDIT_FILE": + return "before_edit" + if action_type == "SUBMIT_PATCH": + return "before_submit" + return None + + +def record_replay_checkpoint( + task_contract: str | os.PathLike[str] | Dict[str, Any], + stage: str, + *, + repo: str | os.PathLike[str] | None = None, + action: Optional[Dict[str, Any]] = None, + observation: Any = None, + outcome: str = "observed", +) -> Dict[str, Any]: + """Create a branchable proof checkpoint for the contract runtime.""" + + repo_root = _resolve_repo_root(repo) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + contract = compiled.get("contract") or {} + task_id = str(contract.get("task_id") or "unknown") + events_checked = _load_events_checked(repo_root, task_id) + events = list(events_checked.get("records") or []) + now = _now_iso() + checkpoint = _sanitize_obj({ + "format": "dhee.replay_checkpoint.v1", + "checkpoint_id": "chk_" + _stable_hash({ + "task_id": task_id, + "stage": stage, + "created_at": now, + "event_count": len(events), + }, 18), + "created_at": now, + "task_id": task_id, + "stage": stage, + "repo": str(repo_root), + "branch": _git_out(repo_root, ["branch", "--show-current"]), + "head_commit": _git_out(repo_root, ["rev-parse", "--short", "HEAD"]), + "status_porcelain": _git_out(repo_root, ["status", "--porcelain=v1", "--untracked-files=all"]), + "diff_stat": _git_out(repo_root, ["diff", "--stat"]), + "event_count": len(events), + "action": action or {}, + "outcome": outcome, + "observation": observation, + "rollback_point": _git_out(repo_root, ["rev-parse", "HEAD"]), + "proof": { + "contract_hash": _stable_hash(compiled, 20), + "verification_card": (contract.get("verification_card") or {}).get("schema_version"), + "contamination_status": (contract.get("contamination_status") or {}).get("status"), + }, + }) + root = _task_run_dir(repo_root, task_id) / "checkpoints" + path = root / f"{stage}_{checkpoint['checkpoint_id']}.json" + write_result = write_json_atomic(path, checkpoint, sanitize=_sanitize_obj) + if not write_result.get("ok"): + diagnostic = write_result.get("diagnostic") or {} + raise RuntimeError(diagnostic.get("message") or f"failed to write replay checkpoint {path}") + return { + "format": "dhee_replay_checkpoint_record.v1", + "checkpoint": checkpoint, + "paths": {"checkpoint": str(path), "dir": str(root)}, + } + + +def build_proof_bundle( + task_contract: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + strict: bool = False, + persist: bool = True, +) -> Dict[str, Any]: + """Build the auditable proof bundle for a contract run. + + The bundle contains observations and pointers only. It does not expose + hidden reasoning or raw memory bodies. + """ + + repo_root = _resolve_repo_root(repo) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + contract = compiled.get("contract") or {} + task_id = str(contract.get("task_id") or "unknown") + events_checked = _load_events_checked(repo_root, task_id) + events = list(events_checked.get("records") or []) + accepted_events = _accepted_events(events) + passed_tests = _passed_test_commands(events) + required_tests = [str(command) for command in contract.get("must_run") or []] + missing_tests = [ + command + for command in required_tests + if not any(_command_allowed(passed_command, [command]) for passed_command in passed_tests) + ] + failed_tests = [ + { + "command": str((_event_action(event) or {}).get("command") or ""), + "outcome": event.get("outcome"), + "event_id": event.get("event_id"), + } + for event in events + if (_event_action(event) or {}).get("type") == "RUN_TEST" + and not _outcome_is_success(event.get("outcome")) + ] + changed_paths = _code_changed_paths(repo_root) + forbidden_changed = [path for path in changed_paths if _is_forbidden_path(path, contract.get("forbidden_paths") or [])] + outside_allowed = [ + path + for path in changed_paths + if not _is_forbidden_path(path, contract.get("forbidden_paths") or []) + and not _path_under_allowed(path, contract.get("allowed_write_paths") or []) + ] + secret_findings = _secret_findings_in_diff(repo_root) + contamination = contract.get("contamination_status") or {} + contamination_clean = str(contamination.get("status") or "clean") in {"clean", "none"} + verifier_passed = not missing_tests and not forbidden_changed and not outside_allowed and not secret_findings and contamination_clean + action_trace = [] + for event in events: + action = _event_action(event) + decision = event.get("decision") or {} + matched = decision.get("matched_contract_action") or {} + action_trace.append({ + "event_id": event.get("event_id"), + "created_at": event.get("created_at"), + "action_id": action.get("action_id") or matched.get("action_id"), + "type": action.get("type"), + "target": action.get("path") or action.get("command") or action.get("query") or action.get("summary") or action.get("category"), + "decision": decision.get("decision"), + "accepted": event.get("accepted"), + "outcome": event.get("outcome"), + }) + context_items = [ + { + "kind": item.get("kind"), + "title": item.get("title"), + "evidence_pointer": item.get("evidence_pointer"), + "why_included": item.get("why_included"), + "token_cost": item.get("token_cost"), + "confidence": item.get("confidence"), + "expected_utility": item.get("expected_utility"), + } + for item in (contract.get("compiled_context") or {}).get("items") or [] + ] + skills_used = sorted({ + str(value) + for event in events + for value in [ + (_event_action(event) or {}).get("skill_id"), + ((event.get("observation") or {}) if isinstance(event.get("observation"), dict) else {}).get("skill_id"), + ] + if value + }) + bundle = _sanitize_obj({ + "schema_version": PROOF_BUNDLE_SCHEMA, + "generated_at": _now_iso(), + "contract_id": task_id, + "contract_hash": _stable_hash(compiled, 20), + "repo": str(repo_root), + "branch_state": { + "branch": _git_out(repo_root, ["branch", "--show-current"]), + "head_commit": _git_out(repo_root, ["rev-parse", "--short", "HEAD"]), + "dirty": bool(changed_paths), + }, + "action_trace": action_trace, + "files_changed": changed_paths, + "tests_run": [ + { + "command": str((_event_action(event) or {}).get("command") or ""), + "outcome": event.get("outcome"), + "event_id": event.get("event_id"), + } + for event in accepted_events + if (_event_action(event) or {}).get("type") == "RUN_TEST" + ], + "verifier_result": { + "status": "passed" if verifier_passed else "blocked", + "required_tests": required_tests, + "passed_tests": passed_tests, + "missing_tests": missing_tests, + "failed_tests": failed_tests, + "forbidden_changed_paths": forbidden_changed, + "out_of_contract_changed_paths": outside_allowed, + "secret_findings": secret_findings, + "verification_card": contract.get("verification_card") or {}, + }, + "context_used": context_items, + "memories_used": [ + { + "kind": pointer.get("kind"), + "evidence_pointer": pointer.get("evidence_pointer") or pointer.get("ref") or pointer.get("id"), + "why_included": pointer.get("why_included"), + "confidence": pointer.get("confidence"), + "content_hash": pointer.get("content_hash"), + } + for pointer in contract.get("memory_pointers") or [] + ], + "skills_used": skills_used, + "contamination_status": contamination, + "policy": { + "raw_evidence_bodies_excluded": True, + "hidden_reasoning_excluded": True, + "strict": bool(strict), + }, + "runtime_state_diagnostics": events_checked.get("diagnostics") or [], + }) + paths: Dict[str, str] = {} + if persist: + proof_path = _task_run_dir(repo_root, task_id) / "proof_bundle.json" + write_result = write_json_atomic(proof_path, bundle, sanitize=_sanitize_obj) + if not write_result.get("ok"): + diagnostic = write_result.get("diagnostic") or {} + raise RuntimeError(diagnostic.get("message") or f"failed to write proof bundle {proof_path}") + paths["proof_bundle"] = str(proof_path) + return { + "format": "dhee_contract_proof_bundle.v1", + "proof_bundle": bundle, + "paths": paths, + } + + +def _goal_token_overlap(goal: str, query: str) -> bool: + goal_tokens = set(_tokens(goal)) + query_tokens = set(_tokens(query)) + if not goal_tokens or not query_tokens: + return False + return bool(goal_tokens & query_tokens) + + +def _planned_targets(actions: Iterable[Dict[str, Any]], action_type: str, field: str) -> List[str]: + out: List[str] = [] + for action in actions or []: + if action.get("type") == action_type and action.get(field): + value = str(action.get(field)) + if value not in out: + out.append(value) + return out + + +def _command_allowed(command: str, must_run: Iterable[str]) -> bool: + text = str(command or "").strip() + if not text: + return False + for expected in must_run or []: + expected_text = str(expected or "").strip() + if text == expected_text or text.startswith(expected_text + " "): + return True + return False + + +def _supervise_by_type( + *, + repo_root: Path, + contract: Dict[str, Any], + planned_actions: List[Dict[str, Any]], + proposed_action: Dict[str, Any], + matched_action: Optional[Dict[str, Any]], + events: List[Dict[str, Any]], + interpreted_readiness: str, +) -> List[Dict[str, Any]]: + violations: List[Dict[str, Any]] = [] + + def deny(code: str, message: str, **extra: Any) -> None: + violations.append({"code": code, "message": message, **extra}) + + action_type = str(proposed_action.get("type") or "") + if action_type not in ACTION_TYPES: + deny("UNKNOWN_ACTION_TYPE", f"Unknown action type {action_type!r}.") + return violations + + if interpreted_readiness == "blocked" and action_type not in _RECOVERY_ACTIONS: + deny("CONTRACT_NOT_READY", "Interpreted contract is blocked; only recovery search or user clarification is allowed.") + + observed_ids = _observed_action_ids(events, planned_actions) + if matched_action: + missing = [dep for dep in matched_action.get("requires") or [] if dep not in observed_ids] + if missing: + deny( + "ACTION_DEPENDENCY_UNSATISFIED", + "Action has hard dependencies that have not been observed yet.", + action_id=matched_action.get("action_id"), + missing_action_ids=missing, + ) + + allowed_paths = contract.get("allowed_write_paths") or [] + forbidden_paths = contract.get("forbidden_paths") or [] + relevant_files = set(str(path) for path in contract.get("relevant_files") or []) + + if action_type == "READ_FILE": + path = str(proposed_action.get("path") or "") + resolved = _safe_repo_path(repo_root, path) + if resolved is None: + deny("UNSAFE_READ_PATH", "READ_FILE path is absolute or escapes the repo.", path=path) + elif _is_forbidden_path(path, forbidden_paths): + deny("READ_PATH_FORBIDDEN", "READ_FILE targets a forbidden path.", path=path) + elif path not in relevant_files and not _path_under_allowed(path, allowed_paths): + deny("READ_PATH_OUT_OF_CONTRACT", "READ_FILE target is outside relevant_files and allowed_write_paths.", path=path) + elif not resolved.exists(): + deny("READ_PATH_MISSING", "READ_FILE target does not exist in this checkout.", path=path) + + elif action_type == "SEARCH_CODE": + query = str(proposed_action.get("query") or "") + scope = str(proposed_action.get("scope") or ".") + planned_queries = _planned_targets(planned_actions, "SEARCH_CODE", "query") + if not query.strip(): + deny("EMPTY_SEARCH_QUERY", "SEARCH_CODE requires a query.") + elif query not in planned_queries and not _goal_token_overlap(str(contract.get("goal") or ""), query): + deny("SEARCH_QUERY_OUT_OF_CONTRACT", "SEARCH_CODE query does not overlap the compiled goal.", query=query) + resolved = repo_root if scope in {"", "."} else _safe_repo_path(repo_root, scope) + if resolved is None: + deny("UNSAFE_SEARCH_SCOPE", "SEARCH_CODE scope is absolute or escapes the repo.", scope=scope) + elif _is_forbidden_path(scope, forbidden_paths): + deny("SEARCH_SCOPE_FORBIDDEN", "SEARCH_CODE targets a forbidden path.", scope=scope) + + elif action_type == "RUN_TEST": + command = str(proposed_action.get("command") or "") + if not _command_allowed(command, contract.get("must_run") or []): + deny("TEST_COMMAND_OUT_OF_CONTRACT", "RUN_TEST command must match the compiled must_run list.", command=command) + + elif action_type == "EDIT_FILE": + path = str(proposed_action.get("path") or "") + resolved = _safe_repo_path(repo_root, path) + if resolved is None: + deny("UNSAFE_EDIT_PATH", "EDIT_FILE path is absolute or escapes the repo.", path=path) + elif _is_forbidden_path(path, forbidden_paths): + deny("EDIT_PATH_FORBIDDEN", "EDIT_FILE targets a forbidden path.", path=path) + elif not _path_under_allowed(path, allowed_paths): + deny("EDIT_PATH_OUTSIDE_ALLOWED", "EDIT_FILE target is outside allowed_write_paths.", path=path) + if not proposed_action.get("patch"): + deny("MISSING_EDIT_PATCH", "EDIT_FILE requires a unified diff patch.", path=path) + if resolved is not None and resolved.exists() and path not in _observed_read_paths(events): + deny("EDIT_REQUIRES_READ_OBSERVATION", "EDIT_FILE requires a prior observed READ_FILE for the same path.", path=path) + proof = _edit_proof(proposed_action) + missing_proof: List[str] = [] + if not str(proof.get("edit_span") or "").strip(): + missing_proof.append("edit_span") + if not str(proof.get("invariant") or "").strip(): + missing_proof.append("invariant") + if not _as_list(proof.get("related_tests")): + missing_proof.append("related_tests") + if not str(proof.get("rollback_point") or "").strip(): + missing_proof.append("rollback_point") + if missing_proof: + deny( + "EDIT_PROOF_OBLIGATION_MISSING", + "EDIT_FILE requires edit_span, invariant, related_tests, and rollback_point proof fields.", + missing=missing_proof, + ) + if proof.get("edit_span") and not _edit_span_is_valid(proof.get("edit_span"), path): + deny( + "EDIT_SPAN_INVALID", + "EDIT_FILE edit_span must identify the edited file and a valid line range.", + edit_span=proof.get("edit_span"), + ) + if proof.get("rollback_point") and not _git_out(repo_root, ["rev-parse", "--verify", str(proof.get("rollback_point"))]): + deny( + "EDIT_ROLLBACK_POINT_INVALID", + "EDIT_FILE rollback_point must resolve to a git object in this checkout.", + rollback_point=proof.get("rollback_point"), + ) + related_tests = _as_list(proof.get("related_tests")) + if related_tests and not all(_command_allowed(test, contract.get("must_run") or []) for test in related_tests): + deny( + "EDIT_RELATED_TEST_OUT_OF_CONTRACT", + "EDIT_FILE related_tests must be selected from the compiled must_run verifier list.", + related_tests=related_tests, + ) + + elif action_type == "ASK_USER": + if not str(proposed_action.get("question") or "").strip(): + deny("MISSING_USER_QUESTION", "ASK_USER requires a question.") + + elif action_type == "SPAWN_SUBAGENT": + permissions = {str(item) for item in proposed_action.get("permissions") or []} + if not str(proposed_action.get("role") or "").strip() or not str(proposed_action.get("task") or "").strip(): + deny("INVALID_SUBAGENT_REQUEST", "SPAWN_SUBAGENT requires role and task.") + forbidden = sorted(permissions & _FORBIDDEN_SUBAGENT_PERMISSIONS) + if forbidden: + deny("SUBAGENT_PERMISSION_FORBIDDEN", "SPAWN_SUBAGENT requests forbidden permissions.", permissions=forbidden) + + elif action_type == "WRITE_MEMORY_NOTE": + if not str(proposed_action.get("category") or "").strip() or not str(proposed_action.get("content") or "").strip(): + deny("INVALID_MEMORY_NOTE", "WRITE_MEMORY_NOTE requires category and content.") + + elif action_type == "SUBMIT_PATCH": + tests = [str(item) for item in proposed_action.get("tests") or []] + if not str(proposed_action.get("summary") or "").strip(): + deny("MISSING_PATCH_SUMMARY", "SUBMIT_PATCH requires a summary.") + if tests and not all(_command_allowed(test, contract.get("must_run") or []) for test in tests): + deny("SUBMIT_TESTS_OUT_OF_CONTRACT", "SUBMIT_PATCH tests must come from the compiled must_run list.", tests=tests) + passed = _passed_test_commands(events) + missing_tests = [ + str(command) + for command in contract.get("must_run") or [] + if not any(_command_allowed(passed_command, [str(command)]) for passed_command in passed) + ] + if missing_tests: + deny("SUBMIT_REQUIRES_PASSING_TESTS", "SUBMIT_PATCH requires every compiled must_run command to be observed as passed.", missing_tests=missing_tests) + for violation in _submit_diff_violations(repo_root, contract, proposed_action): + violations.append(violation) + + return violations + + +def supervise_action( + task_contract: str | os.PathLike[str] | Dict[str, Any], + proposed_action: Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + strict: bool = False, +) -> Dict[str, Any]: + """Decide whether a proposed action is allowed by a compiled contract.""" + + repo_root = _resolve_repo_root(repo) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + interpretation = interpret_task_contract(compiled, repo=repo_root, strict=strict) + contract = compiled.get("contract") or {} + planned_actions = list(compiled.get("actions") or []) + action = _sanitize_obj(dict(proposed_action or {})) + task_id = str(contract.get("task_id") or "unknown") + events_checked = _load_events_checked(repo_root, task_id) + events = list(events_checked.get("records") or []) + matched_action = _match_planned_action(planned_actions, action) + if matched_action and matched_action.get("action_id") and not action.get("action_id"): + action["action_id"] = matched_action.get("action_id") + + violations = _supervise_by_type( + repo_root=repo_root, + contract=contract, + planned_actions=planned_actions, + proposed_action=action, + matched_action=matched_action, + events=events, + interpreted_readiness=str(interpretation.get("readiness") or ""), + ) + action_type = str(action.get("type") or "") + if violations: + decision = "deny" + elif action_type == "ASK_USER" and action.get("blocking"): + decision = "needs_input" + else: + decision = "allow" + + matched = None + if matched_action: + matched = { + "action_id": matched_action.get("action_id"), + "step": matched_action.get("step"), + "type": matched_action.get("type"), + "phase": matched_action.get("phase"), + "target": matched_action.get("path") or matched_action.get("command") or matched_action.get("query") or matched_action.get("summary") or matched_action.get("category"), + "reason": matched_action.get("reason"), + "requires": matched_action.get("requires") or [], + "soft_requires": matched_action.get("soft_requires") or [], + "capabilities": matched_action.get("capabilities") or [], + "effects": matched_action.get("effects") or [], + } + observed_ids = _observed_action_ids(events, planned_actions) + + response = { + "format": SUPERVISOR_DECISION_SCHEMA, + "decision": decision, + "task_id": contract.get("task_id"), + "goal": contract.get("goal"), + "repo": str(repo_root), + "interpreted_readiness": interpretation.get("readiness"), + "proposed_action": action, + "matched_contract_action": matched, + "violations": violations, + "runtime_state": { + "event_count": len(events), + "observed_action_ids": observed_ids, + "passed_tests": _passed_test_commands(events), + "observed_read_paths": _observed_read_paths(events), + "next_allowed_actions": _next_allowed_actions(planned_actions, observed_ids), + "diagnostics": events_checked.get("diagnostics") or [], + }, + "observation_template": { + "precondition": action.get("precondition"), + "execution": action.get("execution"), + "observation": "Record compact result plus pointer to full output.", + "postcondition": action.get("postcondition"), + "memory_update": action.get("memory_update"), + }, + "policy": { + "enforced": True, + "strict": bool(strict), + "auto_execute": False, + }, + } + if action_type == "SUBMIT_PATCH": + response["proof_bundle_preview"] = build_proof_bundle( + compiled, + repo=repo_root, + strict=strict, + persist=False, + ) + return response + + +def record_observation_transition( + task_contract: str | os.PathLike[str] | Dict[str, Any], + action: Dict[str, Any], + observation: str | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + outcome: str = "observed", + next_action: Optional[Dict[str, Any]] = None, + strict: bool = False, +) -> Dict[str, Any]: + """Record a compact action observation and optional next-action decision.""" + + repo_root = _resolve_repo_root(repo) + decision = supervise_action(task_contract, action, repo=repo_root, strict=strict) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + contract = compiled.get("contract") or {} + task_id = str(contract.get("task_id") or decision.get("task_id") or "unknown") + recorded_action = _sanitize_obj(dict(action or {})) + matched = decision.get("matched_contract_action") or {} + if matched.get("action_id") and not recorded_action.get("action_id"): + recorded_action["action_id"] = matched.get("action_id") + created_at = _now_iso() + + event = _sanitize_obj({ + "format": OBSERVATION_EVENT_SCHEMA, + "event_id": "evt_" + _stable_hash({ + "task_id": task_id, + "action": recorded_action, + "observation": observation, + "next_action": next_action, + "created_at": created_at, + }, 18), + "created_at": created_at, + "task_id": task_id, + "action": recorded_action, + "decision": decision, + "accepted": decision.get("decision") in {"allow", "needs_input"} and not _outcome_is_blocked(outcome), + "outcome": outcome, + "observation": observation, + "next_action": next_action, + }) + repo_link._ensure_repo_skeleton(repo_root) + run_dir = _task_run_dir(repo_root, task_id) + events_path = run_dir / "events.jsonl" + append_result = append_jsonl_locked(events_path, event, sanitize=_sanitize_obj) + if not append_result.get("ok"): + diagnostic = append_result.get("diagnostic") or {} + raise RuntimeError(diagnostic.get("message") or f"failed to append supervisor event {events_path}") + checkpoint = None + stage = _checkpoint_stage(recorded_action, outcome) + if stage: + checkpoint = record_replay_checkpoint( + compiled, + stage, + repo=repo_root, + action=recorded_action, + observation=observation, + outcome=outcome, + ) + proof_bundle = None + if recorded_action.get("type") == "SUBMIT_PATCH": + proof_bundle = build_proof_bundle( + compiled, + repo=repo_root, + strict=strict, + persist=True, + ) + next_decision = None + if next_action is not None: + next_decision = supervise_action(compiled, next_action, repo=repo_root, strict=strict) + event["next_decision"] = next_decision + # Keep the JSONL event immutable except for next-decision availability in + # the returned response; the stored event is the observation record. + return { + "format": "dhee_contract_observation_record.v1", + "event": event, + "paths": {"events": str(events_path), "dir": str(run_dir)}, + "decision": decision, + "next_decision": next_decision, + "checkpoint": checkpoint, + "proof_bundle": proof_bundle, + } diff --git a/dhee/hooks/claude_code/__main__.py b/dhee/hooks/claude_code/__main__.py index ea5fe84..7032bdf 100644 --- a/dhee/hooks/claude_code/__main__.py +++ b/dhee/hooks/claude_code/__main__.py @@ -388,6 +388,25 @@ def _repo_context_for(cwd: str, *, query: str, limit: int = 5) -> list[dict[str, return [] +def _scene_world_route(prompt: str, *, repo: str, harness: str = "claude-code") -> dict[str, Any] | None: + """Best-effort optional predictive route hint.""" + + if not str(prompt or "").strip(): + return None + try: + from dhee.hooks.scene_world import predict_scene_world_route + + return predict_scene_world_route( + str(prompt), + repo=repo, + user_id=os.environ.get("DHEE_USER_ID", "default"), + harness=harness, + top_k=4, + ) + except Exception: + return None + + def _discover_repo_config(start: str) -> dict[str, Any]: """Find public .dhee/config.json for repo-link context.""" try: @@ -650,6 +669,7 @@ def handle_session_start(payload: dict[str, Any]) -> dict[str, Any]: repo_entries = _repo_context_for(repo_root, query=task_desc, limit=5) state_card = state_store.render_state_card() + scene_world = _scene_world_route(task_desc, repo=repo_root) if ( not doc_matches @@ -659,6 +679,7 @@ def handle_session_start(payload: dict[str, Any]) -> dict[str, Any]: and not typed.get("last_session") and not assembled.has_cognition and not repo_entries + and not scene_world and not state_card ): return {} @@ -671,6 +692,7 @@ def handle_session_start(payload: dict[str, Any]) -> dict[str, Any]: shared_task_results=shared.get("results") or [], repo_entries=repo_entries, live_messages=live.get("messages") or [], + scene_world=scene_world, state_card=state_card, ) if not xml: @@ -768,6 +790,7 @@ def handle_user_prompt(payload: dict[str, Any]) -> dict[str, Any]: repo_entries = _repo_context_for(repo, query=prompt, limit=3) state_card = state_store.render_state_card() + scene_world = _scene_world_route(prompt, repo=repo) has_signal = ( bool(doc_matches) @@ -776,6 +799,7 @@ def handle_user_prompt(payload: dict[str, Any]) -> dict[str, Any]: or bool(shared.get("task")) or bool(live.get("messages")) or bool(repo_entries) + or bool(scene_world) or bool(state_card) ) if not has_signal: @@ -797,6 +821,7 @@ def handle_user_prompt(payload: dict[str, Any]) -> dict[str, Any]: shared_task_results=shared.get("results") or [], repo_entries=repo_entries, live_messages=live.get("messages") or [], + scene_world=scene_world, state_card=state_card, ) if not xml: diff --git a/dhee/hooks/claude_code/renderer.py b/dhee/hooks/claude_code/renderer.py index de97859..afd620e 100644 --- a/dhee/hooks/claude_code/renderer.py +++ b/dhee/hooks/claude_code/renderer.py @@ -40,6 +40,7 @@ def render_context( shared_task_results: list[dict[str, Any]] | None = None, repo_entries: list[dict[str, Any]] | None = None, live_messages: list[dict[str, Any]] | None = None, + scene_world: dict[str, Any] | None = None, state_card: str | None = None, ) -> str: """Render Dhee context dict as flat XML for Claude Code injection. @@ -51,6 +52,7 @@ def render_context( (118, _live_context_block(live_messages)), (115, _edits_section(edits_block)), (113, _repo_context_block(repo_entries)), + (112, _scene_world_block(scene_world)), (110, _docs_block(doc_matches)), (105, _shared_task_block(shared_task, shared_task_results)), (100, _session_block(ctx.get("last_session"))), @@ -141,6 +143,69 @@ def _router_block() -> list[str]: return [f"{_xml_escape(_ROUTER_NUDGE)}"] +def _scene_world_block(scene_world: dict[str, Any] | None) -> list[str]: + """Predictive route hint from the optional SceneWorld sidecar.""" + + if not isinstance(scene_world, dict): + return [] + best = scene_world.get("best_action") + if not isinstance(best, dict): + return [] + + route_id = str(scene_world.get("route_id") or "") + source = str(scene_world.get("source") or "") + era = str(scene_world.get("era") or "") + active_project = str(scene_world.get("active_project") or "") + meta = scene_world.get("_scene_world") or {} + if not isinstance(meta, dict): + meta = {} + + lines = [ + "" + ] + + action = str(best.get("action") or "") + predicted = str(best.get("predicted_next_scene") or "").strip() + reaction = str(best.get("likely_user_reaction") or "").strip() + risks = best.get("risks") if isinstance(best.get("risks"), list) else [] + best_attrs = _attrs( + action=action, + reward=_fmt(best.get("expected_reward")), + confidence=_fmt(best.get("confidence")), + risks="; ".join(str(r) for r in risks[:3]), + ) + lines.append(_tag("best", best_attrs, predicted[:520] or action)) + if reaction: + lines.append(_tag("reaction", "", reaction[:280])) + + ranked = scene_world.get("ranked_actions") + if isinstance(ranked, list) and ranked: + lines.append("") + for row in ranked[:4]: + if not isinstance(row, dict): + continue + row_risks = row.get("risks") if isinstance(row.get("risks"), list) else [] + row_attrs = _attrs( + action=str(row.get("action") or ""), + reward=_fmt(row.get("expected_reward")), + confidence=_fmt(row.get("confidence")), + ) + lines.append(_tag("candidate", row_attrs, "; ".join(str(r) for r in row_risks[:2])[:220])) + lines.append("") + + lines.append("") + return lines if len(lines) > 2 else [] + + def _state_card_block(state_card: str | None) -> list[str]: if not state_card: return [] diff --git a/dhee/hooks/scene_world.py b/dhee/hooks/scene_world.py new file mode 100644 index 0000000..92a0a90 --- /dev/null +++ b/dhee/hooks/scene_world.py @@ -0,0 +1,223 @@ +"""Optional SceneWorld routing bridge for Dhee hooks and MCP. + +SceneWorld lives outside Dhee so Dhee can remain the memory substrate. This +bridge deliberately imports SankhyaWM lazily and only when explicitly enabled. +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any, Dict, Optional + + +_TRUE_VALUES = {"1", "true", "yes", "on", "enabled"} +_FALSE_VALUES = {"0", "false", "no", "off", "disabled"} + + +def _flag(name: str, *, default: bool = False) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + value = raw.strip().lower() + if value in _TRUE_VALUES: + return True + if value in _FALSE_VALUES: + return False + return default + + +def scene_world_enabled() -> bool: + """Return whether Dhee should call the external SceneWorld router.""" + + raw = ( + os.environ.get("DHEE_SCENE_WORLD_ENABLED") + or os.environ.get("DHEE_SCENE_WORLD") + or "" + ).strip().lower() + if raw == "auto": + return _discover_project(None) is not None + return raw in _TRUE_VALUES + + +def _valid_project(path: Path) -> bool: + return (path / "sankhya_wm" / "dhee_scene_world_adapter.py").exists() + + +def _candidate_projects(repo: Optional[str]) -> list[Path]: + candidates: list[Path] = [] + explicit = os.environ.get("DHEE_SCENE_WORLD_PROJECT") + if explicit: + candidates.append(Path(explicit).expanduser()) + + model_path = os.environ.get("DHEE_SCENE_WORLD_MODEL") or os.environ.get("SCENE_WORLD_MODEL_PATH") + if model_path: + model = Path(model_path).expanduser() + candidates.extend([model.parent.parent, model.parent]) + + if repo: + root = Path(repo).expanduser() + candidates.extend([root, root.parent / "sankhyaWM"]) + + cwd = Path.cwd() + candidates.extend([cwd, cwd.parent / "sankhyaWM"]) + + # Local development layout: /Desktop/Dhee and /Desktop/sankhyaWM siblings. + here = Path(__file__).resolve() + if len(here.parents) > 3: + candidates.append(here.parents[3] / "sankhyaWM") + + seen: set[str] = set() + unique: list[Path] = [] + for candidate in candidates: + try: + resolved = candidate.resolve() + except Exception: + resolved = candidate + key = str(resolved) + if key not in seen: + seen.add(key) + unique.append(resolved) + return unique + + +def _discover_project(repo: Optional[str]) -> Optional[Path]: + for candidate in _candidate_projects(repo): + if _valid_project(candidate): + return candidate + return None + + +def _discover_model_path(repo: Optional[str], project: Optional[Path]) -> Optional[Path]: + explicit = os.environ.get("DHEE_SCENE_WORLD_MODEL") or os.environ.get("SCENE_WORLD_MODEL_PATH") + if explicit: + return Path(explicit).expanduser() + candidates: list[Path] = [] + if project: + candidates.append(project / "models" / "scene_world_reward_model.json") + if repo: + candidates.append(Path(repo).expanduser() / ".dhee" / "scene_world_reward_model.json") + for candidate in candidates: + if candidate.exists(): + return candidate + return None + + +def _ensure_import_path(project: Optional[Path]) -> None: + if not project: + return + path = str(project) + if path not in sys.path: + sys.path.insert(0, path) + + +def _int_arg(value: Any, *, default: int, lower: int, upper: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + return max(lower, min(upper, parsed)) + + +def route_task( + task: str, + *, + repo: Optional[str] = None, + user_id: Optional[str] = None, + harness: str = "agent", + top_k: int = 4, + record: Optional[bool] = None, +) -> Dict[str, Any]: + """Return a status-wrapped SceneWorld route for a task. + + The shape is stable for hooks and MCP. Errors are returned as data because + this code runs on hot agent paths and must never break Dhee itself. + """ + + task = str(task or "").strip() + if not task: + return {"enabled": scene_world_enabled(), "status": "empty_task"} + if not scene_world_enabled(): + return {"enabled": False, "status": "disabled"} + + project = _discover_project(repo) + _ensure_import_path(project) + model_path = _discover_model_path(repo, project) + debug = _flag("DHEE_SCENE_WORLD_DEBUG") + + try: + from sankhya_wm.dhee_scene_world_adapter import predict_next_action + except Exception as exc: + result: Dict[str, Any] = { + "enabled": True, + "status": "unavailable", + "reason": f"{type(exc).__name__}: {exc}", + "project": str(project) if project else None, + } + return result + + try: + route = predict_next_action( + task, + user_id=user_id or os.environ.get("DHEE_USER_ID", "default"), + model_path=model_path, + model_weight=float(os.environ.get("DHEE_SCENE_WORLD_MODEL_WEIGHT", "0.7")), + provider=os.environ.get("DHEE_PROVIDER"), + data_dir=os.environ.get("DHEE_DATA_DIR"), + top_k=_int_arg(top_k, default=4, lower=1, upper=8), + record=_flag("DHEE_SCENE_WORLD_RECORD") if record is None else bool(record), + route_log_path=os.environ.get("DHEE_SCENE_WORLD_ROUTE_LOG"), + ) + except Exception as exc: + error = f"{type(exc).__name__}: {exc}" + return { + "enabled": True, + "status": "error", + "reason": error if debug else type(exc).__name__, + "project": str(project) if project else None, + "model_path": str(model_path) if model_path else None, + "harness": harness, + } + + return { + "enabled": True, + "status": "ok", + "route": route, + "project": str(project) if project else None, + "model_path": str(model_path) if model_path else None, + "harness": harness, + } + + +def predict_scene_world_route( + task: str, + *, + repo: Optional[str] = None, + user_id: Optional[str] = None, + harness: str = "agent", + top_k: int = 4, + record: Optional[bool] = None, +) -> Optional[Dict[str, Any]]: + """Return only the route payload when SceneWorld is enabled and healthy.""" + + result = route_task( + task, + repo=repo, + user_id=user_id, + harness=harness, + top_k=top_k, + record=record, + ) + if result.get("status") == "ok" and isinstance(result.get("route"), dict): + route = dict(result["route"]) + route.setdefault("_scene_world", {}) + route["_scene_world"].update( + { + "project": result.get("project"), + "model_path": result.get("model_path"), + "harness": result.get("harness"), + } + ) + return route + return None diff --git a/dhee/mcp_registry.py b/dhee/mcp_registry.py new file mode 100644 index 0000000..1448510 --- /dev/null +++ b/dhee/mcp_registry.py @@ -0,0 +1,264 @@ +"""Shared MCP tool registry for Dhee compiler/runtime surfaces.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, Iterable, List, Sequence + + +TASK_CONTRACT_TOOL_NAMES = ( + "dhee_task_contract_compile", + "dhee_task_contract_create", + "dhee_task_contract_list", + "dhee_task_contract_get", + "dhee_task_contract_import", + "dhee_task_contract_interpret", +) + +CONTRACT_RUNTIME_TOOL_NAMES = ( + "dhee_contract_supervise_action", + "dhee_contract_record_observation", + "dhee_contract_proof_bundle", + "dhee_contract_runtime_activate", + "dhee_contract_runtime_status", + "dhee_contract_runtime_deactivate", + "dhee_contract_enforcement_set", + "dhee_contract_enforcement_status", + "dhee_contract_runtime_doctor", +) + +UPDATE_CAPSULE_TOOL_NAMES = ( + "dhee_update_capsule_create", + "dhee_update_capsule_list", + "dhee_update_capsule_get", + "dhee_update_capsule_import", + "dhee_update_capsule_interpret", +) + +CONTEXT_COMPILER_TOOL_NAMES = ( + *TASK_CONTRACT_TOOL_NAMES, + *CONTRACT_RUNTIME_TOOL_NAMES, + *UPDATE_CAPSULE_TOOL_NAMES, +) + + +_TASK_COMPILE_PROPERTIES = { + "goal": {"type": "string"}, + "task": {"type": "string"}, + "query": {"type": "string"}, + "repo": {"type": "string"}, + "mode": {"type": "string"}, + "risk": {"type": "string"}, + "allowed_write_paths": {"type": "array", "items": {"type": "string"}}, + "forbidden_paths": {"type": "array", "items": {"type": "string"}}, + "must_run": {"type": "array", "items": {"type": "string"}}, + "success_criteria": {"type": "array", "items": {"type": "string"}}, + "context_budget": {"type": "object"}, + "memory_pointers": {"type": "array", "items": {"type": "object"}}, + "recent_failures": {"type": "array", "items": {"type": "object"}}, +} + +_CONTRACT_REF_PROPERTIES = { + "repo": {"type": "string"}, + "task_id": {"type": "string"}, + "id": {"type": "string"}, + "path": {"type": "string"}, + "contract": {"type": "object"}, +} + +_CAPSULE_REF_PROPERTIES = { + "repo": {"type": "string"}, + "capsule_id": {"type": "string"}, + "path": {"type": "string"}, + "capsule": {"type": "object"}, +} + + +TOOL_SPECS: Dict[str, Dict[str, Any]] = { + "dhee_task_contract_compile": { + "name": "dhee_task_contract_compile", + "description": "Compile a messy user task plus repo state into a deterministic TaskContract and typed ChotuAction plan.", + "inputSchema": {"type": "object", "properties": deepcopy(_TASK_COMPILE_PROPERTIES)}, + }, + "dhee_task_contract_create": { + "name": "dhee_task_contract_create", + "description": "Compile and store a portable TaskContract under .dhee/context/task_contracts.", + "inputSchema": { + "type": "object", + "properties": {"out": {"type": "string"}, **deepcopy(_TASK_COMPILE_PROPERTIES)}, + }, + }, + "dhee_task_contract_list": { + "name": "dhee_task_contract_list", + "description": "List portable task contracts in a repo.", + "inputSchema": {"type": "object", "properties": {"repo": {"type": "string"}}}, + }, + "dhee_task_contract_get": { + "name": "dhee_task_contract_get", + "description": "Get one task contract's markdown and machine JSON.", + "inputSchema": { + "type": "object", + "properties": {"repo": {"type": "string"}, "task_id": {"type": "string"}, "id": {"type": "string"}}, + }, + }, + "dhee_task_contract_import": { + "name": "dhee_task_contract_import", + "description": "Import a portable task contract into a repo and index it.", + "inputSchema": { + "type": "object", + "properties": {"repo": {"type": "string"}, "path": {"type": "string"}}, + "required": ["path"], + }, + }, + "dhee_task_contract_interpret": { + "name": "dhee_task_contract_interpret", + "description": "Interpret a portable TaskContract on this repo and return executable ChotuAction readiness without running tools.", + "inputSchema": { + "type": "object", + "properties": {**deepcopy(_CONTRACT_REF_PROPERTIES), "strict": {"type": "boolean"}}, + }, + }, + "dhee_contract_supervise_action": { + "name": "dhee_contract_supervise_action", + "description": "Runtime gate: allow or deny a proposed ChotuAction against an interpreted task contract.", + "inputSchema": { + "type": "object", + "properties": { + **deepcopy(_CONTRACT_REF_PROPERTIES), + "action": {"type": "object"}, + "proposed_action": {"type": "object"}, + "strict": {"type": "boolean"}, + }, + }, + }, + "dhee_contract_record_observation": { + "name": "dhee_contract_record_observation", + "description": "Record a compact observation-to-next-action transition for a supervised task contract.", + "inputSchema": { + "type": "object", + "properties": { + **deepcopy(_CONTRACT_REF_PROPERTIES), + "action": {"type": "object"}, + "observation": {}, + "outcome": {"type": "string"}, + "next_action": {"type": "object"}, + "strict": {"type": "boolean"}, + }, + }, + }, + "dhee_contract_proof_bundle": { + "name": "dhee_contract_proof_bundle", + "description": "Build the proof bundle for a task contract run: action trace, changed files, tests, verifier result, context pointers, memory pointers, skills, and contamination status.", + "inputSchema": { + "type": "object", + "properties": {**deepcopy(_CONTRACT_REF_PROPERTIES), "strict": {"type": "boolean"}, "persist": {"type": "boolean"}}, + }, + }, + "dhee_contract_runtime_activate": { + "name": "dhee_contract_runtime_activate", + "description": "Bind a TaskContract as the active repo runtime so router/native actions are supervised.", + "inputSchema": { + "type": "object", + "properties": { + **deepcopy(_CONTRACT_REF_PROPERTIES), + "strict": {"type": "boolean"}, + "force": {"type": "boolean"}, + "agent_id": {"type": "string"}, + "harness": {"type": "string"}, + }, + }, + }, + "dhee_contract_runtime_status": { + "name": "dhee_contract_runtime_status", + "description": "Show the active task-contract runtime bound to a repo, including readiness, enforcement, diagnostics, and event paths.", + "inputSchema": {"type": "object", "properties": {"repo": {"type": "string"}}}, + }, + "dhee_contract_runtime_deactivate": { + "name": "dhee_contract_runtime_deactivate", + "description": "Deactivate the active task-contract runtime for a repo without deleting observation history.", + "inputSchema": { + "type": "object", + "properties": {"repo": {"type": "string"}, "agent_id": {"type": "string"}, "reason": {"type": "string"}}, + }, + }, + "dhee_contract_enforcement_set": { + "name": "dhee_contract_enforcement_set", + "description": "Set repo contract enforcement policy to off, warn, or deny.", + "inputSchema": { + "type": "object", + "properties": { + "repo": {"type": "string"}, + "mode": {"type": "string", "enum": ["off", "warn", "deny"]}, + "agent_id": {"type": "string"}, + "reason": {"type": "string"}, + }, + "required": ["mode"], + }, + }, + "dhee_contract_enforcement_status": { + "name": "dhee_contract_enforcement_status", + "description": "Show the effective repo contract enforcement policy, including env-forced deny and diagnostics.", + "inputSchema": {"type": "object", "properties": {"repo": {"type": "string"}}}, + }, + "dhee_contract_runtime_doctor": { + "name": "dhee_contract_runtime_doctor", + "description": "Doctor the contract runtime and report protected, partially_protected, or unprotected with bypass risks.", + "inputSchema": {"type": "object", "properties": {"repo": {"type": "string"}}}, + }, + "dhee_update_capsule_create": { + "name": "dhee_update_capsule_create", + "description": "Create a sanitized repo-shareable UpdateCapsule under .dhee/context/capsules and index it as kind=update_capsule.", + "inputSchema": { + "type": "object", + "properties": { + "repo": {"type": "string"}, + "since": {"type": "string"}, + "task_id": {"type": "string"}, + "out": {"type": "string"}, + "title": {"type": "string"}, + "summary": {"type": "string"}, + "commands": {"type": "array", "items": {"type": "string"}}, + "evidence": {"type": "array", "items": {"type": "object"}}, + }, + }, + }, + "dhee_update_capsule_list": { + "name": "dhee_update_capsule_list", + "description": "List update capsules in a repo.", + "inputSchema": {"type": "object", "properties": {"repo": {"type": "string"}}}, + }, + "dhee_update_capsule_get": { + "name": "dhee_update_capsule_get", + "description": "Get one update capsule's markdown and machine JSON.", + "inputSchema": { + "type": "object", + "properties": {"repo": {"type": "string"}, "capsule_id": {"type": "string"}}, + "required": ["capsule_id"], + }, + }, + "dhee_update_capsule_import": { + "name": "dhee_update_capsule_import", + "description": "Import a sanitized update capsule into a repo and index it as kind=update_capsule.", + "inputSchema": { + "type": "object", + "properties": {"repo": {"type": "string"}, "path": {"type": "string"}, "allow_private": {"type": "boolean"}}, + "required": ["path"], + }, + }, + "dhee_update_capsule_interpret": { + "name": "dhee_update_capsule_interpret", + "description": "Interpret a compiled update capsule on this repo and return a reproduction plan without auto-applying edits.", + "inputSchema": { + "type": "object", + "properties": {**deepcopy(_CAPSULE_REF_PROPERTIES), "strict": {"type": "boolean"}}, + }, + }, +} + + +def tool_specs(names: Iterable[str] = CONTEXT_COMPILER_TOOL_NAMES) -> List[Dict[str, Any]]: + return [deepcopy(TOOL_SPECS[name]) for name in names] + + +def make_tools(tool_cls: Any, names: Sequence[str] = CONTEXT_COMPILER_TOOL_NAMES) -> List[Any]: + return [tool_cls(**spec) for spec in tool_specs(names)] diff --git a/dhee/mcp_server.py b/dhee/mcp_server.py index c3dd805..1298bde 100644 --- a/dhee/mcp_server.py +++ b/dhee/mcp_server.py @@ -48,6 +48,7 @@ from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent +from dhee.mcp_registry import CONTEXT_COMPILER_TOOL_NAMES, make_tools from dhee.memory.main import FullMemory from dhee.configs.base import ( MemoryConfig, @@ -406,6 +407,89 @@ def get_buddhi(): }, }, ), + Tool( + name="dhee_scene_world_route", + description=( + "Predict likely outcomes for candidate next actions using the optional " + "SceneWorld world-model sidecar. Use before choosing a high-stakes " + "agent action when DHEE_SCENE_WORLD_ENABLED=1." + ), + inputSchema={ + "type": "object", + "properties": { + "task": {"type": "string", "description": "Current task or scene"}, + "query": {"type": "string", "description": "Alias for task"}, + "repo": {"type": "string", "description": "Repo/workspace path"}, + "user_id": {"type": "string", "description": "User identifier"}, + "harness": {"type": "string", "description": "Harness/runtime id"}, + "top_k": {"type": "integer", "description": "Number of ranked actions to return"}, + "record": {"type": "boolean", "description": "Record the route trace when route logging is configured"}, + }, + }, + ), + Tool( + name="dhee_scene_compile", + description="Compile a private TemporalScene card from evidence pointers, memory rows, agent outputs, browser captures, or admitted derivatives.", + inputSchema={ + "type": "object", + "properties": { + "evidence": {"type": "array", "items": {"type": "object"}, "description": "Evidence rows or pointers to compile."}, + "query": {"type": "string", "description": "Optional query/task used when include_recent_memories is true."}, + "task": {"type": "string", "description": "Current user task or scene goal."}, + "title": {"type": "string", "description": "Optional scene title."}, + "repo": {"type": "string", "description": "Repo/workspace path to attach as a scene ref."}, + "user_id": {"type": "string", "description": "User id (default: default)."}, + "privacy_scope": {"type": "string", "description": "Scene privacy scope (default: personal)."}, + "store_dir": {"type": "string", "description": "Optional scene store override."}, + "save": {"type": "boolean", "description": "Persist scene to the private scene store (default: true)."}, + "include_recent_memories": {"type": "boolean", "description": "If evidence is empty, search recent/relevant memory and compile from those results."}, + "include_repo_context": {"type": "boolean", "description": "Collect repo-shared context entries as evidence."}, + "include_session": {"type": "boolean", "description": "Fetch latest compact session digest as evidence when available."}, + "include_shared_task_results": {"type": "boolean", "description": "Fetch active shared-task result packets as evidence when available."}, + "include_artifacts": {"type": "boolean", "description": "Fetch recent artifact summaries as evidence when available."}, + "include_live_sources": {"type": "boolean", "description": "Fetch session, shared-task results, and artifacts for requested sources."}, + "sources": {"type": "array", "items": {"type": "string"}, "description": "Evidence sources: evidence, memory, repo_context, session, shared_task_results, artifacts."}, + "session": {"type": "object", "description": "Optional session digest to compile as evidence."}, + "shared_task_results": {"description": "Optional shared-task result rows or response object."}, + "artifacts": {"description": "Optional artifact summaries or response object."}, + "limit": {"type": "integer", "description": "Memory/evidence limit when include_recent_memories is true."}, + }, + }, + ), + Tool( + name="dhee_scene_search", + description="Search private TemporalScene cards and return prompt-safe summaries with evidence refs only.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "user_id": {"type": "string"}, + "repo": {"type": "string"}, + "limit": {"type": "integer"}, + "store_dir": {"type": "string"}, + "include_personal": {"type": "boolean"}, + }, + "required": ["query"], + }, + ), + Tool( + name="dhee_context_pack", + description="Build a hard-budget context pack from ranked scene cards. Raw evidence expands only by pointer.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "user_id": {"type": "string"}, + "repo": {"type": "string"}, + "token_budget": {"type": "integer"}, + "limit": {"type": "integer"}, + "store_dir": {"type": "string"}, + "include_personal": {"type": "boolean"}, + }, + "required": ["query"], + }, + ), + *make_tools(Tool, CONTEXT_COMPILER_TOOL_NAMES), Tool( name="get_last_session", description=( @@ -1310,6 +1394,396 @@ def _handle_dhee_context(memory, args): return result +def _handle_dhee_scene_world_route(_memory, args): + task = str(args.get("task") or args.get("query") or "") + if not task.strip(): + return {"error": "task is required"} + try: + from dhee.hooks.scene_world import route_task + + return route_task( + task, + repo=args.get("repo"), + user_id=_default_user_id(args), + harness=str(args.get("harness") or os.environ.get("DHEE_HARNESS") or _default_agent_id(args)), + top_k=_bounded_limit(args, "top_k", 4, 8), + record=args.get("record") if "record" in args else None, + ) + except Exception as exc: + return {"enabled": False, "status": "error", "reason": f"{type(exc).__name__}: {exc}"} + + +def _scene_evidence_from_args(memory, args: Dict[str, Any]) -> List[Dict[str, Any]]: + from dhee.temporal_scenes import collect_live_scene_sources, collect_scene_evidence + + sources = set(str(source) for source in (args.get("sources") or ["evidence"])) + if args.get("include_recent_memories"): + sources.add("memory") + if args.get("include_repo_context"): + sources.add("repo_context") + if args.get("include_session"): + sources.add("session") + if args.get("include_shared_task_results"): + sources.add("shared_task_results") + if args.get("include_artifacts"): + sources.add("artifacts") + if args.get("include_live_sources"): + sources.update({"session", "shared_task_results", "artifacts"}) + needs_live = bool(args.get("include_live_sources")) or any( + source in sources + for source in ("session", "session_digest", "shared_task_results", "shared_task", "artifacts", "artifact") + ) + live: Dict[str, Any] = {} + if needs_live: + live_db = None + if any(source in sources for source in ("shared_task_results", "shared_task", "artifacts", "artifact")): + live_db = get_db() + live = collect_live_scene_sources( + db=live_db, + repo=args.get("repo"), + user_id=_default_user_id(args), + agent_id=_default_agent_id(args), + limit=_bounded_limit(args, "limit", 8, 50), + include_session=("session" in sources or "session_digest" in sources) and not args.get("session"), + include_shared_task_results=("shared_task_results" in sources or "shared_task" in sources) and not args.get("shared_task_results"), + include_artifacts=("artifacts" in sources or "artifact" in sources) and not args.get("artifacts"), + ) + mem = memory or (get_memory() if "memory" in sources else None) + return collect_scene_evidence( + evidence=args.get("evidence") or [], + memory=mem, + query=str(args.get("query") or args.get("task") or ""), + user_id=_default_user_id(args), + repo=args.get("repo"), + session=args.get("session") or live.get("session"), + shared_task_results=args.get("shared_task_results") or live.get("shared_task_results"), + artifacts=args.get("artifacts") or live.get("artifacts"), + sources=sources, + limit=_bounded_limit(args, "limit", 8, 50), + ) + + +def _handle_dhee_scene_compile(memory, args): + from dhee.temporal_scenes import compile_scene + + evidence = _scene_evidence_from_args(memory, args) + if not evidence: + return {"error": "evidence is required unless include_recent_memories returns results"} + scene = compile_scene( + evidence, + user_id=_default_user_id(args), + repo=args.get("repo"), + task=str(args.get("task") or args.get("query") or ""), + privacy_scope=str(args.get("privacy_scope") or "personal"), + title=args.get("title"), + store_dir=args.get("store_dir"), + save=args.get("save") is not False, + ) + return { + "format": "dhee_scene_compile.v1", + "scene": scene.to_dict(), + "card": scene.to_card(), + } + + +def _handle_dhee_scene_search(_memory, args): + from dhee.temporal_scenes import search_scenes + + query = str(args.get("query") or "").strip() + if not query: + return {"error": "query is required"} + scenes = search_scenes( + query, + user_id=_default_user_id(args), + repo=args.get("repo"), + limit=_bounded_limit(args, "limit", 5, 30), + store_dir=args.get("store_dir"), + include_personal=args.get("include_personal") is not False, + ) + return { + "format": "dhee_scene_search.v1", + "results": [scene.to_card() for scene in scenes], + } + + +def _handle_dhee_context_pack(_memory, args): + from dhee.temporal_scenes import build_context_pack + + query = str(args.get("query") or "").strip() + if not query: + return {"error": "query is required"} + try: + budget = int(args.get("token_budget") or 1200) + except (TypeError, ValueError): + budget = 1200 + return build_context_pack( + query, + user_id=_default_user_id(args), + repo=args.get("repo"), + token_budget=max(128, min(20_000, budget)), + limit=_bounded_limit(args, "limit", 5, 30), + store_dir=args.get("store_dir"), + include_personal=args.get("include_personal") is not False, + ) + + +def _handle_dhee_task_contract_compile(_memory, args): + from dhee.task_contracts import compile_task_contract + + goal = str(args.get("goal") or args.get("task") or args.get("query") or "").strip() + if not goal: + return {"error": "goal, task, or query is required"} + return compile_task_contract( + goal, + repo=args.get("repo"), + mode=str(args.get("mode") or "patch"), + risk=args.get("risk"), + allowed_write_paths=args.get("allowed_write_paths"), + forbidden_paths=args.get("forbidden_paths"), + must_run=args.get("must_run"), + success_criteria=args.get("success_criteria"), + context_budget=args.get("context_budget"), + memory_pointers=args.get("memory_pointers"), + recent_failures=args.get("recent_failures"), + ) + + +def _task_goal_from_args(args: Dict[str, Any]) -> str: + return str(args.get("goal") or args.get("task") or args.get("query") or "").strip() + + +def _handle_dhee_task_contract_create(_memory, args): + from dhee.task_contracts import create_task_contract + + goal = _task_goal_from_args(args) + if not goal: + return {"error": "goal, task, or query is required"} + return create_task_contract( + goal, + repo=args.get("repo"), + out=args.get("out"), + mode=str(args.get("mode") or "patch"), + risk=args.get("risk"), + allowed_write_paths=args.get("allowed_write_paths"), + forbidden_paths=args.get("forbidden_paths"), + must_run=args.get("must_run"), + success_criteria=args.get("success_criteria"), + context_budget=args.get("context_budget"), + memory_pointers=args.get("memory_pointers"), + recent_failures=args.get("recent_failures"), + ) + + +def _handle_dhee_task_contract_list(_memory, args): + from dhee.task_contracts import list_task_contracts + + return { + "format": "dhee_task_contract_list.v1", + "results": list_task_contracts(repo=args.get("repo")), + } + + +def _handle_dhee_task_contract_get(_memory, args): + from dhee.task_contracts import get_task_contract + + task_id = str(args.get("task_id") or args.get("id") or "") + if not task_id: + return {"error": "task_id is required"} + return get_task_contract(task_id, repo=args.get("repo")) + + +def _handle_dhee_task_contract_import(_memory, args): + from dhee.task_contracts import import_task_contract + + path = str(args.get("path") or "") + if not path: + return {"error": "path is required"} + return import_task_contract(path, repo=args.get("repo")) + + +def _handle_dhee_task_contract_interpret(_memory, args): + from dhee.task_contracts import interpret_task_contract + + task_contract = args.get("contract") or args.get("path") or args.get("task_id") or args.get("id") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + return interpret_task_contract( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + +def _contract_ref_from_args(args: Dict[str, Any]) -> Any: + return args.get("contract") or args.get("path") or args.get("task_id") or args.get("id") + + +def _handle_dhee_contract_supervise_action(_memory, args): + from dhee.contract_supervisor import supervise_action + + task_contract = _contract_ref_from_args(args) + action = args.get("action") or args.get("proposed_action") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + if not isinstance(action, dict): + return {"error": "action or proposed_action object is required"} + return supervise_action( + task_contract, + action, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + +def _handle_dhee_contract_record_observation(_memory, args): + from dhee.contract_supervisor import record_observation_transition + + task_contract = _contract_ref_from_args(args) + action = args.get("action") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + if not isinstance(action, dict): + return {"error": "action object is required"} + return record_observation_transition( + task_contract, + action, + args.get("observation") or "", + repo=args.get("repo"), + outcome=str(args.get("outcome") or "observed"), + next_action=args.get("next_action") if isinstance(args.get("next_action"), dict) else None, + strict=bool(args.get("strict") or False), + ) + + +def _handle_dhee_contract_proof_bundle(_memory, args): + from dhee.contract_supervisor import build_proof_bundle + + task_contract = _contract_ref_from_args(args) + if not task_contract: + return {"error": "contract, path, or task_id is required"} + persist = args.get("persist") + return build_proof_bundle( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + persist=True if persist is None else bool(persist), + ) + + +def _handle_dhee_contract_runtime_activate(_memory, args): + from dhee.contract_runtime import activate_contract_runtime + + task_contract = _contract_ref_from_args(args) + if not task_contract: + return {"error": "contract, path, or task_id is required"} + return activate_contract_runtime( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + force=bool(args.get("force") or False), + agent_id=args.get("agent_id"), + harness=args.get("harness"), + ) + + +def _handle_dhee_contract_runtime_status(_memory, args): + from dhee.contract_runtime import contract_runtime_status + + return contract_runtime_status(repo=args.get("repo")) + + +def _handle_dhee_contract_runtime_deactivate(_memory, args): + from dhee.contract_runtime import deactivate_contract_runtime + + return deactivate_contract_runtime( + repo=args.get("repo"), + agent_id=args.get("agent_id"), + reason=str(args.get("reason") or "manual"), + ) + + +def _handle_dhee_contract_enforcement_set(_memory, args): + from dhee.contract_runtime import set_contract_enforcement + + return set_contract_enforcement( + str(args.get("mode") or ""), + repo=args.get("repo"), + agent_id=args.get("agent_id"), + reason=args.get("reason"), + ) + + +def _handle_dhee_contract_enforcement_status(_memory, args): + from dhee.contract_runtime import contract_enforcement_status + + return contract_enforcement_status(repo=args.get("repo")) + + +def _handle_dhee_contract_runtime_doctor(_memory, args): + from dhee.contract_runtime import contract_runtime_doctor + + return contract_runtime_doctor(repo=args.get("repo")) + + +def _handle_dhee_update_capsule_create(_memory, args): + from dhee.update_capsules import create_update_capsule + + return create_update_capsule( + repo=args.get("repo"), + since=args.get("since"), + task_id=args.get("task_id"), + out=args.get("out"), + title=args.get("title"), + summary=args.get("summary"), + commands=args.get("commands"), + evidence=args.get("evidence"), + ) + + +def _handle_dhee_update_capsule_list(_memory, args): + from dhee.update_capsules import list_update_capsules + + return { + "format": "dhee_update_capsule_list.v1", + "results": list_update_capsules(repo=args.get("repo")), + } + + +def _handle_dhee_update_capsule_get(_memory, args): + from dhee.update_capsules import get_update_capsule + + capsule_id = str(args.get("capsule_id") or args.get("id") or "") + if not capsule_id: + return {"error": "capsule_id is required"} + return get_update_capsule(capsule_id, repo=args.get("repo")) + + +def _handle_dhee_update_capsule_import(_memory, args): + from dhee.update_capsules import import_update_capsule + + path = str(args.get("path") or "") + if not path: + return {"error": "path is required"} + return import_update_capsule( + path, + repo=args.get("repo"), + allow_private=bool(args.get("allow_private") or False), + ) + + +def _handle_dhee_update_capsule_interpret(_memory, args): + from dhee.update_capsules import interpret_update_capsule + + capsule = args.get("capsule") or args.get("path") or args.get("capsule_id") or args.get("id") + if not capsule: + return {"error": "capsule, path, or capsule_id is required"} + return interpret_update_capsule( + capsule, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + def _handle_get_last_session(_memory, args): from dhee.core.kernel import get_last_session session = get_last_session( @@ -1744,6 +2218,23 @@ def _handle_dhee_tools_list(_memory, _arguments: Dict[str, Any]) -> Dict[str, An "dhee_context_checkpoint", "dhee_context_rollover", "dhee_context_provision", + "dhee_scene_world_route", + "dhee_scene_compile", + "dhee_scene_search", + "dhee_context_pack", + "dhee_task_contract_compile", + "dhee_task_contract_create", + "dhee_task_contract_list", + "dhee_task_contract_get", + "dhee_task_contract_import", + "dhee_task_contract_interpret", + "dhee_contract_supervise_action", + "dhee_contract_record_observation", + "dhee_update_capsule_create", + "dhee_update_capsule_list", + "dhee_update_capsule_get", + "dhee_update_capsule_import", + "dhee_update_capsule_interpret", "dhee_shell", "dhee_read", "dhee_grep", @@ -2247,6 +2738,30 @@ def _handle_dhee_expand_result(_memory, arguments: Dict[str, Any]) -> Dict[str, "get_memory": _handle_get_memory, "get_all_memories": _handle_get_all_memories, "dhee_context": _handle_dhee_context, + "dhee_scene_world_route": _handle_dhee_scene_world_route, + "dhee_scene_compile": _handle_dhee_scene_compile, + "dhee_scene_search": _handle_dhee_scene_search, + "dhee_context_pack": _handle_dhee_context_pack, + "dhee_task_contract_compile": _handle_dhee_task_contract_compile, + "dhee_task_contract_create": _handle_dhee_task_contract_create, + "dhee_task_contract_list": _handle_dhee_task_contract_list, + "dhee_task_contract_get": _handle_dhee_task_contract_get, + "dhee_task_contract_import": _handle_dhee_task_contract_import, + "dhee_task_contract_interpret": _handle_dhee_task_contract_interpret, + "dhee_contract_supervise_action": _handle_dhee_contract_supervise_action, + "dhee_contract_record_observation": _handle_dhee_contract_record_observation, + "dhee_contract_proof_bundle": _handle_dhee_contract_proof_bundle, + "dhee_contract_runtime_activate": _handle_dhee_contract_runtime_activate, + "dhee_contract_runtime_status": _handle_dhee_contract_runtime_status, + "dhee_contract_runtime_deactivate": _handle_dhee_contract_runtime_deactivate, + "dhee_contract_enforcement_set": _handle_dhee_contract_enforcement_set, + "dhee_contract_enforcement_status": _handle_dhee_contract_enforcement_status, + "dhee_contract_runtime_doctor": _handle_dhee_contract_runtime_doctor, + "dhee_update_capsule_create": _handle_dhee_update_capsule_create, + "dhee_update_capsule_list": _handle_dhee_update_capsule_list, + "dhee_update_capsule_get": _handle_dhee_update_capsule_get, + "dhee_update_capsule_import": _handle_dhee_update_capsule_import, + "dhee_update_capsule_interpret": _handle_dhee_update_capsule_interpret, "get_last_session": _handle_get_last_session, "save_session_digest": _handle_save_session_digest, "get_memory_stats": _handle_get_memory_stats, @@ -2296,6 +2811,8 @@ def _handle_dhee_expand_result(_memory, arguments: Dict[str, Any]) -> Dict[str, _MEMORY_FREE_TOOLS = { "get_last_session", "save_session_digest", "record_outcome", "reflect", "store_intention", + "dhee_scene_world_route", "dhee_scene_compile", "dhee_scene_search", "dhee_context_pack", + *CONTEXT_COMPILER_TOOL_NAMES, "dhee_submit_learning", "dhee_search_learnings", "dhee_promote_learning", "dhee_context_status", "dhee_context_state", "dhee_context_checkpoint", "dhee_context_rollover", "dhee_context_provision", "dhee_tools_list", "dhee_shell", "dhee_list_assets", "dhee_get_asset", "dhee_sync_codex_artifacts", "dhee_why", "dhee_thread_state", "dhee_shared_task", "dhee_shared_task_results", "dhee_inbox", "dhee_broadcast", "dhee_handoff", @@ -2303,6 +2820,7 @@ def _handle_dhee_expand_result(_memory, arguments: Dict[str, Any]) -> Dict[str, } + @server.list_tools() async def list_tools() -> List[Tool]: return list(TOOLS) diff --git a/dhee/mcp_slim.py b/dhee/mcp_slim.py index 2be825a..579e8d2 100644 --- a/dhee/mcp_slim.py +++ b/dhee/mcp_slim.py @@ -73,6 +73,8 @@ def stdio_server(): # type: ignore[no-redef] logger = logging.getLogger(__name__) +from dhee.mcp_registry import CONTEXT_COMPILER_TOOL_NAMES, make_tools + _MCP_CONTEXT_FIRST_INSTRUCTIONS = ( "Dhee is the native memory and context-router. At the start of substantive " "repo/workspace tasks, use Dhee context/recall before reconstructing from " @@ -124,6 +126,10 @@ def _get_db(): return _get_plugin().memory.db +def _default_user_id(args: Dict[str, Any]) -> str: + return str(args.get("user_id") or os.environ.get("DHEE_USER_ID") or "default") + + def _default_agent_id(args: Dict[str, Any]) -> str: return str(args.get("agent_id") or os.environ.get("DHEE_AGENT_ID") or "agent") @@ -216,6 +222,89 @@ def _default_agent_id(args: Dict[str, Any]) -> str: }, }, ), + Tool( + name="dhee_scene_world_route", + description=( + "Predict likely outcomes for candidate next actions using the optional " + "SceneWorld world-model sidecar. Use before choosing a high-stakes " + "agent action when DHEE_SCENE_WORLD_ENABLED=1." + ), + inputSchema={ + "type": "object", + "properties": { + "task": {"type": "string", "description": "Current task or scene"}, + "query": {"type": "string", "description": "Alias for task"}, + "repo": {"type": "string", "description": "Repo/workspace path"}, + "user_id": {"type": "string", "description": "User identifier"}, + "harness": {"type": "string", "description": "Harness/runtime id"}, + "top_k": {"type": "integer", "description": "Number of ranked actions to return"}, + "record": {"type": "boolean", "description": "Record the route trace when route logging is configured"}, + }, + }, + ), + Tool( + name="dhee_scene_compile", + description="Compile a private TemporalScene card from evidence pointers or admitted derivatives.", + inputSchema={ + "type": "object", + "properties": { + "evidence": {"type": "array", "items": {"type": "object"}}, + "query": {"type": "string"}, + "task": {"type": "string"}, + "title": {"type": "string"}, + "repo": {"type": "string"}, + "user_id": {"type": "string"}, + "privacy_scope": {"type": "string"}, + "store_dir": {"type": "string"}, + "save": {"type": "boolean"}, + "include_recent_memories": {"type": "boolean"}, + "include_repo_context": {"type": "boolean"}, + "include_session": {"type": "boolean"}, + "include_shared_task_results": {"type": "boolean"}, + "include_artifacts": {"type": "boolean"}, + "include_live_sources": {"type": "boolean"}, + "sources": {"type": "array", "items": {"type": "string"}}, + "session": {"type": "object"}, + "shared_task_results": {}, + "artifacts": {}, + "limit": {"type": "integer"}, + }, + }, + ), + Tool( + name="dhee_scene_search", + description="Search private TemporalScene cards and return prompt-safe summaries with evidence refs only.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "user_id": {"type": "string"}, + "repo": {"type": "string"}, + "limit": {"type": "integer"}, + "store_dir": {"type": "string"}, + "include_personal": {"type": "boolean"}, + }, + "required": ["query"], + }, + ), + Tool( + name="dhee_context_pack", + description="Build a hard-budget context pack from scene cards. Raw evidence expands only by pointer.", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "user_id": {"type": "string"}, + "repo": {"type": "string"}, + "token_budget": {"type": "integer"}, + "limit": {"type": "integer"}, + "store_dir": {"type": "string"}, + "include_personal": {"type": "boolean"}, + }, + "required": ["query"], + }, + ), + *make_tools(Tool, CONTEXT_COMPILER_TOOL_NAMES), Tool( name="dhee_submit_learning", description="Submit an auditable learning candidate. Candidates are not injected until promoted.", @@ -1016,6 +1105,404 @@ def _handle_dhee_context_provision(args: Dict[str, Any]) -> Dict[str, Any]: return _runtime_context(args, "provision", {"task": task}) or _context_store(args).provision(task) +def _handle_dhee_scene_world_route(args: Dict[str, Any]) -> Dict[str, Any]: + task = str(args.get("task") or args.get("query") or "") + if not task.strip(): + return {"error": "task is required"} + try: + from dhee.hooks.scene_world import route_task + + return route_task( + task, + repo=args.get("repo"), + user_id=args.get("user_id", "default"), + harness=str(args.get("harness") or os.environ.get("DHEE_HARNESS") or _default_agent_id(args)), + top_k=_bounded_limit(args, "top_k", 4, 8), + record=args.get("record") if "record" in args else None, + ) + except Exception as exc: + return {"enabled": False, "status": "error", "reason": f"{type(exc).__name__}: {exc}"} + + +def _scene_evidence_from_args(args: Dict[str, Any]) -> List[Dict[str, Any]]: + from dhee.temporal_scenes import collect_live_scene_sources, collect_scene_evidence + + sources = set(str(source) for source in (args.get("sources") or ["evidence"])) + if args.get("include_recent_memories"): + sources.add("memory") + if args.get("include_repo_context"): + sources.add("repo_context") + if args.get("include_session"): + sources.add("session") + if args.get("include_shared_task_results"): + sources.add("shared_task_results") + if args.get("include_artifacts"): + sources.add("artifacts") + if args.get("include_live_sources"): + sources.update({"session", "shared_task_results", "artifacts"}) + needs_live = bool(args.get("include_live_sources")) or any( + source in sources + for source in ("session", "session_digest", "shared_task_results", "shared_task", "artifacts", "artifact") + ) + live: Dict[str, Any] = {} + if needs_live: + live_db = None + if any(source in sources for source in ("shared_task_results", "shared_task", "artifacts", "artifact")): + try: + live_db = _get_db() + except Exception: + live_db = None + live = collect_live_scene_sources( + db=live_db, + repo=args.get("repo"), + user_id=_default_user_id(args), + agent_id=_default_agent_id(args), + limit=_bounded_limit(args, "limit", 8, 50), + include_session=("session" in sources or "session_digest" in sources) and not args.get("session"), + include_shared_task_results=("shared_task_results" in sources or "shared_task" in sources) and not args.get("shared_task_results"), + include_artifacts=("artifacts" in sources or "artifact" in sources) and not args.get("artifacts"), + ) + memory = None + if "memory" in sources: + try: + memory = _get_plugin().memory + except Exception: + memory = None + return collect_scene_evidence( + evidence=args.get("evidence") or [], + memory=memory, + query=str(args.get("query") or args.get("task") or ""), + user_id=_default_user_id(args), + repo=args.get("repo"), + session=args.get("session") or live.get("session"), + shared_task_results=args.get("shared_task_results") or live.get("shared_task_results"), + artifacts=args.get("artifacts") or live.get("artifacts"), + sources=sources, + limit=_bounded_limit(args, "limit", 8, 50), + ) + + +def _handle_dhee_scene_compile(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.temporal_scenes import compile_scene + + evidence = _scene_evidence_from_args(args) + if not evidence: + return {"error": "evidence is required unless include_recent_memories returns results"} + scene = compile_scene( + evidence, + user_id=_default_user_id(args), + repo=args.get("repo"), + task=str(args.get("task") or args.get("query") or ""), + privacy_scope=str(args.get("privacy_scope") or "personal"), + title=args.get("title"), + store_dir=args.get("store_dir"), + save=args.get("save") is not False, + ) + return { + "format": "dhee_scene_compile.v1", + "scene": scene.to_dict(), + "card": scene.to_card(), + } + + +def _handle_dhee_scene_search(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.temporal_scenes import search_scenes + + query = str(args.get("query") or "").strip() + if not query: + return {"error": "query is required"} + scenes = search_scenes( + query, + user_id=_default_user_id(args), + repo=args.get("repo"), + limit=_bounded_limit(args, "limit", 5, 30), + store_dir=args.get("store_dir"), + include_personal=args.get("include_personal") is not False, + ) + return { + "format": "dhee_scene_search.v1", + "results": [scene.to_card() for scene in scenes], + } + + +def _handle_dhee_context_pack(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.temporal_scenes import build_context_pack + + query = str(args.get("query") or "").strip() + if not query: + return {"error": "query is required"} + try: + budget = int(args.get("token_budget") or 1200) + except (TypeError, ValueError): + budget = 1200 + return build_context_pack( + query, + user_id=_default_user_id(args), + repo=args.get("repo"), + token_budget=max(128, min(20_000, budget)), + limit=_bounded_limit(args, "limit", 5, 30), + store_dir=args.get("store_dir"), + include_personal=args.get("include_personal") is not False, + ) + + +def _handle_dhee_task_contract_compile(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import compile_task_contract + + goal = str(args.get("goal") or args.get("task") or args.get("query") or "").strip() + if not goal: + return {"error": "goal, task, or query is required"} + return compile_task_contract( + goal, + repo=args.get("repo"), + mode=str(args.get("mode") or "patch"), + risk=args.get("risk"), + allowed_write_paths=args.get("allowed_write_paths"), + forbidden_paths=args.get("forbidden_paths"), + must_run=args.get("must_run"), + success_criteria=args.get("success_criteria"), + context_budget=args.get("context_budget"), + memory_pointers=args.get("memory_pointers"), + recent_failures=args.get("recent_failures"), + ) + + +def _task_goal_from_args(args: Dict[str, Any]) -> str: + return str(args.get("goal") or args.get("task") or args.get("query") or "").strip() + + +def _handle_dhee_task_contract_create(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import create_task_contract + + goal = _task_goal_from_args(args) + if not goal: + return {"error": "goal, task, or query is required"} + return create_task_contract( + goal, + repo=args.get("repo"), + out=args.get("out"), + mode=str(args.get("mode") or "patch"), + risk=args.get("risk"), + allowed_write_paths=args.get("allowed_write_paths"), + forbidden_paths=args.get("forbidden_paths"), + must_run=args.get("must_run"), + success_criteria=args.get("success_criteria"), + context_budget=args.get("context_budget"), + memory_pointers=args.get("memory_pointers"), + recent_failures=args.get("recent_failures"), + ) + + +def _handle_dhee_task_contract_list(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import list_task_contracts + + return { + "format": "dhee_task_contract_list.v1", + "results": list_task_contracts(repo=args.get("repo")), + } + + +def _handle_dhee_task_contract_get(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import get_task_contract + + task_id = str(args.get("task_id") or args.get("id") or "") + if not task_id: + return {"error": "task_id is required"} + return get_task_contract(task_id, repo=args.get("repo")) + + +def _handle_dhee_task_contract_import(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import import_task_contract + + path = str(args.get("path") or "") + if not path: + return {"error": "path is required"} + return import_task_contract(path, repo=args.get("repo")) + + +def _handle_dhee_task_contract_interpret(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.task_contracts import interpret_task_contract + + task_contract = args.get("contract") or args.get("path") or args.get("task_id") or args.get("id") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + return interpret_task_contract( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + +def _contract_ref_from_args(args: Dict[str, Any]) -> Any: + return args.get("contract") or args.get("path") or args.get("task_id") or args.get("id") + + +def _handle_dhee_contract_supervise_action(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_supervisor import supervise_action + + task_contract = _contract_ref_from_args(args) + action = args.get("action") or args.get("proposed_action") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + if not isinstance(action, dict): + return {"error": "action or proposed_action object is required"} + return supervise_action( + task_contract, + action, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + +def _handle_dhee_contract_record_observation(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_supervisor import record_observation_transition + + task_contract = _contract_ref_from_args(args) + action = args.get("action") + if not task_contract: + return {"error": "contract, path, or task_id is required"} + if not isinstance(action, dict): + return {"error": "action object is required"} + return record_observation_transition( + task_contract, + action, + args.get("observation") or "", + repo=args.get("repo"), + outcome=str(args.get("outcome") or "observed"), + next_action=args.get("next_action") if isinstance(args.get("next_action"), dict) else None, + strict=bool(args.get("strict") or False), + ) + + +def _handle_dhee_contract_proof_bundle(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_supervisor import build_proof_bundle + + task_contract = _contract_ref_from_args(args) + if not task_contract: + return {"error": "contract, path, or task_id is required"} + persist = args.get("persist") + return build_proof_bundle( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + persist=True if persist is None else bool(persist), + ) + + +def _handle_dhee_contract_runtime_activate(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import activate_contract_runtime + + task_contract = _contract_ref_from_args(args) + if not task_contract: + return {"error": "contract, path, or task_id is required"} + return activate_contract_runtime( + task_contract, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + force=bool(args.get("force") or False), + agent_id=args.get("agent_id"), + harness=args.get("harness"), + ) + + +def _handle_dhee_contract_runtime_status(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import contract_runtime_status + + return contract_runtime_status(repo=args.get("repo")) + + +def _handle_dhee_contract_runtime_deactivate(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import deactivate_contract_runtime + + return deactivate_contract_runtime( + repo=args.get("repo"), + agent_id=args.get("agent_id"), + reason=str(args.get("reason") or "manual"), + ) + + +def _handle_dhee_contract_enforcement_set(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import set_contract_enforcement + + return set_contract_enforcement( + str(args.get("mode") or ""), + repo=args.get("repo"), + agent_id=args.get("agent_id"), + reason=args.get("reason"), + ) + + +def _handle_dhee_contract_enforcement_status(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import contract_enforcement_status + + return contract_enforcement_status(repo=args.get("repo")) + + +def _handle_dhee_contract_runtime_doctor(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.contract_runtime import contract_runtime_doctor + + return contract_runtime_doctor(repo=args.get("repo")) + + +def _handle_dhee_update_capsule_create(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.update_capsules import create_update_capsule + + return create_update_capsule( + repo=args.get("repo"), + since=args.get("since"), + task_id=args.get("task_id"), + out=args.get("out"), + title=args.get("title"), + summary=args.get("summary"), + commands=args.get("commands"), + evidence=args.get("evidence"), + ) + + +def _handle_dhee_update_capsule_list(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.update_capsules import list_update_capsules + + return { + "format": "dhee_update_capsule_list.v1", + "results": list_update_capsules(repo=args.get("repo")), + } + + +def _handle_dhee_update_capsule_get(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.update_capsules import get_update_capsule + + capsule_id = str(args.get("capsule_id") or args.get("id") or "") + if not capsule_id: + return {"error": "capsule_id is required"} + return get_update_capsule(capsule_id, repo=args.get("repo")) + + +def _handle_dhee_update_capsule_import(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.update_capsules import import_update_capsule + + path = str(args.get("path") or "") + if not path: + return {"error": "path is required"} + return import_update_capsule( + path, + repo=args.get("repo"), + allow_private=bool(args.get("allow_private") or False), + ) + + +def _handle_dhee_update_capsule_interpret(args: Dict[str, Any]) -> Dict[str, Any]: + from dhee.update_capsules import interpret_update_capsule + + capsule = args.get("capsule") or args.get("path") or args.get("capsule_id") or args.get("id") + if not capsule: + return {"error": "capsule, path, or capsule_id is required"} + return interpret_update_capsule( + capsule, + repo=args.get("repo"), + strict=bool(args.get("strict") or False), + ) + + def _handle_dhee_tools_list(_args: Dict[str, Any]) -> Dict[str, Any]: default_tools = [tool.name for tool in TOOLS] advanced_tools = [ @@ -1211,6 +1698,30 @@ def _handle_dhee_handoff(args: Dict[str, Any]) -> Dict[str, Any]: "dhee_context_checkpoint": _handle_dhee_context_checkpoint, "dhee_context_rollover": _handle_dhee_context_rollover, "dhee_context_provision": _handle_dhee_context_provision, + "dhee_scene_world_route": _handle_dhee_scene_world_route, + "dhee_scene_compile": _handle_dhee_scene_compile, + "dhee_scene_search": _handle_dhee_scene_search, + "dhee_context_pack": _handle_dhee_context_pack, + "dhee_task_contract_compile": _handle_dhee_task_contract_compile, + "dhee_task_contract_create": _handle_dhee_task_contract_create, + "dhee_task_contract_list": _handle_dhee_task_contract_list, + "dhee_task_contract_get": _handle_dhee_task_contract_get, + "dhee_task_contract_import": _handle_dhee_task_contract_import, + "dhee_task_contract_interpret": _handle_dhee_task_contract_interpret, + "dhee_contract_supervise_action": _handle_dhee_contract_supervise_action, + "dhee_contract_record_observation": _handle_dhee_contract_record_observation, + "dhee_contract_proof_bundle": _handle_dhee_contract_proof_bundle, + "dhee_contract_runtime_activate": _handle_dhee_contract_runtime_activate, + "dhee_contract_runtime_status": _handle_dhee_contract_runtime_status, + "dhee_contract_runtime_deactivate": _handle_dhee_contract_runtime_deactivate, + "dhee_contract_enforcement_set": _handle_dhee_contract_enforcement_set, + "dhee_contract_enforcement_status": _handle_dhee_contract_enforcement_status, + "dhee_contract_runtime_doctor": _handle_dhee_contract_runtime_doctor, + "dhee_update_capsule_create": _handle_dhee_update_capsule_create, + "dhee_update_capsule_list": _handle_dhee_update_capsule_list, + "dhee_update_capsule_get": _handle_dhee_update_capsule_get, + "dhee_update_capsule_import": _handle_dhee_update_capsule_import, + "dhee_update_capsule_interpret": _handle_dhee_update_capsule_interpret, "dhee_tools_list": _handle_dhee_tools_list, "dhee_shell": _handle_dhee_shell, "dhee_inbox": _handle_dhee_inbox, @@ -1229,6 +1740,7 @@ def _handle_dhee_handoff(args: Dict[str, Any]) -> Dict[str, Any]: # MCP Protocol # --------------------------------------------------------------------------- + @server.list_tools() async def list_tools() -> List[Tool]: return list(TOOLS) diff --git a/dhee/memory/admission.py b/dhee/memory/admission.py new file mode 100644 index 0000000..91c5a3f --- /dev/null +++ b/dhee/memory/admission.py @@ -0,0 +1,519 @@ +"""Memory admission policy for passive agent observations. + +Dhee is the memory layer, so agents should be able to submit rich candidates +without each agent re-implementing quality, retention, and forgetting rules. +This module keeps the hot path deterministic and local: no LLM calls, no network, +and no screenshots/raw media stored here. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from datetime import date, timedelta +from typing import Any, Dict, List, Optional + + +PASSIVE_SOURCES = { + "chotu_screen_memory", + "screen_memory", + "screen_activity", + "screen_observation", + "passive_observation", + "agent_observation", + "macos_active_window", + "desktop_observer", +} + +PASSIVE_TYPES = { + "screen_activity", + "screen_observation", + "interest_signal", + "passive_observation", + "observation", +} + +IGNORED_APPS = { + "Control Center", + "Dock", + "Notification Center", + "SystemUIServer", + "UserNotificationCenter", + "Window Server", + "loginwindow", +} + +IGNORED_BUNDLES = { + "com.apple.controlcenter", + "com.apple.dock", + "com.apple.UserNotificationCenter", + "com.apple.notificationcenterui", + "com.apple.loginwindow", + "com.apple.systemuiserver", +} + +GENERIC_TITLES = { + "arc", + "chrome", + "codex", + "google chrome", + "new tab", + "safari", + "youtube", +} + +HIGH_CHURN_MARKERS = { + "claude", + "codex", + "com.microsoft.vscode", + "cursor", + "visual studio code", + "windsurf", +} + +INTEREST_MARKERS = { + "chatgpt", + "claude", + "codex", + "course", + "github", + "tutorial", + "video", + "youtu.be", + "youtube", +} + +COMMON_WORDS = { + "about", + "after", + "again", + "agent", + "also", + "answer", + "are", + "because", + "building", + "can", + "chat", + "check", + "code", + "context", + "currently", + "data", + "dhee", + "doing", + "file", + "for", + "from", + "github", + "have", + "how", + "memory", + "model", + "new", + "not", + "now", + "open", + "page", + "repo", + "screen", + "search", + "should", + "that", + "the", + "this", + "use", + "user", + "what", + "when", + "with", + "working", + "youtube", +} + + +@dataclass(frozen=True) +class MemoryAdmissionDecision: + applies: bool + should_store: bool + retention_policy: str + confidence: float + score: float + ocr_quality: float + reasons: List[str] + promotion_reason: str + skip_reason: Optional[str] = None + include_ocr_excerpt: bool = True + + def to_metadata(self) -> Dict[str, Any]: + return { + "applies": self.applies, + "should_store": self.should_store, + "retention_policy": self.retention_policy, + "confidence": self.confidence, + "score": self.score, + "ocr_quality": self.ocr_quality, + "reasons": list(self.reasons), + "promotion_reason": self.promotion_reason, + "skip_reason": self.skip_reason, + "include_ocr_excerpt": self.include_ocr_excerpt, + } + + +def evaluate_memory_candidate( + content: str, + metadata: Optional[Dict[str, Any]] = None, + *, + explicit_remember: bool = False, +) -> MemoryAdmissionDecision: + """Decide whether a memory candidate should enter Dhee. + + Explicit user memories pass through. Passive observations from screen, + browser, or wearable agents get admission-scored so Dhee stores useful + semantic context and rejects transient UI/OCR noise. + """ + + metadata = metadata or {} + if explicit_remember or not _should_apply_admission(content, metadata): + return _decision( + applies=False, + should_store=True, + retention_policy=str(metadata.get("retention_policy") or "durable"), + confidence=_coerce_float(metadata.get("confidence"), 1.0), + score=1.0, + ocr_quality=1.0, + promotion_reason="explicit_or_non_passive", + reasons=["bypass"], + ) + + evidence = _evidence(metadata) + app = _first_str(evidence.get("app"), metadata.get("app"), metadata.get("source_app")) + bundle = _first_str(evidence.get("bundle_id"), metadata.get("bundle_id")) + title = _first_str(evidence.get("title"), metadata.get("title")) + dwell_seconds = _coerce_int(evidence.get("dwell_seconds"), metadata.get("dwell_seconds"), 0) + ocr_text = _first_str( + metadata.get("ocr_text"), + evidence.get("ocr_text"), + _visible_text_from_content(content), + ) + has_vision = bool( + metadata.get("vision_summary") + or evidence.get("vision_summary_sha256") + or evidence.get("screen_image_available") + ) + + if app in IGNORED_APPS or bundle in IGNORED_BUNDLES: + return _decision( + applies=True, + should_store=False, + retention_policy="ephemeral", + confidence=0.0, + score=0.0, + ocr_quality=0.0, + promotion_reason="ignored_system_surface", + skip_reason="ignored_system_surface", + reasons=["ignored_app"], + include_ocr_excerpt=False, + ) + + ocr_quality = _ocr_quality_score(ocr_text) + common_hits = _ocr_common_hits(ocr_text) + title_quality = _title_quality(title, app) + interest = _looks_like_interest(app, title, ocr_text, metadata) + high_churn = _is_high_churn(app, bundle, title) + has_specific_context = title_quality >= 0.55 or ocr_quality >= 0.55 or has_vision + + score = 0.18 + score += min(dwell_seconds / 180, 1.0) * 0.22 + score += title_quality * 0.24 + score += ocr_quality * 0.30 + if has_vision: + score += 0.22 + if interest: + score += 0.08 + if high_churn and title_quality < 0.45 and not has_vision: + score -= 0.14 + score = round(max(0.0, min(score, 1.0)), 3) + + reasons: List[str] = [] + if title_quality >= 0.55: + reasons.append("specific_title") + if ocr_quality >= 0.55: + reasons.append("readable_ocr") + if has_vision: + reasons.append("vision_summary") + if dwell_seconds >= 120: + reasons.append("long_dwell") + elif dwell_seconds >= 30: + reasons.append("meaningful_dwell") + if interest: + reasons.append("interest_signal") + + if high_churn and title_quality < 0.45 and not has_vision and dwell_seconds < 60 and common_hits < 2: + return _decision( + applies=True, + should_store=False, + retention_policy="ephemeral", + confidence=0.0, + score=score, + ocr_quality=ocr_quality, + promotion_reason="high_churn_ocr_noise", + skip_reason="low_quality_signal", + reasons=reasons, + include_ocr_excerpt=False, + ) + if not has_specific_context and dwell_seconds < 30: + return _decision( + applies=True, + should_store=False, + retention_policy="ephemeral", + confidence=0.0, + score=score, + ocr_quality=ocr_quality, + promotion_reason="no_specific_context", + skip_reason="low_quality_signal", + reasons=reasons, + include_ocr_excerpt=False, + ) + if ocr_text and ocr_quality < 0.22 and title_quality < 0.45 and not has_vision: + return _decision( + applies=True, + should_store=False, + retention_policy="ephemeral", + confidence=0.0, + score=score, + ocr_quality=ocr_quality, + promotion_reason="ocr_noise", + skip_reason="low_ocr_quality", + reasons=reasons, + include_ocr_excerpt=False, + ) + + durable = ( + (has_vision and dwell_seconds >= 30) + or (dwell_seconds >= 180 and score >= 0.52) + or (interest and dwell_seconds >= 60 and score >= 0.58) + or (ocr_quality >= 0.74 and dwell_seconds >= 90) + ) + retention_policy = "durable" if durable else "session" + confidence = round(max(0.45, min(0.94, 0.42 + score * 0.48)), 3) + return _decision( + applies=True, + should_store=True, + retention_policy=retention_policy, + confidence=confidence, + score=score, + ocr_quality=ocr_quality, + promotion_reason="durable_quality_gate" if durable else "session_until_promoted", + reasons=reasons, + include_ocr_excerpt=ocr_quality >= 0.38 and not has_vision, + ) + + +def admission_expiration_date(retention_policy: str) -> Optional[str]: + """Return a coarse ISO date for non-durable admission retention.""" + + policy = (retention_policy or "").lower() + if policy == "ephemeral": + return (date.today() + timedelta(days=1)).isoformat() + if policy == "session": + return (date.today() + timedelta(days=7)).isoformat() + if policy == "short": + return (date.today() + timedelta(days=30)).isoformat() + return None + + +def forget_reason_for_memory(memory: Dict[str, Any]) -> Optional[str]: + """Return a reason when an existing memory should be forgotten.""" + + metadata = memory.get("metadata") if isinstance(memory.get("metadata"), dict) else {} + if not _should_apply_admission(str(memory.get("memory") or ""), metadata): + return None + admission = metadata.get("dhee_admission") if isinstance(metadata, dict) else None + if isinstance(admission, dict): + try: + score = float(admission.get("score")) + except (TypeError, ValueError): + score = 1.0 + if admission.get("should_store") is False: + return f"admission:{admission.get('skip_reason') or 'rejected'}" + if score < 0.25: + return "admission:low_quality" + return None + + decision = evaluate_memory_candidate( + str(memory.get("memory") or ""), + metadata, + explicit_remember=False, + ) + if decision.applies and not decision.should_store: + return f"admission:{decision.skip_reason or 'rejected'}" + return None + + +def sanitize_admitted_content(content: str, decision: MemoryAdmissionDecision) -> str: + """Trim noisy raw OCR from admitted passive memories when Dhee does not need it.""" + + if not decision.applies or decision.include_ocr_excerpt: + return content + return _strip_visible_text(content) + + +def _should_apply_admission(content: str, metadata: Dict[str, Any]) -> bool: + if metadata.get("admission") is False or metadata.get("dhee_admission_bypass"): + return False + evidence = _evidence(metadata) + source = str(metadata.get("source") or metadata.get("source_app") or "").strip().lower() + mem_type = str(metadata.get("type") or metadata.get("memory_type") or "").strip().lower() + kind = str(evidence.get("kind") or "").strip().lower() + if source in PASSIVE_SOURCES or mem_type in PASSIVE_TYPES or kind in PASSIVE_SOURCES: + return True + if any(key in evidence for key in ("app", "bundle_id", "title", "dwell_seconds", "ocr_text_sha256")): + return True + lowered = content.lower() + return "visible screen activity" in lowered or "active screen" in lowered + + +def _decision(**kwargs: Any) -> MemoryAdmissionDecision: + return MemoryAdmissionDecision(**kwargs) + + +def _evidence(metadata: Dict[str, Any]) -> Dict[str, Any]: + evidence = metadata.get("evidence") + return evidence if isinstance(evidence, dict) else {} + + +def _first_str(*values: Any) -> str: + for value in values: + if value is None: + continue + text = str(value).strip() + if text: + return text + return "" + + +def _coerce_float(value: Any, default: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _coerce_int(*values: Any) -> int: + default = int(values[-1]) if values else 0 + for value in values[:-1]: + try: + return int(value) + except (TypeError, ValueError): + continue + return default + + +def _visible_text_from_content(content: str) -> str: + for marker in ("Visible text excerpt:", "Visible text:", "Selected text:"): + if marker in content: + return content.split(marker, 1)[1].split("Visual summary:", 1)[0].strip() + return "" + + +def _strip_visible_text(content: str) -> str: + for marker in ("Visible text excerpt:", "Visible text:"): + if marker not in content: + continue + before, after = content.split(marker, 1) + if "Visual summary:" in after: + _, rest = after.split("Visual summary:", 1) + return (before.rstrip() + "\nVisual summary:\n" + rest.strip()).strip() + return before.rstrip() + return content + + +def _ocr_quality_score(text: str) -> float: + compact = " ".join(str(text or "").split()) + if len(compact) < 24: + return 0.0 + chars = len(compact) + alnum_ratio = sum(ch.isalnum() for ch in compact) / chars + allowed = ".,:;!?()[]{}'\"/-_@#%+&" + weird_ratio = sum(not (ch.isalnum() or ch.isspace() or ch in allowed) for ch in compact) / chars + tokens = re.findall(r"[A-Za-z][A-Za-z0-9'_-]{2,}", compact) + if not tokens: + return 0.0 + good_tokens = [token for token in tokens if _looks_like_word(token)] + token_quality = len(good_tokens) / len(tokens) + unique_ratio = len({token.lower() for token in tokens}) / len(tokens) + common_ratio = min(_ocr_common_hits(compact) / 8, 1.0) + length_score = min(chars / 500, 1.0) + score = ( + 0.16 + + alnum_ratio * 0.18 + + token_quality * 0.28 + + unique_ratio * 0.14 + + common_ratio * 0.12 + + length_score * 0.12 + - weird_ratio * 0.35 + ) + if len(tokens) < 6: + score -= 0.12 + return round(max(0.0, min(score, 1.0)), 3) + + +def _ocr_common_hits(text: str) -> int: + return sum( + 1 + for token in re.findall(r"[A-Za-z][A-Za-z0-9'_-]{2,}", str(text or "")) + if token.lower() in COMMON_WORDS + ) + + +def _title_quality(title: str, app: str) -> float: + normalized = str(title or "").strip() + if _is_generic_title(normalized, app): + return 0.0 + words = re.findall(r"[A-Za-z0-9][A-Za-z0-9'_-]{1,}", normalized) + if not words: + return 0.0 + score = min(len(normalized) / 80, 0.45) + min(len(words) / 8, 0.35) + 0.2 + if any(marker in normalized.lower() for marker in ("chatgpt", "github", "google", "youtube")): + score += 0.08 + return round(max(0.0, min(score, 1.0)), 3) + + +def _is_generic_title(title: str, app: str) -> bool: + normalized_title = str(title or "").strip().lower() + normalized_app = str(app or "").strip().lower() + if not normalized_title: + return True + return normalized_title == normalized_app or normalized_title in GENERIC_TITLES + + +def _looks_like_word(token: str) -> bool: + letters = re.sub(r"[^a-z]", "", token.lower()) + if len(letters) < 3: + return True + if not re.search(r"[aeiou]", letters): + return False + if re.search(r"[^aeiou]{6,}", letters): + return False + if len(set(letters)) <= 2 and len(letters) >= 6: + return False + return True + + +def _is_high_churn(app: str, bundle: str, title: str) -> bool: + haystack = " ".join((app or "", bundle or "", title or "")).lower() + return any(marker in haystack for marker in HIGH_CHURN_MARKERS) + + +def _looks_like_interest(app: str, title: str, ocr_text: str, metadata: Dict[str, Any]) -> bool: + haystack = " ".join( + ( + app or "", + title or "", + ocr_text or "", + str(metadata.get("source") or ""), + str(metadata.get("type") or ""), + ) + ).lower() + return any(marker in haystack for marker in INTEREST_MARKERS) diff --git a/dhee/memory/write_pipeline.py b/dhee/memory/write_pipeline.py index 94fd173..625b5d5 100644 --- a/dhee/memory/write_pipeline.py +++ b/dhee/memory/write_pipeline.py @@ -20,6 +20,11 @@ from dhee.core.traces import initialize_traces from dhee.memory.cost import estimate_token_count, estimate_output_tokens from dhee.memory.episodic import index_episodic_events_for_memory as _index_episodic +from dhee.memory.admission import ( + admission_expiration_date, + evaluate_memory_candidate, + sanitize_admitted_content, +) from dhee.memory.retrieval_helpers import ( attach_bitemporal_metadata, normalize_bitemporal_value, @@ -466,6 +471,33 @@ def _add_llm_cost(input_tokens: float) -> None: if explicit_remember and explicit_intent and explicit_intent.content: content = explicit_intent.content + admission = evaluate_memory_candidate( + content, + mem_metadata, + explicit_remember=explicit_remember, + ) + if admission.applies: + mem_metadata["dhee_admission"] = admission.to_metadata() + mem_metadata["dhee_passive_observation"] = True + if not explicit_remember: + mem_metadata["dhee_lite_path"] = True + mem_metadata["retention_policy"] = admission.retention_policy + mem_metadata["confidence"] = admission.confidence + mem_metadata["admission_score"] = admission.score + mem_metadata["admission_promotion_reason"] = admission.promotion_reason + if not admission.should_store: + return { + "event": "SKIP", + "reason": admission.skip_reason or "admission_rejected", + "memory": content, + "admission": admission.to_metadata(), + } + if expiration_date is None: + expiration_date = admission_expiration_date(admission.retention_policy) + if admission.retention_policy != "durable": + initial_strength = min(initial_strength, max(0.2, admission.confidence)) + content = sanitize_admitted_content(content, admission) + blocked = detect_sensitive_categories(content) # allow_sensitive: explicit caller opt-in, or caller explicitly provided # the content (infer=False / user_provided=True). PII detection is a @@ -494,7 +526,11 @@ def _add_llm_cost(input_tokens: float) -> None: # --- Deferred enrichment: lite path (0 LLM calls) --- enrichment_config = getattr(self._config, "enrichment", None) - if enrichment_config and enrichment_config.defer_enrichment: + use_lite_path = bool( + (enrichment_config and enrichment_config.defer_enrichment) + or mem_metadata.get("dhee_lite_path") + ) + if use_lite_path: return self.process_single_memory_lite( content=content, mem_metadata=mem_metadata, @@ -1032,6 +1068,7 @@ def _do_category(): "namespace": namespace_value, "vector_nodes": len(vectors), "memory_type": memory_type, + "admission": mem_metadata.get("dhee_admission"), } def process_single_memory_lite( @@ -1298,6 +1335,7 @@ def process_single_memory_lite( "vector_nodes": 1, "memory_type": memory_type, "enrichment_status": "pending", + "admission": mem_metadata.get("dhee_admission"), } def process_memory_batch( diff --git a/dhee/release_hygiene.py b/dhee/release_hygiene.py new file mode 100644 index 0000000..5a7da4c --- /dev/null +++ b/dhee/release_hygiene.py @@ -0,0 +1,392 @@ +"""Release hygiene checks for premium Dhee builds. + +This module is intentionally deterministic. It does not decide whether a +change is good; it proves whether the repo is in a releasable state and records +the intended scope when it is not. +""" + +from __future__ import annotations + +import os +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from dhee.runtime_io import read_json_checked, write_json_atomic + + +RELEASE_INTENT_SCHEMA = "dhee.release_intent.v1" +RELEASE_CHECK_SCHEMA = "dhee.release_check.v1" +RELEASE_INTENT_REL_PATH = ".dhee/context/release_intent.json" + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _run_git( + repo: Path, + args: Sequence[str], + *, + text: bool = True, +) -> subprocess.CompletedProcess: + return subprocess.run( + ["git", "-C", str(repo), *args], + capture_output=True, + text=text, + check=False, + ) + + +def _repo_root(repo: str | os.PathLike[str] | None) -> Dict[str, Any]: + start = Path(repo or os.getcwd()).expanduser() + proc = _run_git(start, ["rev-parse", "--show-toplevel"]) + if proc.returncode != 0: + return { + "ok": False, + "repo": str(start), + "diagnostic": { + "code": "GIT_REPO_UNAVAILABLE", + "message": (proc.stderr or proc.stdout or "Not inside a git repository.").strip(), + "path": str(start), + }, + } + return {"ok": True, "repo": proc.stdout.strip()} + + +def _git_text(repo_root: Path, args: Sequence[str]) -> str: + proc = _run_git(repo_root, args) + return proc.stdout.strip() if proc.returncode == 0 else "" + + +def _decode_git_path(raw: bytes) -> str: + return raw.decode("utf-8", errors="surrogateescape") + + +def _parse_porcelain_z(raw: bytes) -> List[Dict[str, Any]]: + """Parse `git status --porcelain=v1 -z` output. + + Git's NUL format is the only reliable way to handle spaces and unusual + filenames. Rename/copy records contain an extra source path; both paths are + kept so scope checks cannot miss a moved file. + """ + + entries: List[Dict[str, Any]] = [] + parts = raw.split(b"\0") + idx = 0 + while idx < len(parts): + record = parts[idx] + idx += 1 + if not record: + continue + status = _decode_git_path(record[:2]) + path = _decode_git_path(record[3:]) if len(record) > 3 else "" + paths = [path] if path else [] + if ("R" in status or "C" in status) and idx < len(parts) and parts[idx]: + paths.append(_decode_git_path(parts[idx])) + idx += 1 + entries.append( + { + "status": status, + "path": path, + "paths": paths, + "kind": "untracked" if status == "??" else "ignored" if status == "!!" else "tracked_change", + "staged": bool(status[:1].strip() and status[:1] not in {"?", "!"}), + "unstaged": bool(status[1:2].strip() and status[1:2] not in {"?", "!"}), + } + ) + return entries + + +def _normalize_repo_path(repo_root: Path, value: str) -> str: + raw = str(value or "").strip() + if not raw: + return "" + path = Path(raw).expanduser() + if path.is_absolute(): + try: + rel = path.resolve(strict=False).relative_to(repo_root.resolve(strict=False)) + text = rel.as_posix() + except ValueError: + return path.as_posix() + else: + text = Path(raw).as_posix() + while text.startswith("./"): + text = text[2:] + text = text.rstrip("/") + return text or "." + + +def _normalize_paths(repo_root: Path, paths: Optional[Iterable[str]]) -> List[str]: + normalized: List[str] = [] + seen = set() + for item in paths or []: + text = _normalize_repo_path(repo_root, str(item)) + if text and text not in seen: + normalized.append(text) + seen.add(text) + return normalized + + +def _matches_intent(path: str, intended_paths: Sequence[str]) -> bool: + clean = path.rstrip("/") + for intended in intended_paths: + prefix = intended.rstrip("/") + if prefix == ".": + return True + if clean == prefix or clean.startswith(prefix + "/"): + return True + return False + + +def _blocker(code: str, message: str, **extra: Any) -> Dict[str, Any]: + return {"code": code, "message": message, **extra} + + +def _unique_paths(entries: Sequence[Dict[str, Any]]) -> List[str]: + seen = set() + paths: List[str] = [] + for entry in entries: + for path in entry.get("paths") or []: + if path and path not in seen: + paths.append(path) + seen.add(path) + return paths + + +def release_intent_path(repo: str | os.PathLike[str] | None = None) -> Path: + root = _repo_root(repo) + if not root.get("ok"): + return Path(repo or os.getcwd()).expanduser() / RELEASE_INTENT_REL_PATH + return Path(root["repo"]) / RELEASE_INTENT_REL_PATH + + +def load_release_intent(repo: str | os.PathLike[str] | None = None) -> Dict[str, Any]: + root = _repo_root(repo) + if not root.get("ok"): + return { + "ok": False, + "exists": False, + "path": str(Path(repo or os.getcwd()).expanduser() / RELEASE_INTENT_REL_PATH), + "intended_paths": [], + "diagnostics": [root["diagnostic"]], + } + + repo_root = Path(root["repo"]) + path = repo_root / RELEASE_INTENT_REL_PATH + checked = read_json_checked(path, expected_schema=RELEASE_INTENT_SCHEMA) + diagnostics = checked.get("diagnostics") or [] + data = checked.get("data") or {} + intended_paths = _normalize_paths(repo_root, data.get("intended_paths") or []) + return { + "ok": bool(checked.get("ok")) or not bool(checked.get("exists")), + "exists": bool(checked.get("exists")), + "path": str(path), + "schema_version": data.get("schema_version"), + "reason": data.get("reason") or "", + "created_by": data.get("created_by") or "", + "generated_at": data.get("generated_at") or "", + "intended_paths": intended_paths, + "diagnostics": diagnostics if checked.get("exists") else [], + } + + +def write_release_intent( + repo: str | os.PathLike[str] | None, + paths: Sequence[str], + *, + reason: str = "", + agent_id: str = "cli", +) -> Dict[str, Any]: + root = _repo_root(repo) + if not root.get("ok"): + return {"ok": False, "diagnostics": [root["diagnostic"]]} + repo_root = Path(root["repo"]) + intended_paths = _normalize_paths(repo_root, paths) + payload = { + "schema_version": RELEASE_INTENT_SCHEMA, + "generated_at": _now_iso(), + "created_by": agent_id, + "reason": reason or "", + "intended_paths": intended_paths, + } + path = repo_root / RELEASE_INTENT_REL_PATH + write_result = write_json_atomic(path, payload) + return { + "ok": bool(write_result.get("ok")), + "repo": str(repo_root), + "path": str(path), + "intent": payload, + "diagnostics": [] if write_result.get("ok") else [write_result.get("diagnostic")], + } + + +def git_status_entries(repo: str | os.PathLike[str] | None = None) -> Dict[str, Any]: + root = _repo_root(repo) + if not root.get("ok"): + return {"ok": False, "repo": root.get("repo"), "entries": [], "diagnostics": [root["diagnostic"]]} + repo_root = Path(root["repo"]) + proc = _run_git(repo_root, ["status", "--porcelain=v1", "-z", "--untracked-files=all"], text=False) + if proc.returncode != 0: + return { + "ok": False, + "repo": str(repo_root), + "entries": [], + "diagnostics": [ + { + "code": "GIT_STATUS_FAILED", + "message": _decode_git_path(proc.stderr or proc.stdout or b"git status failed").strip(), + "path": str(repo_root), + } + ], + } + return {"ok": True, "repo": str(repo_root), "entries": _parse_porcelain_z(proc.stdout or b""), "diagnostics": []} + + +def release_check( + repo: str | os.PathLike[str] | None = None, + *, + intended_paths: Optional[Sequence[str]] = None, + require_clean: bool = True, + expected_artifacts: Optional[Sequence[str]] = None, +) -> Dict[str, Any]: + root = _repo_root(repo) + if not root.get("ok"): + return { + "schema_version": RELEASE_CHECK_SCHEMA, + "generated_at": _now_iso(), + "repo": root.get("repo"), + "status": "blocked", + "release_allowed": False, + "release_blockers": [root["diagnostic"]], + "warnings": [], + } + + repo_root = Path(root["repo"]) + intent = load_release_intent(repo_root) + cli_intended = _normalize_paths(repo_root, intended_paths) + combined_intended = _normalize_paths(repo_root, [*(intent.get("intended_paths") or []), *cli_intended]) + status = git_status_entries(repo_root) + entries = status.get("entries") or [] + dirty_paths = _unique_paths(entries) + unexpected_dirty_paths = [ + path for path in dirty_paths if not _matches_intent(path, combined_intended) + ] + + blockers: List[Dict[str, Any]] = [] + warnings: List[Dict[str, Any]] = [] + if intent.get("exists") and not intent.get("ok"): + blockers.append( + _blocker( + "RELEASE_INTENT_UNREADABLE", + "Release intent exists but cannot be trusted.", + diagnostics=intent.get("diagnostics") or [], + ) + ) + if not status.get("ok"): + blockers.extend(status.get("diagnostics") or []) + if dirty_paths and require_clean: + blockers.append( + _blocker( + "GIT_WORKTREE_DIRTY", + "Release tag is blocked until git status is clean.", + paths=dirty_paths, + ) + ) + if unexpected_dirty_paths: + blockers.append( + _blocker( + "UNEXPECTED_DIRTY_PATHS", + "Dirty paths are outside the documented release intent.", + paths=unexpected_dirty_paths, + ) + ) + if dirty_paths and not combined_intended: + warnings.append( + _blocker( + "RELEASE_INTENT_MISSING", + "Dirty work has no release intent file or --intended-path scope.", + ) + ) + + missing_artifacts: List[str] = [] + for artifact in expected_artifacts or []: + rel = _normalize_repo_path(repo_root, artifact) + if not (repo_root / rel).exists(): + missing_artifacts.append(rel) + if missing_artifacts: + blockers.append( + _blocker( + "MISSING_RELEASE_ARTIFACT", + "Expected release artifact is missing.", + paths=missing_artifacts, + ) + ) + + clean = not dirty_paths + release_allowed = not blockers + return { + "schema_version": RELEASE_CHECK_SCHEMA, + "generated_at": _now_iso(), + "repo": str(repo_root), + "status": "ready" if release_allowed else "blocked", + "release_allowed": release_allowed, + "require_clean": require_clean, + "git": { + "clean": clean, + "branch": _git_text(repo_root, ["rev-parse", "--abbrev-ref", "HEAD"]), + "head": _git_text(repo_root, ["rev-parse", "HEAD"]), + "dirty_count": len(dirty_paths), + "dirty_paths": dirty_paths, + "entries": entries, + }, + "intent": { + **intent, + "cli_intended_paths": cli_intended, + "combined_intended_paths": combined_intended, + }, + "unexpected_dirty_paths": unexpected_dirty_paths, + "release_blockers": blockers, + "warnings": warnings, + "summary": ( + "Release allowed: git tree is clean." + if release_allowed + else "Release blocked: fix blockers before tagging." + ), + } + + +def format_release_check(report: Dict[str, Any]) -> str: + lines = [ + f"Dhee release check: {report.get('status')}", + f" repo {report.get('repo') or ''}", + ] + git = report.get("git") or {} + if git: + head = (git.get("head") or "")[:12] + lines.append(f" branch {git.get('branch') or '(unknown)'} {head}") + lines.append(f" clean {'yes' if git.get('clean') else 'no'} ({git.get('dirty_count', 0)} dirty path(s))") + intent = report.get("intent") or {} + intended = intent.get("combined_intended_paths") or [] + if intended: + lines.append(f" intent {', '.join(intended)}") + elif git.get("dirty_count"): + lines.append(" intent none") + blockers = report.get("release_blockers") or [] + if blockers: + lines.append(" blockers") + for blocker in blockers: + lines.append(f" - {blocker.get('code')}: {blocker.get('message')}") + paths = blocker.get("paths") or [] + if paths: + preview = ", ".join(str(path) for path in paths[:8]) + suffix = f" (+{len(paths) - 8} more)" if len(paths) > 8 else "" + lines.append(f" paths: {preview}{suffix}") + warnings = report.get("warnings") or [] + if warnings: + lines.append(" warnings") + for warning in warnings: + lines.append(f" - {warning.get('code')}: {warning.get('message')}") + lines.append(f" verdict {'release allowed' if report.get('release_allowed') else 'do not tag'}") + return "\n".join(lines) diff --git a/dhee/router/handlers.py b/dhee/router/handlers.py index 667995c..e8a6c69 100644 --- a/dhee/router/handlers.py +++ b/dhee/router/handlers.py @@ -320,6 +320,11 @@ def handle_dhee_read(arguments: Dict[str, Any]) -> Dict[str, Any]: file_path = str(arguments.get("file_path", "")).strip() if not file_path: return {"error": "file_path is required"} + from dhee.contract_runtime import guard_router_call, record_router_observation, router_refusal, router_result_runtime + + contract_guard = guard_router_call("dhee_read", arguments) + if not contract_guard.get("allowed", True): + return router_refusal(contract_guard) offset_raw = arguments.get("offset") limit_raw = arguments.get("limit") @@ -502,7 +507,7 @@ def handle_dhee_read(arguments: Dict[str, Any]) -> Dict[str, Any]: }, content_hash=_stable_context_hash("routed_read", file_path, offset, limit, content), ) - return { + result = { "ptr": stored.ptr, "digest": rendered, "line_count": d.line_count, @@ -514,6 +519,11 @@ def handle_dhee_read(arguments: Dict[str, Any]) -> Dict[str, Any]: "task_source": state_route.get("source") if state_route else ("explicit" if (task_query or task_intent) else ""), "focus_count": len(d.focus), } + observation = record_router_observation(contract_guard, result) + runtime_info = router_result_runtime(contract_guard, observation) + if runtime_info: + result["contract_runtime"] = runtime_info + return result def handle_dhee_bash(arguments: Dict[str, Any]) -> Dict[str, Any]: @@ -529,6 +539,13 @@ def handle_dhee_bash(arguments: Dict[str, Any]) -> Dict[str, Any]: cmd = str(arguments.get("command", "")).strip() if not cmd: return {"error": "command is required"} + from dhee.contract_runtime import command_preview, guard_router_call, record_router_observation, router_refusal, router_result_runtime + + contract_guard = guard_router_call("dhee_bash", arguments) + if not contract_guard.get("allowed", True): + denial = router_refusal(contract_guard) + denial["command_preview"] = command_preview(cmd) + return denial try: timeout = float(arguments.get("timeout", BASH_DEFAULT_TIMEOUT)) @@ -680,7 +697,7 @@ def handle_dhee_bash(arguments: Dict[str, Any]) -> Dict[str, Any]: }, content_hash=_stable_context_hash("routed_bash", cwd or os.getcwd(), cmd, exit_code, raw_blob), ) - return { + result = { "ptr": stored.ptr, "digest": rendered, "exit_code": exit_code, @@ -692,6 +709,11 @@ def handle_dhee_bash(arguments: Dict[str, Any]) -> Dict[str, Any]: "inlined": inlined, "preflight": preflight, } + observation = record_router_observation(contract_guard, result) + runtime_info = router_result_runtime(contract_guard, observation) + if runtime_info: + result["contract_runtime"] = runtime_info + return result def handle_dhee_agent(arguments: Dict[str, Any]) -> Dict[str, Any]: @@ -770,6 +792,11 @@ def handle_dhee_grep(arguments: Dict[str, Any]) -> Dict[str, Any]: path = arguments.get("path") if not isinstance(path, str) or not path: path = "." + from dhee.contract_runtime import guard_router_call, record_router_observation, router_refusal, router_result_runtime + + contract_guard = guard_router_call("dhee_grep", {**arguments, "path": path}) + if not contract_guard.get("allowed", True): + return router_refusal(contract_guard) glob = arguments.get("glob") if glob is not None and not isinstance(glob, str): @@ -899,7 +926,7 @@ def handle_dhee_grep(arguments: Dict[str, Any]) -> Dict[str, Any]: }, content_hash=_stable_context_hash("routed_grep", path, pattern, glob, raw), ) - return { + result = { "ptr": stored.ptr, "digest": rendered, "match_count": digest.match_count, @@ -909,6 +936,11 @@ def handle_dhee_grep(arguments: Dict[str, Any]) -> Dict[str, Any]: "engine": digest.engine, "inlined": inlined, } + observation = record_router_observation(contract_guard, result) + runtime_info = router_result_runtime(contract_guard, observation) + if runtime_info: + result["contract_runtime"] = runtime_info + return result def _slice_by_range(content: str, range_spec: Any) -> tuple[str, dict[str, Any]]: diff --git a/dhee/router/pre_tool_gate.py b/dhee/router/pre_tool_gate.py index 0cea718..a79851e 100644 --- a/dhee/router/pre_tool_gate.py +++ b/dhee/router/pre_tool_gate.py @@ -27,6 +27,7 @@ from __future__ import annotations import os +import json import re from pathlib import Path from typing import Any @@ -89,14 +90,56 @@ def _deny(reason: str, steer: str) -> dict[str, Any]: } +def _truthy(value: Any) -> bool: + return str(value or "").strip().lower() in {"1", "true", "yes", "on"} + + +def _candidate_repo(inp: dict[str, Any]) -> Path: + raw = inp.get("cwd") or inp.get("file_path") or inp.get("path") or os.getcwd() + path = Path(str(raw or ".")).expanduser() + if not path.is_absolute(): + path = Path(os.getcwd()) / path + if path.suffix or (path.exists() and path.is_file()): + path = path.parent + path = path.resolve() + for current in [path, *path.parents]: + if (current / ".git").exists() or (current / ".dhee").exists(): + return current + return path + + +def _fallback_enforcement_mode(inp: dict[str, Any]) -> str: + if _truthy(os.environ.get("DHEE_REQUIRE_ACTIVE_CONTRACT")): + return "deny" + repo = _candidate_repo(inp) + policy_path = repo / ".dhee" / "context" / "task_runs" / "enforcement.json" + if not policy_path.exists(): + return "off" + try: + data = json.loads(policy_path.read_text(encoding="utf-8")) + except Exception: + return "deny" + mode = str((data if isinstance(data, dict) else {}).get("mode") or "").strip().lower() + return mode if mode in {"off", "warn", "deny"} else "deny" + + +def _enforcement_mode_for_input(inp: dict[str, Any]) -> str: + if _truthy(os.environ.get("DHEE_REQUIRE_ACTIVE_CONTRACT")): + return "deny" + try: + from dhee.contract_runtime import contract_enforcement_status + + return str(contract_enforcement_status(repo=str(_candidate_repo(inp))).get("mode") or "off") + except Exception: + return _fallback_enforcement_mode(inp) + + def evaluate(payload: dict[str, Any]) -> dict[str, Any]: """Decide whether to allow or deny a native tool call. Returns ``{}`` for allow (pass-through). Returns a deny block when enforcement is on and heuristics fire. """ - if not _enforce_on(): - return {} if not isinstance(payload, dict): return {} @@ -105,6 +148,13 @@ def evaluate(payload: dict[str, Any]) -> dict[str, Any]: if not isinstance(tool_input, dict): tool_input = {} + contract_denial = _evaluate_contract_supervisor(str(tool), tool_input) + if contract_denial: + return contract_denial + + if not _enforce_on(): + return {} + if tool == "Read": return _evaluate_read(tool_input) if tool == "Bash": @@ -114,6 +164,37 @@ def evaluate(payload: dict[str, Any]) -> dict[str, Any]: return {} +def _evaluate_contract_supervisor(tool: str, inp: dict[str, Any]) -> dict[str, Any]: + if tool not in {"Read", "Bash", "Grep", "Edit", "Write", "MultiEdit", "NotebookEdit"}: + return {} + try: + from dhee.contract_runtime import guard_router_call, router_refusal + + guard = guard_router_call(tool, inp) + if guard.get("allowed", True): + return {} + refusal = router_refusal(guard) + codes = ", ".join(refusal.get("violation_codes") or []) + reason = f"Contract supervisor denied {tool}: {codes or refusal.get('message')}" + steer = ( + "Activate a task contract with `dhee context task activate ` " + "or satisfy the compiled proof obligations before retrying. " + f"Violation codes: {codes or 'none'}. " + f"Decision: {json.dumps(refusal, sort_keys=True, default=str)[:1200]}" + ) + return _deny(reason, steer) + except Exception as exc: + if _enforcement_mode_for_input(inp) == "deny": + reason = f"Contract supervisor unavailable for {tool}: {type(exc).__name__}" + steer = ( + "Dhee contract enforcement is in deny mode, so native coding tools " + "cannot run while the supervisor is unavailable. " + f"Violation codes: CONTRACT_SUPERVISOR_UNAVAILABLE. Error: {exc}" + ) + return _deny(reason, steer) + return {} + + def _evaluate_grep(inp: dict[str, Any]) -> dict[str, Any]: """Steer native Grep onto dhee_grep. diff --git a/dhee/runtime_io.py b/dhee/runtime_io.py new file mode 100644 index 0000000..47967f4 --- /dev/null +++ b/dhee/runtime_io.py @@ -0,0 +1,328 @@ +"""Durable, corruption-aware runtime file I/O for Dhee. + +The contract runtime is a safety boundary, so its state files must not behave +like casual cache files. JSON writes are atomic, JSONL appends are locked, and +read failures are returned as structured diagnostics instead of being treated as +missing state. +""" + +from __future__ import annotations + +import json +import os +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional + +try: # pragma: no cover - Windows fallback; CI and target runtime are POSIX. + import fcntl +except Exception: # pragma: no cover + fcntl = None # type: ignore[assignment] + + +JsonSanitizer = Callable[[Any], Any] + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _diagnostic(path: Path, code: str, message: str, **extra: Any) -> Dict[str, Any]: + return { + "code": code, + "message": message, + "path": str(path), + "created_at": _now_iso(), + **extra, + } + + +def _fsync_dir(path: Path) -> None: + try: + fd = os.open(str(path), os.O_RDONLY) + except OSError: + return + try: + os.fsync(fd) + finally: + os.close(fd) + + +def write_json_atomic( + path: str | os.PathLike[str], + data: Any, + *, + sanitize: Optional[JsonSanitizer] = None, +) -> Dict[str, Any]: + """Atomically replace *path* with JSON data and fsync the file + directory.""" + + target = Path(path).expanduser() + target.parent.mkdir(parents=True, exist_ok=True) + payload = sanitize(data) if sanitize else data + fd = -1 + tmp_name = "" + try: + fd, tmp_name = tempfile.mkstemp( + prefix=f".{target.name}.", + suffix=".tmp", + dir=str(target.parent), + text=True, + ) + with os.fdopen(fd, "w", encoding="utf-8") as fh: + fd = -1 + json.dump(payload, fh, indent=2, sort_keys=True, default=str) + fh.write("\n") + fh.flush() + os.fsync(fh.fileno()) + os.replace(tmp_name, target) + _fsync_dir(target.parent) + return { + "ok": True, + "path": str(target), + "bytes": target.stat().st_size if target.exists() else None, + } + except Exception as exc: + if fd >= 0: + try: + os.close(fd) + except OSError: + pass + if tmp_name: + try: + Path(tmp_name).unlink(missing_ok=True) + except OSError: + pass + return { + "ok": False, + "path": str(target), + "diagnostic": _diagnostic( + target, + "ATOMIC_JSON_WRITE_FAILED", + f"{type(exc).__name__}: {exc}", + ), + } + + +def append_jsonl_locked( + path: str | os.PathLike[str], + item: Any, + *, + sanitize: Optional[JsonSanitizer] = None, +) -> Dict[str, Any]: + """Append one JSONL record while holding a sibling lock file.""" + + target = Path(path).expanduser() + target.parent.mkdir(parents=True, exist_ok=True) + lock_path = target.with_name(target.name + ".lock") + payload = sanitize(item) if sanitize else item + try: + with lock_path.open("a+", encoding="utf-8") as lock_fh: + if fcntl is not None: + fcntl.flock(lock_fh.fileno(), fcntl.LOCK_EX) + try: + with target.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(payload, sort_keys=True, default=str) + "\n") + fh.flush() + os.fsync(fh.fileno()) + finally: + if fcntl is not None: + fcntl.flock(lock_fh.fileno(), fcntl.LOCK_UN) + return {"ok": True, "path": str(target)} + except Exception as exc: + return { + "ok": False, + "path": str(target), + "diagnostic": _diagnostic( + target, + "LOCKED_JSONL_APPEND_FAILED", + f"{type(exc).__name__}: {exc}", + ), + } + + +def quarantine_corrupt_file( + path: str | os.PathLike[str], + reason: str, + *, + code: str = "CORRUPT_RUNTIME_FILE_QUARANTINED", +) -> Dict[str, Any]: + """Move a corrupt runtime file aside and return quarantine metadata.""" + + source = Path(path).expanduser() + if not source.exists(): + return { + "ok": False, + "path": str(source), + "diagnostic": _diagnostic(source, "QUARANTINE_SOURCE_MISSING", "File is already missing."), + } + stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + quarantine_path = source.with_name(f"{source.name}.corrupt.{stamp}") + suffix = 1 + while quarantine_path.exists(): + quarantine_path = source.with_name(f"{source.name}.corrupt.{stamp}.{suffix}") + suffix += 1 + try: + os.replace(source, quarantine_path) + _fsync_dir(source.parent) + return { + "ok": True, + "path": str(source), + "quarantine_path": str(quarantine_path), + "diagnostic": _diagnostic(source, code, reason, quarantine_path=str(quarantine_path)), + } + except Exception as exc: + return { + "ok": False, + "path": str(source), + "diagnostic": _diagnostic( + source, + "QUARANTINE_FAILED", + f"{type(exc).__name__}: {exc}", + ), + } + + +def read_json_checked( + path: str | os.PathLike[str], + *, + expected_schema: Optional[str] = None, + quarantine: bool = False, +) -> Dict[str, Any]: + """Read JSON and return data plus structured diagnostics. + + Missing files are not corrupt. Decode errors, non-object data for runtime + state, and schema mismatches are diagnostics the caller can surface or block + on. + """ + + target = Path(path).expanduser() + if not target.exists(): + return { + "ok": False, + "exists": False, + "path": str(target), + "data": None, + "diagnostics": [ + _diagnostic(target, "RUNTIME_FILE_MISSING", "Runtime file does not exist.") + ], + } + try: + text = target.read_text(encoding="utf-8") + except Exception as exc: + diagnostic = _diagnostic(target, "RUNTIME_FILE_READ_FAILED", f"{type(exc).__name__}: {exc}") + return {"ok": False, "exists": True, "path": str(target), "data": None, "diagnostics": [diagnostic]} + try: + data = json.loads(text) + except Exception as exc: + diagnostic = _diagnostic(target, "RUNTIME_JSON_CORRUPT", f"{type(exc).__name__}: {exc}") + diagnostics = [diagnostic] + quarantine_result = quarantine_corrupt_file(target, diagnostic["message"]) if quarantine else None + if quarantine_result: + diagnostics.append(quarantine_result.get("diagnostic") or {}) + return { + "ok": False, + "exists": True, + "path": str(target), + "data": None, + "diagnostics": diagnostics, + "quarantine": quarantine_result, + } + if not isinstance(data, dict): + diagnostic = _diagnostic(target, "RUNTIME_JSON_NOT_OBJECT", "Runtime JSON root must be an object.") + diagnostics = [diagnostic] + quarantine_result = quarantine_corrupt_file(target, diagnostic["message"]) if quarantine else None + if quarantine_result: + diagnostics.append(quarantine_result.get("diagnostic") or {}) + return { + "ok": False, + "exists": True, + "path": str(target), + "data": None, + "diagnostics": diagnostics, + "quarantine": quarantine_result, + } + diagnostics: List[Dict[str, Any]] = [] + if expected_schema: + observed = data.get("schema_version") or data.get("format") + if observed != expected_schema: + diagnostics.append( + _diagnostic( + target, + "RUNTIME_SCHEMA_MISMATCH", + "Runtime JSON schema version does not match the expected schema.", + expected_schema=expected_schema, + observed_schema=observed, + ) + ) + return { + "ok": not diagnostics, + "exists": True, + "path": str(target), + "data": data, + "diagnostics": diagnostics, + } + + +def read_jsonl_checked( + path: str | os.PathLike[str], + *, + quarantine_on_corrupt: bool = False, +) -> Dict[str, Any]: + """Read JSONL records and surface every corrupt line as a diagnostic.""" + + target = Path(path).expanduser() + if not target.exists(): + return { + "ok": True, + "exists": False, + "path": str(target), + "records": [], + "diagnostics": [], + } + records: List[Dict[str, Any]] = [] + diagnostics: List[Dict[str, Any]] = [] + try: + lines: Iterable[str] = target.read_text(encoding="utf-8").splitlines() + except Exception as exc: + diagnostic = _diagnostic(target, "RUNTIME_JSONL_READ_FAILED", f"{type(exc).__name__}: {exc}") + return {"ok": False, "exists": True, "path": str(target), "records": [], "diagnostics": [diagnostic]} + for line_no, line in enumerate(lines, start=1): + if not line.strip(): + continue + try: + data = json.loads(line) + except Exception as exc: + diagnostics.append( + _diagnostic( + target, + "RUNTIME_JSONL_LINE_CORRUPT", + f"{type(exc).__name__}: {exc}", + line=line_no, + ) + ) + continue + if not isinstance(data, dict): + diagnostics.append( + _diagnostic( + target, + "RUNTIME_JSONL_LINE_NOT_OBJECT", + "Runtime JSONL line root must be an object.", + line=line_no, + ) + ) + continue + records.append(data) + quarantine_result = None + if diagnostics and quarantine_on_corrupt: + quarantine_result = quarantine_corrupt_file(target, "JSONL file contains corrupt records.") + if quarantine_result: + diagnostics.append(quarantine_result.get("diagnostic") or {}) + return { + "ok": not diagnostics, + "exists": True, + "path": str(target), + "records": records, + "diagnostics": diagnostics, + "quarantine": quarantine_result, + } diff --git a/dhee/simple.py b/dhee/simple.py index ac726ff..57fb5ff 100644 --- a/dhee/simple.py +++ b/dhee/simple.py @@ -20,6 +20,8 @@ import logging import os +import re +import sqlite3 import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -78,6 +80,35 @@ def _get_embedding_dims(provider: str) -> int: return 384 +def _existing_sqlite_vec_dims(db_path: Path, collection_name: str) -> Optional[int]: + """Return the stored sqlite-vec dimension for an existing collection. + + Dhee is a long-lived memory layer. A user's memory database may outlive + provider/model changes, so we must not silently switch from a 384-dim local + store to a 2048-dim hosted embedder and make writes fail. + """ + if not db_path.exists(): + return None + vec_table = f"vec_{collection_name}" + try: + conn = sqlite3.connect(str(db_path)) + try: + row = conn.execute( + "SELECT sql FROM sqlite_master WHERE name = ?", + (vec_table,), + ).fetchone() + finally: + conn.close() + except sqlite3.Error: + return None + if not row or not row[0]: + return None + match = re.search(r"embedding\s+float\[(\d+)\]", str(row[0]), re.IGNORECASE) + if not match: + return None + return int(match.group(1)) + + def _has_api_key() -> bool: try: from dhee.cli_config import get_api_key @@ -87,7 +118,10 @@ def _has_api_key() -> bool: os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") or (get_api_key and (get_api_key("gemini") or get_api_key("openai"))) - ) +) + +DEFAULT_NVIDIA_LLM_MODEL = "nvidia/nemotron-3-nano-omni-30b-a3b-reasoning" +DEFAULT_NVIDIA_EMBEDDER_MODEL = "nvidia/llama-nemotron-embed-vl-1b-v2" def _get_data_dir() -> Path: @@ -139,8 +173,17 @@ def __init__( self._data_dir = Path(data_dir) if data_dir else _get_data_dir() self._data_dir.mkdir(parents=True, exist_ok=True) - # Build configuration - embedding_dims = _get_embedding_dims(self._provider) + # Build configuration. If a persistent sqlite-vec collection already + # exists, keep its dimension stable and use a compatible local embedder + # when the selected provider would produce a different vector size. + provider_embedding_dims = _get_embedding_dims(self._provider) + existing_embedding_dims = None + if not in_memory: + existing_embedding_dims = _existing_sqlite_vec_dims( + self._data_dir / "sqlite_vec.db", + collection_name, + ) + embedding_dims = existing_embedding_dims or provider_embedding_dims if in_memory: vector_config = VectorStoreConfig( @@ -162,12 +205,30 @@ def __init__( llm_provider = "mock" if self._provider == "mock" else self._provider embedder_provider = "simple" if self._provider == "mock" else self._provider + if existing_embedding_dims and existing_embedding_dims != provider_embedding_dims: + embedder_provider = "simple" + logger.info( + "Using simple embedder with existing sqlite-vec dimension %s " + "instead of %s provider dimension %s", + existing_embedding_dims, + self._provider, + provider_embedding_dims, + ) + llm_kwargs: Dict[str, Any] = {} embedder_kwargs: Dict[str, Any] = {} + if llm_provider == "nvidia": + llm_kwargs["config"] = { + "model": DEFAULT_NVIDIA_LLM_MODEL, + "temperature": 0.2, + "max_tokens": 4096, + } if embedder_provider == "simple": embedder_kwargs["config"] = {"embedding_dims": embedding_dims} + elif embedder_provider == "nvidia": + embedder_kwargs["config"] = {"model": DEFAULT_NVIDIA_EMBEDDER_MODEL} config = MemoryConfig( - llm=LLMConfig(provider=llm_provider), + llm=LLMConfig(provider=llm_provider, **llm_kwargs), embedder=EmbedderConfig(provider=embedder_provider, **embedder_kwargs), vector_store=vector_config, fade=FadeMemConfig( @@ -348,6 +409,46 @@ def delete(self, memory_id: str) -> None: """ self._memory.delete(memory_id) + def sweep_admission( + self, + user_id: str = "default", + agent_id: Optional[str] = None, + limit: int = 10_000, + dry_run: bool = True, + ) -> Dict[str, Any]: + """Find or delete memories that current Dhee admission rules reject. + + This is the retroactive half of Dhee's memory hygiene: agents can improve + admission policy over time, then use this sweep to remove legacy passive + observations that no longer meet the bar. + """ + from dhee.memory.admission import forget_reason_for_memory + + memories = self.get_all(user_id=user_id, agent_id=agent_id, limit=limit) + candidates: List[Dict[str, Any]] = [] + deleted: List[Dict[str, Any]] = [] + for memory in memories: + reason = forget_reason_for_memory(memory) + if not reason: + continue + item = { + "id": memory.get("id"), + "reason": reason, + "memory": str(memory.get("memory") or "")[:240], + } + candidates.append(item) + if not dry_run and memory.get("id"): + self.delete(str(memory["id"])) + deleted.append(item) + return { + "dry_run": dry_run, + "scanned_count": len(memories), + "candidate_count": len(candidates), + "deleted_count": len(deleted), + "candidates": candidates, + "deleted": deleted, + } + def forget( self, user_id: Optional[str] = None, @@ -524,8 +625,21 @@ def remember( if isinstance(result, dict): rs = result.get("results", []) if rs: - memory_id = rs[0].get("id") + first = rs[0] + event = first.get("event") + if event in {"SKIP", "BLOCKED"}: + return { + "stored": False, + "event": event, + "reason": first.get("reason"), + "admission": first.get("admission"), + "queued": False, + "degraded": False, + } + memory_id = first.get("id") response["id"] = memory_id + if first.get("admission") is not None: + response["admission"] = first.get("admission") if tier == "shruti": response["tier"] = "shruti" @@ -539,6 +653,21 @@ def remember( response["detected_intention"] = intention.to_dict() return response + def sweep_admission( + self, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + limit: int = 10_000, + dry_run: bool = True, + ) -> Dict[str, Any]: + """Find or delete memories rejected by the current admission policy.""" + return self._engram.sweep_admission( + user_id=user_id or self._user_id, + agent_id=agent_id, + limit=limit, + dry_run=dry_run, + ) + # ------------------------------------------------------------------ # Tool 2: recall # ------------------------------------------------------------------ diff --git a/dhee/task_contracts.py b/dhee/task_contracts.py new file mode 100644 index 0000000..a493524 --- /dev/null +++ b/dhee/task_contracts.py @@ -0,0 +1,1780 @@ +"""Deterministic task contracts and Chotu action plans. + +Dhee's compiler role is to turn a noisy user request plus repo state into a +bounded, machine-checkable action contract. This module is intentionally +heuristic and deterministic: it does not ask an LLM to write a plan. +""" + +from __future__ import annotations + +import hashlib +import ast +import json +import os +import re +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from dhee import repo_link + + +TASK_CONTRACT_SCHEMA = "dhee.task_contract.v1" +ACTION_PLAN_SCHEMA = "dhee.chotu_action_plan.v1" +ACTION_BYTECODE_SCHEMA = "dhee.chotu_action_bytecode.v1" +CONTRACT_COMPILER_SCHEMA = "dhee.contract_compiler.v1" +CONTEXT_LEDGER_SCHEMA = "dhee.context_ledger.v1" +REPO_INTELLIGENCE_SCHEMA = "dhee.repo_intelligence.v1" +VERIFICATION_CARD_SCHEMA = "dhee.verification_card.v1" +CONTAMINATION_STATUS_SCHEMA = "dhee.contamination_status.v1" +TASK_INTERPRETATION_SCHEMA = "dhee.task_contract_interpretation.v1" +TASK_CONTRACT_KIND = "task_contract" + +ACTION_TYPES = { + "READ_FILE", + "SEARCH_CODE", + "LSP_SYMBOL", + "RUN_TEST", + "EDIT_FILE", + "ASK_USER", + "SPAWN_SUBAGENT", + "WRITE_MEMORY_NOTE", + "SUBMIT_PATCH", +} + +DEFAULT_CONTEXT_BUDGET = { + "state_card_tokens": 1500, + "retrieved_memory_tokens": 3000, + "repo_context_tokens": 6000, + "tool_output_tokens": 2000, +} +DEFAULT_FORBIDDEN_PATHS = [".env", ".env.*", "secrets/", "prod-config/"] +DEFAULT_FORBIDDEN_ACTIONS = [ + "git reset --hard", + "git checkout --", + "rm -rf", + "write secrets", + "edit generated capsule imports without request", +] +DEFAULT_SUCCESS_CRITERIA = [ + "target tests pass", + "no unrelated files changed", + "diff is reviewable", + "memory note created if failure pattern is reusable", +] +_SECRET_PATTERNS = [ + re.compile(r"sk-[A-Za-z0-9_-]{20,}"), + re.compile(r"gh[pousr]_[A-Za-z0-9_]{20,}"), + re.compile(r"xox[baprs]-[A-Za-z0-9-]{20,}"), + re.compile(r"(?i)\b(api[_-]?key|token|secret|password|passwd)\b\s*[:=]\s*['\"]?[^'\"\s]{8,}"), + re.compile(r"eyJ[A-Za-z0-9_-]{20,}\.[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}"), +] +_LOCAL_PATH_RE = re.compile(r"(/Users/[^\s\"']+|/home/[^\s\"']+|[A-Za-z]:\\\\[^\s\"']+)") +_EXCLUDED_DIRS = { + ".git", + ".dhee", + ".hg", + ".svn", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".venv", + "venv", + "env", + "__pycache__", + "node_modules", + "dist", + "build", + ".next", +} +_STOP_WORDS = { + "a", + "an", + "and", + "are", + "bug", + "fix", + "for", + "in", + "is", + "it", + "of", + "on", + "please", + "test", + "tests", + "the", + "this", + "to", + "with", +} +_ACTION_OPERAND_FIELDS = ( + "path", + "command", + "query", + "scope", + "symbol", + "question", + "role", + "task", + "category", + "summary", +) +_ACTION_SEMANTICS = { + "SEARCH_CODE": { + "phase": "discover", + "op": "repo.search", + "capabilities": ["repo.search"], + "effects": ["context.matches_observed"], + }, + "READ_FILE": { + "phase": "inspect", + "op": "fs.read", + "capabilities": ["fs.read"], + "effects": ["context.file_observed"], + }, + "LSP_SYMBOL": { + "phase": "inspect", + "op": "lsp.symbol", + "capabilities": ["lsp.lookup"], + "effects": ["context.symbol_observed"], + }, + "RUN_TEST": { + "phase": "verify", + "op": "shell.test", + "capabilities": ["shell.test"], + "effects": ["verification.test_observed"], + }, + "EDIT_FILE": { + "phase": "mutate", + "op": "fs.patch", + "capabilities": ["fs.write"], + "effects": ["repo.diff_mutated"], + }, + "ASK_USER": { + "phase": "clarify", + "op": "user.ask", + "capabilities": ["user.ask"], + "effects": ["context.user_input_observed"], + }, + "SPAWN_SUBAGENT": { + "phase": "delegate", + "op": "agent.spawn", + "capabilities": ["agent.spawn"], + "effects": ["context.parallel_result_observed"], + }, + "WRITE_MEMORY_NOTE": { + "phase": "learn", + "op": "memory.write", + "capabilities": ["memory.write"], + "effects": ["memory.lesson_recorded"], + }, + "SUBMIT_PATCH": { + "phase": "submit", + "op": "patch.submit", + "capabilities": ["patch.submit"], + "effects": ["handoff.patch_submitted"], + }, +} +_COMPILER_PASSES = [ + {"name": "issue_parse", "kind": "analysis", "purpose": "Normalize the user issue into goal, constraints, and ambiguity signals."}, + {"name": "repo_index", "kind": "analysis", "purpose": "Build a git-SHA scoped repo brain: symbols, imports, tests, dependencies, and risk signals."}, + {"name": "env_probe", "kind": "analysis", "purpose": "Infer deterministic setup and execution commands without running arbitrary code."}, + {"name": "test_discovery", "kind": "analysis", "purpose": "Select fail-to-pass, pass-to-pass, nearest, smoke, static, and security checks."}, + {"name": "localization", "kind": "analysis", "purpose": "Localize candidate files and symbols with evidence pointers and confidence."}, + {"name": "context_pack", "kind": "budgeting", "purpose": "Compile ranked context items with why, pointer, token cost, freshness, confidence, and expected utility."}, + {"name": "patch_strategy", "kind": "planning", "purpose": "Emit allowed patch families and edit proof obligations, not free-form plans."}, + {"name": "verification_plan", "kind": "verification", "purpose": "Create a VerificationCard the runtime can check before submit."}, + {"name": "replay_plan", "kind": "verification", "purpose": "Define branchable checkpoints for localization, edit, failed-test, and submit boundaries."}, + {"name": "memory_policy", "kind": "safety", "purpose": "Permit pointer-backed lessons only after verification and contamination checks."}, +] + + +def _stable_hash(data: Any, length: int = 16) -> str: + raw = json.dumps(data, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:length] + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _sanitize_text(text: str) -> str: + value = str(text or "") + home = str(Path.home()) + if home: + value = value.replace(home, "$HOME") + value = _LOCAL_PATH_RE.sub("", value) + for pattern in _SECRET_PATTERNS: + value = pattern.sub("", value) + return value + + +def _sanitize_obj(value: Any) -> Any: + if isinstance(value, str): + return _sanitize_text(value) + if isinstance(value, list): + return [_sanitize_obj(item) for item in value] + if isinstance(value, tuple): + return [_sanitize_obj(item) for item in value] + if isinstance(value, dict): + return {str(key): _sanitize_obj(item) for key, item in value.items()} + return value + + +def _resolve_repo_root(repo: str | os.PathLike[str] | None) -> Path: + base = Path(repo or os.getcwd()).expanduser().resolve() + proc = subprocess.run( + ["git", "-C", str(base), "rev-parse", "--show-toplevel"], + text=True, + capture_output=True, + check=False, + ) + if proc.returncode == 0 and proc.stdout.strip(): + return Path(proc.stdout.strip()).resolve() + return base + + +def _git_out(repo_root: Path, args: Sequence[str], default: str = "") -> str: + proc = subprocess.run( + ["git", "-C", str(repo_root), *args], + text=True, + capture_output=True, + check=False, + ) + if proc.returncode != 0: + return default + return proc.stdout.strip() + + +def _repo_slug(repo_root: Path) -> str: + remote = _git_out(repo_root, ["remote", "get-url", "origin"], default="") + if remote: + value = remote.rstrip("/") + value = re.sub(r"\.git$", "", value) + if ":" in value and "/" in value: + value = value.split(":", 1)[1] + else: + parts = value.split("/") + value = "/".join(parts[-2:]) if len(parts) >= 2 else parts[-1] + if value: + return value + return repo_root.name + + +def _tokens(text: str) -> List[str]: + out: List[str] = [] + for token in re.findall(r"[A-Za-z0-9_]+", str(text or "").lower()): + if len(token) < 3 or token in _STOP_WORDS: + continue + if token not in out: + out.append(token) + return out + + +def _path_tokens(path: str) -> str: + return str(path).replace("_", " ").replace("-", " ").replace("/", " ").lower() + + +def _branch_state(repo_root: Path) -> Dict[str, Any]: + status = _git_out(repo_root, ["status", "--porcelain=v1", "--untracked-files=all"], default="") + staged: List[str] = [] + unstaged: List[str] = [] + untracked: List[str] = [] + changed: List[str] = [] + for line in status.splitlines(): + if not line: + continue + code = line[:2] + path = (line[3:] if len(line) > 2 and line[2] == " " else line[2:]).strip() + if " -> " in path: + _old, path = path.split(" -> ", 1) + changed.append(path) + if code.startswith("??"): + untracked.append(path) + else: + if code[0] != " ": + staged.append(path) + if len(code) > 1 and code[1] != " ": + unstaged.append(path) + return { + "branch": _git_out(repo_root, ["branch", "--show-current"], default=""), + "head_commit": _git_out(repo_root, ["rev-parse", "--short", "HEAD"], default=""), + "dirty": bool(changed), + "staged": staged, + "unstaged": unstaged, + "untracked": untracked, + "changed_paths": sorted(set(changed)), + } + + +def _iter_repo_files(repo_root: Path, limit: int = 4_000) -> List[str]: + files: List[str] = [] + for root, dirnames, filenames in os.walk(repo_root): + dirnames[:] = [ + name + for name in dirnames + if name not in _EXCLUDED_DIRS and not name.endswith(".egg-info") + ] + for filename in filenames: + path = Path(root) / filename + try: + rel = os.path.relpath(path, repo_root).replace(os.sep, "/") + except ValueError: + continue + files.append(rel) + if len(files) >= limit: + return sorted(files) + return sorted(files) + + +def _score_file(path: str, tokens: Sequence[str]) -> int: + haystack = _path_tokens(path) + name = _path_tokens(Path(path).name) + score = 0 + for token in tokens: + if token in name: + score += 4 + elif token in haystack: + score += 2 + if path.startswith("tests/"): + score += 1 + return score + + +def _relevant_files(repo_root: Path, goal: str, branch_state: Dict[str, Any], limit: int = 12) -> List[str]: + tokens = _tokens(goal) + scored: List[Tuple[int, str]] = [] + for path in _iter_repo_files(repo_root): + score = _score_file(path, tokens) + if score > 0: + scored.append((score, path)) + for path in branch_state.get("changed_paths") or []: + if path and not str(path).startswith(".dhee/"): + scored.append((100, str(path))) + scored.sort(key=lambda item: (-item[0], item[1])) + out: List[str] = [] + for _score, path in scored: + if path not in out: + out.append(path) + if len(out) >= limit: + break + return out + + +def _affected_modules(relevant_files: Sequence[str], branch_state: Dict[str, Any]) -> List[str]: + modules: List[str] = [] + for path in list(relevant_files) + list(branch_state.get("changed_paths") or []): + if not path or str(path).startswith(".dhee/"): + continue + parts = Path(path).parts + if len(parts) > 1: + module = "/".join(parts[:2]) if parts[0] not in {"tests", "docs"} else parts[0] + else: + module = parts[0] + if module not in modules: + modules.append(module) + return modules[:12] + + +def _known_architecture(repo_root: Path) -> Dict[str, Any]: + files = set(_iter_repo_files(repo_root, limit=1_000)) + package_roots = [ + path + for path in sorted({item.split("/", 1)[0] for item in files if "/" in item}) + if (repo_root / path / "__init__.py").exists() + ] + entrypoints = [ + path + for path in ("dhee/mcp_server.py", "dhee/mcp_slim.py", "dhee/cli.py", "pyproject.toml", "package.json") + if (repo_root / path).exists() + ] + return { + "language": "python" if "pyproject.toml" in files or any(item.endswith(".py") for item in files) else "unknown", + "test_framework": "pytest" if "pytest.ini" in files or "pyproject.toml" in files or any(item.startswith("tests/test_") for item in files) else "unknown", + "package_roots": package_roots[:8], + "entrypoints": entrypoints, + } + + +def _infer_test_commands(repo_root: Path, goal: str, relevant_files: Sequence[str], must_run: Optional[Iterable[str]]) -> List[str]: + if must_run: + return [str(cmd) for cmd in must_run if str(cmd).strip()] + tokens = _tokens(goal) + files = _iter_repo_files(repo_root, limit=4_000) + commands: List[str] = [] + for path in files: + if not path.startswith("tests/") or not path.endswith(".py"): + continue + if any(token in _path_tokens(path) for token in tokens): + commands.append(f"pytest {path}") + for rel in relevant_files: + stem = Path(rel).stem + if not stem or stem == "__init__": + continue + expected = f"tests/test_{stem}.py" + if expected in files: + command = f"pytest {expected}" + if command not in commands: + commands.append(command) + if not commands: + commands.append("pytest") + return commands[:6] + + +def _default_allowed_write_paths(repo_root: Path, affected_modules: Sequence[str]) -> List[str]: + paths: List[str] = [] + for module in affected_modules: + first = module.split("/", 1)[0] + candidate = f"{first}/" if (repo_root / first).is_dir() else module + if candidate not in paths and not candidate.startswith(".dhee"): + paths.append(candidate) + if (repo_root / "tests").is_dir() and "tests/" not in paths: + paths.append("tests/") + if not paths: + if (repo_root / "dhee").is_dir(): + paths.append("dhee/") + if (repo_root / "tests").is_dir(): + paths.append("tests/") + return paths or ["."] + + +def _infer_risk(goal: str, branch_state: Dict[str, Any], recent_failures: Sequence[Dict[str, Any]]) -> str: + risky = {"auth", "security", "secret", "token", "migration", "database", "prod", "delete", "billing"} + tokens = set(_tokens(goal)) + if tokens & risky: + return "high" + if branch_state.get("dirty") or recent_failures: + return "medium" + return "low" + + +def _repo_memory_pointers(repo_root: Path, goal: str, explicit: Optional[Iterable[Dict[str, Any]]] = None, limit: int = 8) -> List[Dict[str, Any]]: + if explicit is not None: + out: List[Dict[str, Any]] = [] + for item in explicit: + if not isinstance(item, dict): + continue + enriched = dict(item) + enriched.setdefault("why_included", "Explicitly supplied to the task compiler.") + enriched.setdefault("evidence_pointer", enriched.get("ref") or enriched.get("id") or "explicit") + enriched.setdefault("token_cost", _estimate_tokens(enriched.get("title") or enriched.get("content") or "")) + enriched.setdefault("freshness", "explicit") + enriched.setdefault("confidence", 0.8) + enriched.setdefault("expected_utility", 0.7) + out.append(enriched) + if len(out) >= limit: + break + return out + tokens = _tokens(goal) + pointers: List[Dict[str, Any]] = [] + try: + entries = repo_link.list_entries(repo_root) + except Exception: + entries = [] + for entry in reversed(entries): + text = f"{entry.kind} {entry.title} {entry.content}".lower() + if tokens and not any(token in text for token in tokens): + continue + pointers.append({ + "kind": entry.kind, + "id": entry.id, + "title": entry.title, + "ref": f"repo_context:{entry.id}", + "evidence_pointer": f"repo_context:{entry.id}", + "content_hash": entry.content_hash, + "why_included": "Repo-shared context matched the compiled task tokens.", + "token_cost": _estimate_tokens(f"{entry.title}\n{entry.content}"), + "freshness": getattr(entry, "updated_at", None) or getattr(entry, "created_at", None) or "unknown", + "confidence": 0.72, + "expected_utility": 0.68, + }) + if len(pointers) >= limit: + break + return pointers + + +def _estimate_tokens(value: Any) -> int: + text = str(value or "") + if not text: + return 0 + return max(1, int(len(text) / 3.8)) + + +def _repo_brain_root(repo_root: Path) -> Path: + return repo_link.repo_context_dir(repo_root) / "repo_brain" + + +def _python_symbol_index(repo_root: Path, files: Sequence[str]) -> Tuple[List[Dict[str, Any]], Dict[str, List[str]], List[Dict[str, Any]]]: + symbols: List[Dict[str, Any]] = [] + imports: Dict[str, List[str]] = {} + call_edges: List[Dict[str, Any]] = [] + for rel in files: + if not str(rel).endswith(".py"): + continue + path = repo_root / rel + if not path.exists() or path.stat().st_size > 512_000: + continue + try: + tree = ast.parse(path.read_text(encoding="utf-8", errors="replace")) + except Exception: + continue + imported: List[str] = [] + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + symbols.append({ + "path": rel, + "name": node.name, + "kind": "class" if isinstance(node, ast.ClassDef) else "function", + "line": getattr(node, "lineno", 0), + }) + for child in ast.walk(node): + if isinstance(child, ast.Call): + target = "" + if isinstance(child.func, ast.Name): + target = child.func.id + elif isinstance(child.func, ast.Attribute): + target = child.func.attr + if target: + call_edges.append({ + "path": rel, + "caller": node.name, + "callee": target, + "line": getattr(child, "lineno", 0), + }) + elif isinstance(node, ast.Import): + imported.extend(alias.name for alias in node.names if alias.name) + elif isinstance(node, ast.ImportFrom): + module = "." * int(node.level or 0) + str(node.module or "") + if module: + imported.append(module) + if imported: + imports[rel] = sorted(set(imported)) + return symbols[:500], imports, call_edges[:1_000] + + +def _test_map(files: Sequence[str], relevant_files: Sequence[str], must_run: Sequence[str]) -> Dict[str, Any]: + tests = [path for path in files if str(path).startswith("tests/") and str(path).endswith(".py")] + source_to_tests: Dict[str, List[str]] = {} + for rel in relevant_files: + stem = Path(rel).stem + candidates = [test for test in tests if stem and stem in _path_tokens(test)] + if candidates: + source_to_tests[rel] = candidates[:8] + return { + "tests": tests[:300], + "source_to_tests": source_to_tests, + "must_run": list(must_run), + } + + +def _setup_commands(files: Sequence[str]) -> List[str]: + commands: List[str] = [] + file_set = set(files) + if "pyproject.toml" in file_set or "setup.py" in file_set: + commands.append('pip install -e ".[dev]"') + if "requirements.txt" in file_set: + commands.append("pip install -r requirements.txt") + if "package.json" in file_set: + commands.append("npm install") + return commands[:6] + + +def _risky_files(files: Sequence[str]) -> List[Dict[str, Any]]: + risky_names = ("auth", "secret", "token", "security", "migration", "payment", "billing", "prod", "config") + out: List[Dict[str, Any]] = [] + for rel in files: + lower = str(rel).lower() + reasons = [name for name in risky_names if name in lower] + if reasons: + out.append({"path": rel, "reasons": reasons}) + if len(out) >= 80: + break + return out + + +def _historical_failure_signatures(repo_root: Path, goal: str, limit: int = 20) -> List[Dict[str, Any]]: + tokens = _tokens(goal) + try: + entries = repo_link.list_entries(repo_root) + except Exception: + return [] + out: List[Dict[str, Any]] = [] + for entry in reversed(entries): + text = f"{entry.kind} {entry.title} {entry.content}".lower() + if "fail" not in text and "error" not in text and "regression" not in text: + continue + if tokens and not any(token in text for token in tokens): + continue + out.append({ + "ref": f"repo_context:{entry.id}", + "title": entry.title, + "kind": entry.kind, + "content_hash": entry.content_hash, + }) + if len(out) >= limit: + break + return out + + +def _compile_repo_intelligence( + repo_root: Path, + *, + goal: str, + relevant_files: Sequence[str], + must_run: Sequence[str], +) -> Dict[str, Any]: + branch = _branch_state(repo_root) + files = _iter_repo_files(repo_root, limit=4_000) + python_files = [path for path in files if path.endswith(".py")] + focus = list(dict.fromkeys(list(relevant_files) + python_files[:300])) + symbols, imports, call_graph = _python_symbol_index(repo_root, focus) + data = { + "schema_version": REPO_INTELLIGENCE_SCHEMA, + "repo": _repo_slug(repo_root), + "head_commit": branch.get("head_commit"), + "generated_at": _now_iso(), + "symbols": symbols, + "imports": imports, + "call_graph": call_graph, + "test_map": _test_map(files, relevant_files, must_run), + "dependency_graph": { + "python_imports": imports, + "package_roots": _known_architecture(repo_root).get("package_roots") or [], + }, + "setup_commands": _setup_commands(files), + "flaky_tests": [], + "risky_files": _risky_files(files), + "historical_failure_signatures": _historical_failure_signatures(repo_root, goal), + } + repo_link._ensure_repo_skeleton(repo_root) + path = _repo_brain_root(repo_root) / f"{branch.get('head_commit') or 'no_head'}_{_stable_hash(goal, 8)}.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(_json_dumps(_sanitize_obj(data)) + "\n", encoding="utf-8") + return { + "schema_version": REPO_INTELLIGENCE_SCHEMA, + "ref": f"repo_brain:{path.name}", + "path": os.path.relpath(path, repo_root).replace(os.sep, "/"), + "head_commit": branch.get("head_commit"), + "symbol_count": len(symbols), + "import_file_count": len(imports), + "call_edge_count": len(call_graph), + "test_count": len(data["test_map"].get("tests") or []), + "risky_file_count": len(data["risky_files"]), + "historical_failure_count": len(data["historical_failure_signatures"]), + } + + +def _context_item( + *, + kind: str, + title: str, + evidence_pointer: str, + why_included: str, + token_cost: int, + freshness: str, + confidence: float, + expected_utility: float, + metadata: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + return { + "kind": kind, + "title": title, + "evidence_pointer": evidence_pointer, + "why_included": why_included, + "token_cost": int(token_cost), + "freshness": freshness, + "confidence": round(float(confidence), 3), + "expected_utility": round(float(expected_utility), 3), + "metadata": metadata or {}, + } + + +def _context_ledger(contract: Dict[str, Any], repo_intelligence: Dict[str, Any]) -> Dict[str, Any]: + items: List[Dict[str, Any]] = [] + branch = contract.get("current_branch_state") or {} + items.append(_context_item( + kind="branch_state", + title=f"Branch {branch.get('branch') or '(detached)'} at {branch.get('head_commit') or 'unknown'}", + evidence_pointer="git:branch_state", + why_included="Every supervised action must know the branch, dirty state, and rollback point.", + token_cost=_estimate_tokens(branch), + freshness=str(contract.get("created_at") or "compile_time"), + confidence=0.95, + expected_utility=0.9, + metadata={"dirty": bool(branch.get("dirty")), "changed_paths": branch.get("changed_paths") or []}, + )) + items.append(_context_item( + kind="repo_intelligence", + title="Git-SHA scoped repo brain", + evidence_pointer=str(repo_intelligence.get("ref") or ""), + why_included="Symbols, imports, test map, dependencies, risky files, and failure signatures constrain localization.", + token_cost=_estimate_tokens(repo_intelligence), + freshness=f"head:{repo_intelligence.get('head_commit') or 'unknown'}", + confidence=0.82, + expected_utility=0.86, + metadata={key: repo_intelligence.get(key) for key in ("symbol_count", "test_count", "risky_file_count", "historical_failure_count")}, + )) + for path in contract.get("relevant_files") or []: + items.append(_context_item( + kind="file", + title=str(path), + evidence_pointer=f"repo_file:{path}", + why_included="Localized by issue tokens, dirty state, or source-to-test mapping.", + token_cost=_estimate_tokens(path), + freshness=f"head:{branch.get('head_commit') or 'unknown'}", + confidence=0.74, + expected_utility=0.78, + metadata={"allowed_write": _path_under_allowed(str(path), contract.get("allowed_write_paths") or [])}, + )) + for command in contract.get("must_run") or []: + items.append(_context_item( + kind="test_command", + title=str(command), + evidence_pointer=f"command:{_stable_hash(command, 10)}", + why_included="Required verifier command compiled from the issue and nearby tests.", + token_cost=_estimate_tokens(command), + freshness="compile_time", + confidence=0.8, + expected_utility=0.88, + metadata={"must_run": True}, + )) + for pointer in contract.get("memory_pointers") or []: + items.append(_context_item( + kind=f"memory:{pointer.get('kind') or 'repo_context'}", + title=str(pointer.get("title") or pointer.get("id") or "memory pointer"), + evidence_pointer=str(pointer.get("evidence_pointer") or pointer.get("ref") or pointer.get("id") or ""), + why_included=str(pointer.get("why_included") or "Matched task tokens."), + token_cost=int(pointer.get("token_cost") or 1), + freshness=str(pointer.get("freshness") or "unknown"), + confidence=float(pointer.get("confidence") or 0.6), + expected_utility=float(pointer.get("expected_utility") or 0.5), + metadata={"content_hash": pointer.get("content_hash")}, + )) + return { + "schema_version": CONTEXT_LEDGER_SCHEMA, + "budget": contract.get("context_budget") or {}, + "total_token_cost": sum(int(item.get("token_cost") or 0) for item in items), + "items": sorted(items, key=lambda item: (-float(item.get("expected_utility") or 0), int(item.get("token_cost") or 0))), + "policy": { + "top_k_memory_injection": False, + "raw_evidence_expansion": "by_pointer_only", + "include_requires_why_and_pointer": True, + }, + } + + +def _verification_card(contract: Dict[str, Any], repo_intelligence: Dict[str, Any]) -> Dict[str, Any]: + relevant_tests = [ + path for path in contract.get("relevant_files") or [] + if str(path).startswith("tests/") and str(path).endswith(".py") + ] + nearest_tests = sorted(set(relevant_tests + (repo_intelligence.get("test_map") or {}).get("tests", [])[:6])) + smoke_targets = [ + path for path in contract.get("relevant_files") or [] + if str(path).endswith(".py") and not str(path).startswith("tests/") + ][:8] + import_smoke = [f"python3 -m py_compile {' '.join(smoke_targets)}"] if smoke_targets else [] + public_api_risk = "medium" if any(Path(str(path)).name == "__init__.py" for path in contract.get("relevant_files") or []) else "low" + risky_paths = {item.get("path") for item in (repo_intelligence.get("risky_files") or [])} + diff_risk = "high" if any(path in risky_paths for path in contract.get("relevant_files") or []) else contract.get("risk", "medium") + failure_text = json.dumps(contract.get("recent_failures") or [], sort_keys=True, default=str) + return { + "schema_version": VERIFICATION_CARD_SCHEMA, + "fail_to_pass_tests": list(contract.get("must_run") or []), + "pass_to_pass_tests": [cmd for cmd in list(contract.get("must_run") or []) if cmd not in failure_text], + "nearest_tests": nearest_tests[:12], + "import_smoke_tests": import_smoke, + "static_checks": import_smoke, + "security_checks": [ + "verify no forbidden path changed", + "verify no secret-like token introduced", + "verify benchmark contamination status is clean before submit", + ], + "diff_risk": diff_risk, + "public_api_risk": public_api_risk, + "submit_requirements": [ + "all fail_to_pass_tests observed as passed", + "edit proof obligations satisfied for every EDIT_FILE", + "contamination status is clean or explicitly quarantined", + "replay checkpoint exists before submit", + ], + } + + +def _contamination_status(goal: str, memory_pointers: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + text = str(goal or "").lower() + benchmark_mode = bool(os.environ.get("DHEE_BENCHMARK_MODE")) or any(token in text for token in ("swe-bench", "swe bench", "benchmark", "eval")) + risky_refs = [ + pointer for pointer in memory_pointers + if any(marker in f"{pointer.get('title')} {pointer.get('kind')}".lower() for marker in ("gold", "solution", "hidden test", "eval")) + ] + status = "quarantined" if benchmark_mode and risky_refs else "clean" + return { + "schema_version": CONTAMINATION_STATUS_SCHEMA, + "benchmark_mode": benchmark_mode, + "status": status, + "rules": [ + "no gold patches", + "no hidden tests", + "no issue-to-solution memory", + "no prior evaluated solution recall", + "all memories carry provenance", + "eval memories are quarantined", + ], + "quarantined_refs": [pointer.get("evidence_pointer") or pointer.get("ref") for pointer in risky_refs], + } + + +def _lifecycle( + *, + precondition: str, + execution: Dict[str, Any], + observation: str, + postcondition: str, + memory_update: str, +) -> Dict[str, Any]: + return { + "precondition": precondition, + "execution": execution, + "observation": observation, + "postcondition": postcondition, + "memory_update": memory_update, + } + + +def _action(action_type: str, reason: str, lifecycle: Dict[str, Any], **payload: Any) -> Dict[str, Any]: + return { + "type": action_type, + **payload, + "reason": reason, + **lifecycle, + } + + +def _action_operands(action: Dict[str, Any]) -> Dict[str, Any]: + operands: Dict[str, Any] = {} + for field in _ACTION_OPERAND_FIELDS: + value = action.get(field) + if value not in (None, "", [], {}): + operands[field] = value + if action.get("timeout_sec") and action.get("type") == "RUN_TEST": + operands["timeout_sec"] = action.get("timeout_sec") + return operands + + +def _action_target(action: Dict[str, Any]) -> str: + operands = _action_operands(action) + for field in ("path", "command", "query", "symbol", "question", "category", "summary"): + if operands.get(field): + return str(operands[field]) + return str(action.get("type") or "") + + +def _action_id(action: Dict[str, Any], index: int) -> str: + return "act_" + _stable_hash({ + "index": index, + "type": action.get("type"), + "operands": _action_operands(action), + }, 14) + + +def _compile_action_bytecode(actions: Sequence[Dict[str, Any]], contract: Dict[str, Any]) -> List[Dict[str, Any]]: + """Lower lifecycle actions into a tiny portable bytecode graph.""" + + lowered: List[Dict[str, Any]] = [] + for index, raw in enumerate(actions, start=1): + action = dict(raw) + action_type = str(action.get("type") or "") + semantics = _ACTION_SEMANTICS.get(action_type, {}) + operands = _action_operands(action) + action_id = str(action.get("action_id") or _action_id(action, index)) + arg_hash = _stable_hash({"type": action_type, "operands": operands}, 12) + action.update({ + "step": index, + "action_id": action_id, + "phase": semantics.get("phase", "unknown"), + "capabilities": list(semantics.get("capabilities") or []), + "effects": list(semantics.get("effects") or []), + "operands": operands, + "requires": [], + "soft_requires": [], + "bytecode": { + "schema_version": ACTION_BYTECODE_SCHEMA, + "op": semantics.get("op", action_type.lower()), + "arg_hash": arg_hash, + "operands": operands, + "requires": [], + "soft_requires": [], + "emits": list(semantics.get("effects") or []), + }, + }) + lowered.append(action) + + first_search_id = next((action["action_id"] for action in lowered if action.get("type") == "SEARCH_CODE"), None) + read_by_path = { + str(action.get("path")): action["action_id"] + for action in lowered + if action.get("type") == "READ_FILE" and action.get("path") + } + run_test_ids = [ + action["action_id"] + for action in lowered + if action.get("type") == "RUN_TEST" + ] + + for action in lowered: + hard: List[str] = [] + soft: List[str] = [] + action_type = action.get("type") + if action_type == "READ_FILE" and first_search_id: + soft.append(first_search_id) + elif action_type == "EDIT_FILE": + read_id = read_by_path.get(str(action.get("path") or "")) + if read_id: + hard.append(read_id) + elif action_type == "RUN_TEST": + for path in contract.get("relevant_files") or []: + read_id = read_by_path.get(str(path)) + if read_id and read_id not in soft: + soft.append(read_id) + elif action_type == "WRITE_MEMORY_NOTE": + soft.extend(run_test_ids) + elif action_type == "SUBMIT_PATCH": + hard.extend(run_test_ids) + action["requires"] = hard + action["soft_requires"] = soft + action["bytecode"]["requires"] = hard + action["bytecode"]["soft_requires"] = soft + + return lowered + + +def _compiler_manifest(contract: Dict[str, Any], actions: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + action_digest = _stable_hash([ + { + "action_id": action.get("action_id"), + "type": action.get("type"), + "operands": action.get("operands") or _action_operands(action), + "requires": action.get("requires") or [], + } + for action in actions + ], 20) + return { + "schema_version": CONTRACT_COMPILER_SCHEMA, + "compiler": "dhee.context-compiler", + "version": 1, + "deterministic": True, + "source_language": "messy_user_task+repo_state", + "target_runtime": "dhee.contract_supervisor", + "passes": list(_COMPILER_PASSES), + "artifact_hash": action_digest, + "constraints": { + "auto_execute": False, + "validate_before_execute": True, + "raw_evidence_by_pointer": True, + "personal_context_bodies_excluded": True, + }, + "stats": { + "action_count": len(actions), + "hard_dependency_edges": sum(len(action.get("requires") or []) for action in actions), + "soft_dependency_edges": sum(len(action.get("soft_requires") or []) for action in actions), + "must_run_count": len(contract.get("must_run") or contract.get("test_commands") or []), + }, + } + + +def _compile_actions(contract: Dict[str, Any]) -> List[Dict[str, Any]]: + goal = contract["goal"] + repo_root = contract["repo_root"] + relevant_files = contract.get("relevant_files") or [] + must_run = contract.get("must_run") or [] + tokens = _tokens(goal) + query = " ".join(tokens) if tokens else goal + actions: List[Dict[str, Any]] = [ + _action( + "SEARCH_CODE", + "Locate implementation and tests before reading or editing.", + _lifecycle( + precondition="Repo is available and readable.", + execution={"tool": "dhee_grep", "query": query, "scope": repo_root}, + observation="Compact match summary plus pointer to full results.", + postcondition="Relevant implementation and test files are identified or the task is marked under-specified.", + memory_update="Do not write memory yet; only note a reusable repo pattern after confirmation.", + ), + query=query, + scope=".", + ) + ] + + for path in relevant_files[:6]: + actions.append( + _action( + "READ_FILE", + "Read the concrete file before proposing edits.", + _lifecycle( + precondition=f"`{path}` exists in the target repo.", + execution={"tool": "dhee_read", "path": path}, + observation="Pointer-backed file excerpt with line references.", + postcondition="Relevant symbols, invariants, and edit boundaries are known.", + memory_update="No memory write unless the file reveals a reusable architecture rule.", + ), + path=path, + ) + ) + + if not relevant_files: + actions.append( + _action( + "ASK_USER", + "The compiler could not identify a bounded file/module scope.", + _lifecycle( + precondition="Search produced no strong repo-local targets.", + execution={"prompt": "Ask for the failing command, file, or error message."}, + observation="User supplies missing scope or confirms broad investigation.", + postcondition="Task can be recompiled with concrete targets.", + memory_update="Store nothing unless the missing-scope pattern repeats.", + ), + question="Which failing command, file, or error should Dhee target first?", + blocking=True, + ) + ) + + for command in must_run[:6]: + actions.append( + _action( + "RUN_TEST", + "Classify current failure before and after edits.", + _lifecycle( + precondition="Dependency environment exists and command is safe for the sandbox.", + execution={"tool": "dhee_bash", "command": command, "timeout_sec": 120}, + observation="Compact failing stacktrace and pointer to full log.", + postcondition="Failure is classified as target failure, unrelated failure, or environment failure.", + memory_update="Store failure signature if it is reusable across future tasks.", + ), + command=command, + timeout_sec=120, + ) + ) + + actions.append( + _action( + "WRITE_MEMORY_NOTE", + "Capture reusable lessons without bloating the active prompt.", + _lifecycle( + precondition="A repeated failure pattern, architectural invariant, or repo workflow was confirmed.", + execution={"category": "failure_pattern"}, + observation="Short note with evidence pointers, not raw logs.", + postcondition="Future contracts can retrieve the lesson by pointer.", + memory_update="Write compact lesson only after verification.", + ), + category="failure_pattern", + content="If this task reveals a reusable failure signature, store the minimal signature plus test command and fix boundary.", + ) + ) + actions.append( + _action( + "SUBMIT_PATCH", + "Finish only after contract success criteria are satisfied.", + _lifecycle( + precondition="Edits are complete, tests have run, and unrelated diffs were avoided.", + execution={"summary": "Summarize changed behavior and tests run."}, + observation="Reviewable patch summary with test results.", + postcondition="User can review or ship the patch.", + memory_update="Checkpoint decisions, files touched, and any reusable lesson pointers.", + ), + summary="Submit a scoped patch for the compiled task contract.", + tests=must_run, + ) + ) + return actions + + +def _task_contract_root(repo_root: Path) -> Path: + return repo_link.repo_context_dir(repo_root) / "task_contracts" + + +def _safe_repo_path(repo_root: Path, rel_path: str) -> Optional[Path]: + raw = Path(str(rel_path or "")) + if not rel_path or raw.is_absolute() or ".." in raw.parts: + return None + try: + root = repo_root.resolve() + path = (root / raw).resolve() + if os.path.commonpath([str(root), str(path)]) != str(root): + return None + return path + except (OSError, ValueError): + return None + + +def _is_forbidden_path(path: str, forbidden_paths: Iterable[str]) -> bool: + normalized = str(path or "").replace("\\", "/").lstrip("./") + for pattern in forbidden_paths or []: + item = str(pattern or "").replace("\\", "/").lstrip("./") + if not item: + continue + if item.endswith("/"): + if normalized.startswith(item): + return True + continue + if item.endswith(".*"): + prefix = item[:-1] + if normalized.startswith(prefix): + return True + continue + if normalized == item or normalized.startswith(item.rstrip("/") + "/"): + return True + return False + + +def _path_under_allowed(path: str, allowed_paths: Iterable[str]) -> bool: + normalized = str(path or "").replace("\\", "/").lstrip("./") + allowed = [str(item or "").replace("\\", "/").lstrip("./") for item in allowed_paths or []] + if not allowed or "." in allowed: + return True + for item in allowed: + if not item: + continue + if item.endswith("/"): + if normalized.startswith(item): + return True + elif normalized == item or normalized.startswith(item.rstrip("/") + "/"): + return True + return False + + +def _command_is_safe(command: str) -> bool: + text = str(command or "").strip() + if not text: + return False + lowered = text.lower() + dangerous = ("rm -rf", "git reset --hard", "git checkout --", "sudo ", "curl | sh", "chmod 777") + if any(item in lowered for item in dangerous): + return False + safe_prefixes = ("pytest", "python -m pytest", "python3 -m pytest", "npm test", "npm run test", "pnpm test", "uv run pytest") + return lowered.startswith(safe_prefixes) + + +def _resolve_contract(compiled_or_contract: Dict[str, Any]) -> Dict[str, Any]: + if "contract" in compiled_or_contract and "actions" in compiled_or_contract: + compiled = dict(compiled_or_contract) + contract = compiled.get("contract") or {} + actions = compiled.get("actions") or [] + if actions and any(not action.get("action_id") or not action.get("bytecode") for action in actions if isinstance(action, dict)): + compiled["actions"] = _compile_action_bytecode(actions, contract) + compiled["actions_schema"] = ACTION_BYTECODE_SCHEMA + if compiled.get("actions") and not compiled.get("compiler"): + compiled["compiler"] = _compiler_manifest(contract, compiled.get("actions") or []) + return compiled + if compiled_or_contract.get("schema_version") == TASK_CONTRACT_SCHEMA: + wrapper = { + "format": "dhee_task_contract_compile.v1", + "contract": compiled_or_contract, + "actions_schema": ACTION_BYTECODE_SCHEMA, + "actions": compiled_or_contract.get("actions") or [], + } + wrapper["actions"] = _compile_action_bytecode(wrapper["actions"], compiled_or_contract) if wrapper["actions"] else [] + if wrapper["actions"]: + wrapper["compiler"] = _compiler_manifest(compiled_or_contract, wrapper["actions"]) + wrapper["validation"] = validate_task_contract(wrapper) + return wrapper + return compiled_or_contract + + +def compile_task_contract( + goal: str, + *, + repo: str | os.PathLike[str] | None = None, + mode: str = "patch", + risk: Optional[str] = None, + allowed_write_paths: Optional[Iterable[str]] = None, + forbidden_paths: Optional[Iterable[str]] = None, + must_run: Optional[Iterable[str]] = None, + success_criteria: Optional[Iterable[str]] = None, + context_budget: Optional[Dict[str, int]] = None, + memory_pointers: Optional[Iterable[Dict[str, Any]]] = None, + recent_failures: Optional[Iterable[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Compile a messy task request into a deterministic action contract.""" + + if not str(goal or "").strip(): + raise ValueError("goal is required") + repo_root = _resolve_repo_root(repo) + branch_state = _branch_state(repo_root) + relevant_files = _relevant_files(repo_root, goal, branch_state) + affected_modules = _affected_modules(relevant_files, branch_state) + failures = [dict(item) for item in (recent_failures or []) if isinstance(item, dict)] + contract = { + "schema_version": TASK_CONTRACT_SCHEMA, + "task_id": "task_" + datetime.now(timezone.utc).strftime("%Y_%m_%d_") + _stable_hash({ + "goal": goal, + "repo": str(repo_root), + "head": branch_state.get("head_commit"), + }, 8), + "created_at": _now_iso(), + "goal": str(goal).strip(), + "repo": _repo_slug(repo_root), + "repo_root": str(repo_root), + "mode": mode or "patch", + "risk": risk or _infer_risk(goal, branch_state, failures), + "affected_modules": affected_modules, + "known_architecture": _known_architecture(repo_root), + "recent_failures": failures, + "test_commands": _infer_test_commands(repo_root, goal, relevant_files, must_run), + "relevant_files": relevant_files, + "allowed_write_paths": list(allowed_write_paths or _default_allowed_write_paths(repo_root, affected_modules)), + "forbidden_paths": list(forbidden_paths or DEFAULT_FORBIDDEN_PATHS), + "forbidden_actions": DEFAULT_FORBIDDEN_ACTIONS, + "success_criteria": list(success_criteria or DEFAULT_SUCCESS_CRITERIA), + "rollback_plan": [ + "Review `git diff --stat` before and after edits.", + "Keep edits inside allowed_write_paths.", + "If tests regress outside target scope, stop and report the failing command.", + ], + "memory_pointers": _repo_memory_pointers(repo_root, goal, memory_pointers), + "current_branch_state": branch_state, + "context_budget": dict(context_budget or DEFAULT_CONTEXT_BUDGET), + } + contract["must_run"] = contract["test_commands"] + repo_intelligence = _compile_repo_intelligence( + repo_root, + goal=contract["goal"], + relevant_files=contract["relevant_files"], + must_run=contract["must_run"], + ) + contract["repo_intelligence"] = repo_intelligence + contract["compiled_context"] = _context_ledger(contract, repo_intelligence) + contract["verification_card"] = _verification_card(contract, repo_intelligence) + contract["contamination_status"] = _contamination_status(contract["goal"], contract["memory_pointers"]) + contract["patch_families"] = [ + { + "name": "minimal_fix", + "intent": "Smallest behavior change that satisfies fail-to-pass tests.", + "risk": "low", + "requires_isolated_worktree": False, + }, + { + "name": "semantic_fix", + "intent": "Correct the underlying invariant while preserving public API behavior.", + "risk": "medium", + "requires_isolated_worktree": False, + }, + { + "name": "edge_case_fix", + "intent": "Address boundary conditions revealed by nearby tests or failure signatures.", + "risk": "medium", + "requires_isolated_worktree": False, + }, + { + "name": "regression_safe_fix", + "intent": "Prefer broader pass-to-pass verification before submit.", + "risk": "medium", + "requires_isolated_worktree": True, + }, + { + "name": "alternative_hypothesis", + "intent": "Branch and test a competing localization only when evidence is weak.", + "risk": "high", + "requires_isolated_worktree": True, + }, + ] + contract["edit_proof_obligations"] = [ + "file was read", + "edit span localized", + "invariant stated", + "related test selected", + "rollback point exists", + ] + contract["replay_plan"] = { + "checkpoints": ["after_localization", "before_edit", "after_failing_test", "before_submit"], + "storage": ".dhee/context/task_runs//checkpoints/", + "failed_attempts_are_assets": True, + } + contract["memory_policy"] = { + "generic_memory_injection": False, + "survived_lessons_only": True, + "raw_logs_by_pointer_only": True, + "skill_promotion_requires_ab_test": True, + } + actions = _compile_action_bytecode(_compile_actions(contract), contract) + validation = validate_task_contract({"contract": contract, "actions": actions}) + return { + "format": "dhee_task_contract_compile.v1", + "contract": contract, + "compiler": _compiler_manifest(contract, actions), + "actions_schema": ACTION_BYTECODE_SCHEMA, + "actions": actions, + "validation": validation, + } + + +def _write_task_contract(compiled: Dict[str, Any], out_dir: Path) -> Dict[str, str]: + out_dir.mkdir(parents=True, exist_ok=True) + data = _sanitize_obj(compiled) + md = _sanitize_text(render_task_contract(data)) + json_path = out_dir / "contract.json" + md_path = out_dir / "contract.md" + json_path.write_text(_json_dumps(data) + "\n", encoding="utf-8") + md_path.write_text(md, encoding="utf-8") + return {"json": str(json_path), "markdown": str(md_path), "dir": str(out_dir)} + + +def create_task_contract( + goal: str, + *, + repo: str | os.PathLike[str] | None = None, + out: str | os.PathLike[str] | None = None, + mode: str = "patch", + risk: Optional[str] = None, + allowed_write_paths: Optional[Iterable[str]] = None, + forbidden_paths: Optional[Iterable[str]] = None, + must_run: Optional[Iterable[str]] = None, + success_criteria: Optional[Iterable[str]] = None, + context_budget: Optional[Dict[str, int]] = None, + memory_pointers: Optional[Iterable[Dict[str, Any]]] = None, + recent_failures: Optional[Iterable[Dict[str, Any]]] = None, +) -> Dict[str, Any]: + """Compile and persist a portable task contract under .dhee/context.""" + + repo_root = _resolve_repo_root(repo) + repo_link._ensure_repo_skeleton(repo_root) + compiled = compile_task_contract( + goal, + repo=repo_root, + mode=mode, + risk=risk, + allowed_write_paths=allowed_write_paths, + forbidden_paths=forbidden_paths, + must_run=must_run, + success_criteria=success_criteria, + context_budget=context_budget, + memory_pointers=memory_pointers, + recent_failures=recent_failures, + ) + compiled = _sanitize_obj(compiled) + task_id = str((compiled.get("contract") or {}).get("task_id") or ("task_" + _stable_hash(compiled, 16))) + target_dir = Path(out).expanduser().resolve() if out else _task_contract_root(repo_root) / task_id + paths = _write_task_contract(compiled, target_dir) + md = Path(paths["markdown"]).read_text(encoding="utf-8") + rel_dir = os.path.relpath(target_dir, repo_root) if str(target_dir).startswith(str(repo_root)) else str(target_dir) + entry = repo_link.add_entry( + repo_root, + kind=TASK_CONTRACT_KIND, + title=f"Task contract {task_id}", + content=md, + meta={ + "task_id": task_id, + "contract_dir": rel_dir, + "goal": (compiled.get("contract") or {}).get("goal"), + "mode": (compiled.get("contract") or {}).get("mode"), + "risk": (compiled.get("contract") or {}).get("risk"), + "must_run": (compiled.get("contract") or {}).get("must_run") or [], + "portable": True, + }, + ) + return { + "format": "dhee_task_contract_create.v1", + "contract": compiled["contract"], + "compiler": compiled.get("compiler"), + "actions_schema": compiled.get("actions_schema"), + "actions": compiled.get("actions") or [], + "validation": compiled.get("validation") or validate_task_contract(compiled), + "paths": paths, + "entry": entry.to_json(), + } + + +def list_task_contracts(*, repo: str | os.PathLike[str] | None = None) -> List[Dict[str, Any]]: + repo_root = _resolve_repo_root(repo) + root = _task_contract_root(repo_root) + if not root.exists(): + return [] + out: List[Dict[str, Any]] = [] + for json_path in sorted(root.glob("*/contract.json")): + try: + data = json.loads(json_path.read_text(encoding="utf-8")) + except Exception: + continue + contract = data.get("contract") or {} + out.append({ + "task_id": contract.get("task_id"), + "goal": contract.get("goal"), + "mode": contract.get("mode"), + "risk": contract.get("risk"), + "must_run": contract.get("must_run") or [], + "path": str(json_path.parent), + "created_at": contract.get("created_at"), + }) + return out + + +def get_task_contract( + task_id: str, + *, + repo: str | os.PathLike[str] | None = None, +) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + root = _task_contract_root(repo_root) + matches = [ + path + for path in root.glob("*/contract.json") + if path.parent.name == task_id or path.parent.name.startswith(str(task_id)) + ] + if not matches: + raise FileNotFoundError(f"Task contract {task_id!r} not found") + json_path = matches[0] + md_path = json_path.with_name("contract.md") + data = json.loads(json_path.read_text(encoding="utf-8")) + return { + "format": "dhee_task_contract_get.v1", + "compiled": data, + "contract": (data.get("contract") or {}), + "compiler": data.get("compiler"), + "actions": data.get("actions") or [], + "markdown": md_path.read_text(encoding="utf-8") if md_path.exists() else render_task_contract(data), + "paths": {"json": str(json_path), "markdown": str(md_path), "dir": str(json_path.parent)}, + } + + +def _read_contract_source(path: str | os.PathLike[str]) -> Tuple[Dict[str, Any], str, Path]: + source = Path(path).expanduser().resolve() + if source.is_dir(): + json_path = source / "contract.json" + md_path = source / "contract.md" + source_dir = source + elif source.suffix == ".json": + json_path = source + md_path = source.with_name("contract.md") + source_dir = source.parent + elif source.suffix == ".md": + json_path = source.with_name("contract.json") + md_path = source + source_dir = source.parent + else: + raise ValueError("Import path must be a task contract directory, contract.json, or contract.md") + if not json_path.exists(): + raise FileNotFoundError(f"Missing contract.json near {source}") + data = json.loads(json_path.read_text(encoding="utf-8")) + md = md_path.read_text(encoding="utf-8") if md_path.exists() else render_task_contract(data) + return data, md, source_dir + + +def import_task_contract( + path: str | os.PathLike[str], + *, + repo: str | os.PathLike[str] | None = None, +) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + repo_link._ensure_repo_skeleton(repo_root) + data, md, _source_dir = _read_contract_source(path) + data = _sanitize_obj(_resolve_contract(data)) + validation = validate_task_contract(data) + if not validation["ok"]: + codes = ", ".join(str(item.get("code")) for item in validation["diagnostics"] if item.get("level") == "error") + raise ValueError(f"Task contract import rejected: invalid contract ({codes or 'validation failed'})") + data["validation"] = validation + contract = data.get("contract") or {} + task_id = str(contract.get("task_id") or ("task_" + _stable_hash(data, 16))) + contract["task_id"] = task_id + data["contract"] = contract + dest = _task_contract_root(repo_root) / task_id + paths = _write_task_contract(data, dest) + entry = repo_link.add_entry( + repo_root, + kind=TASK_CONTRACT_KIND, + title=f"Imported task contract {task_id}", + content=_sanitize_text(md), + meta={ + "task_id": task_id, + "contract_dir": os.path.relpath(dest, repo_root), + "goal": contract.get("goal"), + "imported": True, + "portable": True, + }, + ) + return { + "format": "dhee_task_contract_import.v1", + "contract": contract, + "compiler": data.get("compiler"), + "actions": data.get("actions") or [], + "validation": validation, + "paths": paths, + "entry": entry.to_json(), + } + + +def _load_task_contract( + task_contract: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, +) -> Dict[str, Any]: + if isinstance(task_contract, dict): + return _resolve_contract(task_contract) + value = str(task_contract) + source = Path(value).expanduser() + if source.exists(): + data, _md, _source_dir = _read_contract_source(source) + return _resolve_contract(data) + return get_task_contract(value, repo=repo)["compiled"] + + +def _action_state(repo_root: Path, contract: Dict[str, Any], action: Dict[str, Any]) -> Dict[str, Any]: + action_type = str(action.get("type") or "") + allowed_paths = contract.get("allowed_write_paths") or [] + forbidden_paths = contract.get("forbidden_paths") or [] + diagnostics: List[Dict[str, Any]] = [] + state = "ready" + target = action.get("path") or action.get("command") or action.get("query") or action.get("category") or action.get("summary") + + if action_type == "READ_FILE": + path = str(action.get("path") or "") + resolved = _safe_repo_path(repo_root, path) + if resolved is None: + state = "blocked" + diagnostics.append({"level": "error", "code": "UNSAFE_READ_PATH", "path": path, "message": "READ_FILE path is absolute or escapes the repo."}) + elif not resolved.exists(): + state = "blocked" + diagnostics.append({"level": "error", "code": "READ_PATH_MISSING", "path": path, "message": "Required read target is missing in this checkout."}) + elif _is_forbidden_path(path, forbidden_paths): + state = "blocked" + diagnostics.append({"level": "error", "code": "READ_PATH_FORBIDDEN", "path": path, "message": "Action targets a forbidden path."}) + elif action_type == "EDIT_FILE": + path = str(action.get("path") or "") + resolved = _safe_repo_path(repo_root, path) + if resolved is None: + state = "blocked" + diagnostics.append({"level": "error", "code": "UNSAFE_EDIT_PATH", "path": path, "message": "EDIT_FILE path is absolute or escapes the repo."}) + elif _is_forbidden_path(path, forbidden_paths): + state = "blocked" + diagnostics.append({"level": "error", "code": "EDIT_PATH_FORBIDDEN", "path": path, "message": "Action targets a forbidden path."}) + elif not _path_under_allowed(path, allowed_paths): + state = "blocked" + diagnostics.append({"level": "error", "code": "EDIT_PATH_OUTSIDE_ALLOWED", "path": path, "message": "Action is outside allowed_write_paths."}) + elif action_type == "RUN_TEST": + command = str(action.get("command") or "") + if not _command_is_safe(command): + state = "blocked" + diagnostics.append({"level": "error", "code": "UNSAFE_TEST_COMMAND", "command": command, "message": "RUN_TEST command is empty or outside the safe test command allowlist."}) + elif action_type == "ASK_USER" and action.get("blocking"): + state = "needs_input" + diagnostics.append({"level": "warning", "code": "BLOCKING_USER_INPUT", "message": "Action requires user input before execution."}) + elif action_type in {"WRITE_MEMORY_NOTE", "SUBMIT_PATCH"}: + state = "deferred" + elif action_type == "SEARCH_CODE": + if not str(action.get("query") or "").strip(): + state = "blocked" + diagnostics.append({"level": "error", "code": "EMPTY_SEARCH_QUERY", "message": "SEARCH_CODE action has no query."}) + + return { + "action_id": action.get("action_id"), + "type": action_type, + "phase": action.get("phase"), + "requires": action.get("requires") or [], + "soft_requires": action.get("soft_requires") or [], + "capabilities": action.get("capabilities") or [], + "effects": action.get("effects") or [], + "target": target, + "state": state, + "diagnostics": diagnostics, + "precondition": action.get("precondition"), + "postcondition": action.get("postcondition"), + } + + +def interpret_task_contract( + task_contract: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + strict: bool = False, +) -> Dict[str, Any]: + """Interpret a compiled task contract on the receiving machine.""" + + repo_root = _resolve_repo_root(repo) + compiled = _sanitize_obj(_load_task_contract(task_contract, repo=repo_root)) + validation = validate_task_contract(compiled) + contract = compiled.get("contract") or {} + actions = compiled.get("actions") or [] + diagnostics = list(validation.get("diagnostics") or []) + current_repo = _repo_slug(repo_root) + if contract.get("repo") and current_repo != contract.get("repo"): + diagnostics.append({ + "level": "error" if strict else "warning", + "code": "REPO_ID_MISMATCH", + "message": "Compiled contract repo id differs from this checkout.", + "compiled_repo": contract.get("repo"), + "target_repo": current_repo, + }) + current_branch_state = _branch_state(repo_root) + if current_branch_state.get("dirty"): + diagnostics.append({ + "level": "warning", + "code": "TARGET_WORKTREE_DIRTY", + "message": "Target worktree is dirty; receiving agent must avoid mixing unrelated edits.", + "changed_paths": current_branch_state.get("changed_paths") or [], + }) + for path in contract.get("allowed_write_paths") or []: + if _is_forbidden_path(str(path), contract.get("forbidden_paths") or []): + diagnostics.append({ + "level": "error", + "code": "ALLOWED_PATH_FORBIDDEN", + "path": path, + "message": "Contract has an allowed_write_path that overlaps forbidden_paths.", + }) + + action_states = [_action_state(repo_root, contract, action) for action in actions] + for state in action_states: + diagnostics.extend(state.get("diagnostics") or []) + + states = {state.get("state") for state in action_states} + if not validation["ok"] or any(item.get("level") == "error" for item in diagnostics): + readiness = "blocked" + elif "needs_input" in states: + readiness = "needs_input" + elif states and states <= {"deferred"}: + readiness = "deferred" + else: + readiness = "ready" + + execution_plan: List[Dict[str, Any]] = [] + for index, action in enumerate(actions, start=1): + state = action_states[index - 1] if index - 1 < len(action_states) else {} + execution_plan.append({ + "step": index, + "action_id": action.get("action_id"), + "type": action.get("type"), + "phase": action.get("phase"), + "state": state.get("state"), + "target": state.get("target"), + "requires": action.get("requires") or [], + "soft_requires": action.get("soft_requires") or [], + "capabilities": action.get("capabilities") or [], + "effects": action.get("effects") or [], + "reason": action.get("reason"), + "precondition": action.get("precondition"), + "execution": action.get("execution"), + "observation": action.get("observation"), + "postcondition": action.get("postcondition"), + "memory_update": action.get("memory_update"), + }) + + return { + "format": TASK_INTERPRETATION_SCHEMA, + "repo": str(repo_root), + "compiled_repo": contract.get("repo"), + "target_repo": current_repo, + "task_id": contract.get("task_id"), + "goal": contract.get("goal"), + "readiness": readiness, + "validation": validation, + "current_branch_state": current_branch_state, + "action_states": action_states, + "execution_plan": execution_plan, + "diagnostics": diagnostics, + "policy": { + "auto_execute": False, + "requires_agent_tool_execution": True, + "strict": bool(strict), + }, + } + + +def validate_task_contract(compiled: Dict[str, Any]) -> Dict[str, Any]: + diagnostics: List[Dict[str, Any]] = [] + contract = compiled.get("contract") if isinstance(compiled, dict) else None + actions = compiled.get("actions") if isinstance(compiled, dict) else None + if not isinstance(contract, dict): + diagnostics.append({"level": "error", "code": "MISSING_CONTRACT", "message": "Compiled task is missing contract."}) + contract = {} + for field in ( + "task_id", + "goal", + "repo", + "mode", + "risk", + "allowed_write_paths", + "forbidden_paths", + "must_run", + "success_criteria", + "context_budget", + ): + if field not in contract: + diagnostics.append({"level": "error", "code": "MISSING_CONTRACT_FIELD", "field": field, "message": f"Task contract missing {field}."}) + if not isinstance(actions, list) or not actions: + diagnostics.append({"level": "error", "code": "MISSING_ACTIONS", "message": "Compiled task needs at least one typed action."}) + actions = [] + action_ids = {str(action.get("action_id")) for action in actions if isinstance(action, dict) and action.get("action_id")} + for index, action in enumerate(actions): + if not isinstance(action, dict): + diagnostics.append({"level": "error", "code": "INVALID_ACTION", "index": index, "message": "Action must be an object."}) + continue + action_type = action.get("type") + if action_type not in ACTION_TYPES: + diagnostics.append({"level": "error", "code": "UNKNOWN_ACTION_TYPE", "index": index, "message": f"Unknown action type {action_type!r}."}) + if not action.get("action_id"): + diagnostics.append({"level": "warning", "code": "MISSING_ACTION_ID", "index": index, "message": "Action has no stable action_id; import will lower legacy actions when possible."}) + if not action.get("bytecode"): + diagnostics.append({"level": "warning", "code": "MISSING_ACTION_BYTECODE", "index": index, "message": "Action has no bytecode metadata; runtime enforcement will fall back to structural matching."}) + for dep in action.get("requires") or []: + if not isinstance(dep, str) or not dep.startswith("act_"): + diagnostics.append({"level": "error", "code": "INVALID_ACTION_DEPENDENCY", "index": index, "dependency": dep, "message": "Hard dependency must reference a stable action_id."}) + elif dep not in action_ids: + diagnostics.append({"level": "error", "code": "UNKNOWN_ACTION_DEPENDENCY", "index": index, "dependency": dep, "message": "Hard dependency does not reference a compiled action."}) + for field in ("precondition", "execution", "observation", "postcondition", "memory_update"): + if field not in action: + diagnostics.append({"level": "error", "code": "MISSING_ACTION_LIFECYCLE", "index": index, "field": field, "message": f"Action missing {field}."}) + if action_type == "RUN_TEST": + if not action.get("command") or not action.get("timeout_sec"): + diagnostics.append({"level": "error", "code": "INVALID_RUN_TEST_ACTION", "index": index, "message": "RUN_TEST requires command and timeout_sec."}) + if action_type == "READ_FILE" and not action.get("path"): + diagnostics.append({"level": "error", "code": "INVALID_READ_FILE_ACTION", "index": index, "message": "READ_FILE requires path."}) + if action_type == "SEARCH_CODE" and not action.get("query"): + diagnostics.append({"level": "error", "code": "INVALID_SEARCH_CODE_ACTION", "index": index, "message": "SEARCH_CODE requires query."}) + return { + "ok": not any(item.get("level") == "error" for item in diagnostics), + "diagnostics": diagnostics, + "action_count": len(actions), + } + + +def render_task_contract(compiled: Dict[str, Any]) -> str: + contract = compiled.get("contract") or {} + lines = [ + f"# Task Contract: {contract.get('task_id') or '(unknown)'}", + "", + f"- Goal: {contract.get('goal') or ''}", + f"- Repo: `{contract.get('repo') or ''}`", + f"- Mode: `{contract.get('mode') or ''}`", + f"- Risk: `{contract.get('risk') or ''}`", + f"- Allowed writes: {', '.join(f'`{item}`' for item in contract.get('allowed_write_paths') or []) or '(none)'}", + f"- Forbidden paths: {', '.join(f'`{item}`' for item in contract.get('forbidden_paths') or []) or '(none)'}", + "", + "## Must Run", + ] + lines.extend(f"- `{cmd}`" for cmd in contract.get("must_run") or ["pytest"]) + compiler = compiled.get("compiler") or {} + if compiler: + lines.extend([ + "", + "## Compiler", + f"- Schema: `{compiler.get('schema_version') or ''}`", + f"- Target runtime: `{compiler.get('target_runtime') or ''}`", + f"- Artifact hash: `{compiler.get('artifact_hash') or ''}`", + ]) + lines.extend(["", "## Typed Actions"]) + for index, action in enumerate(compiled.get("actions") or [], start=1): + subject = _action_target(action) + requires = ", ".join(f"`{item}`" for item in action.get("requires") or []) or "(none)" + lines.append(f"{index}. `{action.get('type')}` `{action.get('action_id') or ''}` {subject}") + lines.append(f" - Phase: {action.get('phase') or 'unknown'}") + lines.append(f" - Requires: {requires}") + lines.append(f" - Precondition: {action.get('precondition')}") + lines.append(f" - Observation: {action.get('observation')}") + lines.append(f" - Postcondition: {action.get('postcondition')}") + return "\n".join(lines).strip() + "\n" diff --git a/dhee/temporal_scenes.py b/dhee/temporal_scenes.py new file mode 100644 index 0000000..1545f94 --- /dev/null +++ b/dhee/temporal_scenes.py @@ -0,0 +1,1018 @@ +"""Deterministic temporal scene cards and bounded context packs. + +Temporal scenes are Dhee's compact layer over noisy evidence. They keep +provenance and searchable derivatives close at hand, while raw screenshots, +transcripts, media, and long memory bodies stay behind pointers. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +import json +import os +import re +from collections import Counter +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence + + +SCENE_SCHEMA_VERSION = 1 +_MAX_SNIPPET_CHARS = 280 +_TOKEN_RE = re.compile(r"[A-Za-z][A-Za-z0-9_+-]{2,}") +_NOISE_WORDS = { + "about", "after", "agent", "also", "and", "are", "because", "been", + "before", "being", "build", "can", "codex", "context", "dhee", "for", + "from", "has", "have", "into", "its", "memory", "more", "not", "now", + "only", "repo", "should", "that", "the", "their", "then", "there", + "this", "use", "used", "user", "was", "when", "with", "work", "will", + "you", "your", +} +_GEM_TERMS = { + "adapter", "api", "baseline", "behavior", "bug", "capsule", "change", + "compatibility", "decision", "diff", "failure", "fix", "interface", + "lesson", "migration", "privacy", "regression", "reproduce", "risk", + "scene", "secret", "test", "token", "update", +} + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _clip(text: Any, limit: int) -> str: + value = str(text or "").strip() + value = re.sub(r"\s+", " ", value) + if len(value) <= limit: + return value + return value[: max(0, limit - 1)].rstrip() + "…" + + +def _stable_hash(payload: Any, length: int = 16) -> str: + raw = json.dumps(payload, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:length] + + +def _safe_user_key(user_id: str) -> str: + return re.sub(r"[^A-Za-z0-9_.-]+", "_", user_id or "default")[:80] or "default" + + +def _estimate_tokens(text: str) -> int: + return max(1, (len(text or "") + 3) // 4) + + +def _tokens(text: str) -> List[str]: + out: List[str] = [] + for match in _TOKEN_RE.findall(text or ""): + token = match.lower() + if token not in _NOISE_WORDS and len(token) > 2: + out.append(token) + return out + + +def _first_text(raw: Dict[str, Any]) -> str: + for key in ( + "memory", "content", "body", "text", "summary", "digest", + "observation", "title", "message", + ): + value = raw.get(key) + if isinstance(value, str) and value.strip(): + return value + metadata = raw.get("metadata") or raw.get("meta") or {} + if isinstance(metadata, dict): + for key in ("text", "summary", "title", "url", "path"): + value = metadata.get(key) + if isinstance(value, str) and value.strip(): + return value + return "" + + +def _infer_modality(raw: Dict[str, Any], text: str) -> str: + metadata = raw.get("metadata") or raw.get("meta") or {} + if isinstance(metadata, dict): + for key in ("modality", "media_type", "source_type"): + value = metadata.get(key) + if value: + return str(value) + source_type = str(raw.get("source_type") or raw.get("memory_type") or raw.get("kind") or "").lower() + for candidate in ("video", "audio", "image", "ocr", "dom", "web", "screen", "artifact", "transcript"): + if candidate in source_type: + return candidate + if " List[str]: + values: List[str] = [] + for key in ("categories", "tags"): + item = raw.get(key) + if isinstance(item, list): + values.extend(str(v) for v in item if v) + elif isinstance(item, str): + values.extend(part.strip() for part in item.split(",") if part.strip()) + metadata = raw.get("metadata") or raw.get("meta") or {} + if isinstance(metadata, dict): + item = metadata.get("categories") or metadata.get("tags") + if isinstance(item, list): + values.extend(str(v) for v in item if v) + elif isinstance(item, str): + values.extend(part.strip() for part in item.split(",") if part.strip()) + return values + + +@dataclass +class EvidencePointer: + """Compact pointer to evidence plus a small searchable derivative.""" + + kind: str + ref: str + label: str = "" + modality: str = "text" + user_id: str = "default" + agent_id: str = "" + source_app: str = "" + source_event_id: str = "" + run_id: str = "" + memory_type: str = "" + confidentiality_scope: str = "personal" + uri: str = "" + snippet: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self, include_snippet: bool = True, include_private_uri: bool = True) -> Dict[str, Any]: + data = { + "kind": self.kind, + "ref": self.ref, + "label": self.label, + "modality": self.modality, + "user_id": self.user_id, + "agent_id": self.agent_id, + "source_app": self.source_app, + "source_event_id": self.source_event_id, + "run_id": self.run_id, + "memory_type": self.memory_type, + "confidentiality_scope": self.confidentiality_scope, + "metadata": dict(self.metadata or {}), + } + if include_private_uri and self.uri: + data["uri"] = self.uri + if include_snippet and self.snippet: + data["snippet"] = self.snippet + return data + + @classmethod + def from_dict(cls, raw: Dict[str, Any]) -> "EvidencePointer": + return cls( + kind=str(raw.get("kind") or raw.get("source_type") or "evidence"), + ref=str(raw.get("ref") or raw.get("id") or raw.get("memory_id") or _stable_hash(raw)), + label=str(raw.get("label") or raw.get("title") or ""), + modality=str(raw.get("modality") or "text"), + user_id=str(raw.get("user_id") or "default"), + agent_id=str(raw.get("agent_id") or ""), + source_app=str(raw.get("source_app") or ""), + source_event_id=str(raw.get("source_event_id") or ""), + run_id=str(raw.get("run_id") or ""), + memory_type=str(raw.get("memory_type") or ""), + confidentiality_scope=str(raw.get("confidentiality_scope") or "personal"), + uri=str(raw.get("uri") or ""), + snippet=str(raw.get("snippet") or ""), + metadata=dict(raw.get("metadata") or {}), + ) + + +@dataclass +class TemporalScene: + """A private compact scene card compiled from many noisy evidence events.""" + + id: str + title: str + summary: str + topic: str = "" + user_goal: str = "" + action: str = "" + outcome: str = "" + lesson: str = "" + entities: List[str] = field(default_factory=list) + tags: List[str] = field(default_factory=list) + modalities: List[str] = field(default_factory=list) + repo_refs: List[str] = field(default_factory=list) + evidence: List[EvidencePointer] = field(default_factory=list) + provenance: Dict[str, Any] = field(default_factory=dict) + privacy_scope: str = "personal" + confidence: float = 0.5 + score: float = 0.0 + tier: str = "warm" + created_at: str = field(default_factory=_now_iso) + start_time: Optional[str] = None + end_time: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self, include_evidence_snippets: bool = True) -> Dict[str, Any]: + return { + "schema_version": SCENE_SCHEMA_VERSION, + "id": self.id, + "title": self.title, + "summary": self.summary, + "topic": self.topic, + "user_goal": self.user_goal, + "action": self.action, + "outcome": self.outcome, + "lesson": self.lesson, + "entities": list(self.entities), + "tags": list(self.tags), + "modalities": list(self.modalities), + "repo_refs": list(self.repo_refs), + "evidence": [ + pointer.to_dict(include_snippet=include_evidence_snippets) + for pointer in self.evidence + ], + "provenance": dict(self.provenance or {}), + "privacy_scope": self.privacy_scope, + "confidence": round(float(self.confidence), 4), + "score": round(float(self.score), 4), + "tier": self.tier, + "created_at": self.created_at, + "start_time": self.start_time, + "end_time": self.end_time, + "metadata": dict(self.metadata or {}), + } + + @classmethod + def from_dict(cls, raw: Dict[str, Any]) -> "TemporalScene": + evidence = [ + pointer if isinstance(pointer, EvidencePointer) else EvidencePointer.from_dict(pointer) + for pointer in (raw.get("evidence") or []) + if isinstance(pointer, (dict, EvidencePointer)) + ] + return cls( + id=str(raw.get("id") or _stable_hash(raw)), + title=str(raw.get("title") or ""), + summary=str(raw.get("summary") or ""), + topic=str(raw.get("topic") or ""), + user_goal=str(raw.get("user_goal") or ""), + action=str(raw.get("action") or ""), + outcome=str(raw.get("outcome") or ""), + lesson=str(raw.get("lesson") or ""), + entities=[str(v) for v in raw.get("entities") or []], + tags=[str(v) for v in raw.get("tags") or []], + modalities=[str(v) for v in raw.get("modalities") or []], + repo_refs=[str(v) for v in raw.get("repo_refs") or []], + evidence=evidence, + provenance=dict(raw.get("provenance") or {}), + privacy_scope=str(raw.get("privacy_scope") or "personal"), + confidence=float(raw.get("confidence") or 0.5), + score=float(raw.get("score") or 0.0), + tier=str(raw.get("tier") or "warm"), + created_at=str(raw.get("created_at") or _now_iso()), + start_time=raw.get("start_time") or None, + end_time=raw.get("end_time") or None, + metadata=dict(raw.get("metadata") or {}), + ) + + def to_card(self, max_chars: int = 900) -> Dict[str, Any]: + """Return a prompt-safe card with no raw evidence bodies.""" + + evidence_refs = [ + { + "kind": pointer.kind, + "ref": pointer.ref, + "label": pointer.label, + "modality": pointer.modality, + "source_app": pointer.source_app, + "agent_id": pointer.agent_id, + "confidentiality_scope": pointer.confidentiality_scope, + } + for pointer in self.evidence[:8] + ] + card = { + "id": self.id, + "title": self.title, + "summary": _clip(self.summary, max_chars), + "topic": self.topic, + "lesson": _clip(self.lesson, 360), + "tags": list(self.tags[:12]), + "entities": list(self.entities[:12]), + "repo_refs": list(self.repo_refs[:8]), + "tier": self.tier, + "score": round(float(self.score), 4), + "confidence": round(float(self.confidence), 4), + "evidence_refs": evidence_refs, + } + return card + + +class GemScorer: + """Deterministic scorer that decides whether noisy evidence is a gem.""" + + def score(self, text: str, pointers: Sequence[EvidencePointer]) -> float: + terms = set(_tokens(text)) + score = 0.18 + score += min(0.22, len(terms & _GEM_TERMS) * 0.045) + score += min(0.18, len(pointers) * 0.035) + score += min(0.12, len({p.agent_id for p in pointers if p.agent_id}) * 0.04) + score += min(0.12, len({p.source_app for p in pointers if p.source_app}) * 0.04) + if any(p.modality not in ("text", "") for p in pointers): + score += 0.06 + if any(p.confidentiality_scope in ("public", "repo", "shareable") for p in pointers): + score += 0.04 + if len(text) > 500: + score += 0.05 + if any(p.confidentiality_scope in ("secret", "restricted") for p in pointers): + score -= 0.08 + return max(0.0, min(1.0, score)) + + def tier(self, score: float) -> str: + if score >= 0.72: + return "hot" + if score >= 0.42: + return "warm" + return "cold" + + +class SceneCompiler: + """Compile compact scenes from memory rows, artifacts, browser captures, or agent outputs.""" + + def __init__(self, scorer: Optional[GemScorer] = None) -> None: + self.scorer = scorer or GemScorer() + + def _pointer_from_evidence(self, item: Any, default_user_id: str) -> EvidencePointer: + raw = dataclasses.asdict(item) if dataclasses.is_dataclass(item) else item + if not isinstance(raw, dict): + raw = {"content": str(raw)} + metadata = raw.get("metadata") or raw.get("meta") or {} + if not isinstance(metadata, dict): + metadata = {"raw_metadata": metadata} + text = _first_text(raw) + ref = ( + raw.get("ref") or raw.get("id") or raw.get("memory_id") or + raw.get("source_event_id") or metadata.get("id") or metadata.get("source_event_id") or + _stable_hash({"text": text, "metadata": metadata}) + ) + kind = str(raw.get("kind") or raw.get("memory_type") or raw.get("source_type") or "evidence") + label = str(raw.get("title") or metadata.get("title") or _clip(text, 80)) + uri = str(raw.get("uri") or raw.get("path") or raw.get("url") or metadata.get("uri") or metadata.get("path") or metadata.get("url") or "") + return EvidencePointer( + kind=kind, + ref=str(ref), + label=label, + modality=_infer_modality(raw, text), + user_id=str(raw.get("user_id") or metadata.get("user_id") or default_user_id), + agent_id=str(raw.get("agent_id") or metadata.get("agent_id") or ""), + source_app=str(raw.get("source_app") or metadata.get("source_app") or ""), + source_event_id=str(raw.get("source_event_id") or metadata.get("source_event_id") or ""), + run_id=str(raw.get("run_id") or metadata.get("run_id") or ""), + memory_type=str(raw.get("memory_type") or metadata.get("memory_type") or kind), + confidentiality_scope=str(raw.get("confidentiality_scope") or metadata.get("confidentiality_scope") or "personal"), + uri=uri, + snippet=_clip(text, _MAX_SNIPPET_CHARS), + metadata={ + key: value + for key, value in metadata.items() + if key not in {"text", "body", "content", "memory", "transcript", "ocr"} + }, + ) + + def compile_scene( + self, + evidence_items: Iterable[Any], + *, + user_id: str = "default", + repo: Optional[str] = None, + task: str = "", + privacy_scope: str = "personal", + title: Optional[str] = None, + ) -> TemporalScene: + items = list(evidence_items) + pointers = [self._pointer_from_evidence(item, user_id) for item in items] + if not pointers: + raise ValueError("at least one evidence item is required to compile a scene") + + combined = " ".join(pointer.snippet for pointer in pointers if pointer.snippet) + categories: List[str] = [] + for item in items: + if isinstance(item, dict): + categories.extend(_extract_categories(item)) + token_counts = Counter(_tokens(" ".join([combined, task, " ".join(categories)]))) + tags = [] + for value in categories: + value = value.strip().lower() + if value and value not in tags: + tags.append(value) + for token, _count in token_counts.most_common(16): + if token not in tags: + tags.append(token) + if len(tags) >= 16: + break + entities = [tag for tag in tags if tag[:1].isupper()] + if not entities: + entities = [tag for tag in tags[:8] if tag not in _NOISE_WORDS] + scene_title = title or _clip(task, 90) or _clip(pointers[0].label or combined, 90) or "Temporal scene" + topic = _clip(" ".join(tags[:5]), 120) or scene_title + modalities = sorted({pointer.modality for pointer in pointers if pointer.modality}) + repo_refs = [] + if repo: + repo_refs.append(str(repo)) + for pointer in pointers: + path = pointer.metadata.get("path") or pointer.metadata.get("file_path") + if path and str(path) not in repo_refs: + repo_refs.append(str(path)) + source_apps = sorted({p.source_app for p in pointers if p.source_app}) + agent_ids = sorted({p.agent_id for p in pointers if p.agent_id}) + source_event_ids = sorted({p.source_event_id for p in pointers if p.source_event_id}) + run_ids = sorted({p.run_id for p in pointers if p.run_id}) + memory_types = sorted({p.memory_type for p in pointers if p.memory_type}) + score = self.scorer.score(" ".join([scene_title, task, combined]), pointers) + confidence = min(0.95, 0.45 + min(0.25, len(pointers) * 0.05) + min(0.15, len(source_apps) * 0.05)) + payload_for_id = { + "title": scene_title, + "task": task, + "refs": [p.ref for p in pointers], + "repo": repo, + } + scene = TemporalScene( + id="scene_" + _stable_hash(payload_for_id, 18), + title=scene_title, + summary=_clip(combined, 850), + topic=topic, + user_goal=_clip(task, 280), + action=_clip(task, 360) if task else "", + outcome="Compiled reusable context from admitted evidence.", + lesson=_clip( + "Relevant scene for future agents: " + (task or scene_title) + + ". Use the card first; expand evidence only by pointer when needed.", + 420, + ), + entities=entities[:12], + tags=tags[:16], + modalities=modalities or ["text"], + repo_refs=repo_refs[:12], + evidence=pointers, + provenance={ + "user_id": user_id, + "agent_ids": agent_ids, + "source_apps": source_apps, + "source_event_ids": source_event_ids, + "run_ids": run_ids, + "memory_types": memory_types, + "evidence_count": len(pointers), + }, + privacy_scope=privacy_scope, + confidence=confidence, + score=score, + tier=self.scorer.tier(score), + metadata={ + "source_evidence_hash": _stable_hash([p.to_dict() for p in pointers], 24), + "task": task, + "repo": repo or "", + "storage_policy": "scene_card_plus_pointer_derivatives", + }, + ) + return scene + + +def _normalize_sources(sources: Optional[Iterable[str]]) -> set[str]: + if sources is None: + return {"evidence"} + return {str(source).strip().lower() for source in sources if str(source).strip()} + + +def _memory_rows(memory: Any, *, query: str, user_id: str, limit: int) -> List[Dict[str, Any]]: + if memory is None: + return [] + try: + if query: + result = memory.search(query=query, user_id=user_id, limit=limit) + else: + result = memory.get_all(user_id=user_id, limit=limit) + except TypeError: + try: + result = memory.search(query, user_id=user_id, limit=limit) + except Exception: + return [] + except Exception: + return [] + if isinstance(result, dict): + rows = result.get("results") or result.get("memories") or [] + return [row for row in rows if isinstance(row, dict)] + if isinstance(result, list): + return [row for row in result if isinstance(row, dict)] + return [] + + +def _repo_context_rows(repo: Optional[str | os.PathLike[str]], limit: int) -> List[Dict[str, Any]]: + if not repo: + return [] + try: + from dhee import repo_link + + repo_root = repo_link._resolve_repo(repo) or Path(repo).expanduser().resolve() + entries = repo_link.list_entries(repo_root) + except Exception: + return [] + rows: List[Dict[str, Any]] = [] + for entry in entries[-max(1, int(limit)):]: + rows.append({ + "id": entry.id, + "kind": f"repo_context:{entry.kind}", + "title": entry.title, + "content": entry.content, + "source_app": "dhee-repo-context", + "source_event_id": entry.id, + "agent_id": entry.created_by, + "memory_type": entry.kind, + "confidentiality_scope": "repo", + "metadata": { + "entry_id": entry.id, + "repo": str(repo_root), + "content_hash": entry.content_hash, + "created_at": entry.created_at, + "updated_at": entry.updated_at, + "kind": entry.kind, + "meta": entry.meta, + }, + }) + return rows + + +def _session_rows(session: Any) -> List[Dict[str, Any]]: + if not isinstance(session, dict): + return [] + content_parts: List[str] = [] + for key in ("task_summary", "summary", "title", "status"): + value = session.get(key) + if value: + content_parts.append(f"{key}: {value}") + for key in ("decisions", "decisions_made", "files_touched", "todos", "todos_remaining", "blockers", "test_results"): + value = session.get(key) + if isinstance(value, list) and value: + content_parts.append(f"{key}: " + "; ".join(str(item) for item in value[:12])) + content = "\n".join(content_parts).strip() + if not content: + return [] + return [{ + "id": session.get("id") or session.get("session_id") or _stable_hash(session, 12), + "kind": "session_digest", + "title": session.get("task_summary") or session.get("title") or "Session digest", + "content": content, + "source_app": "dhee-session", + "source_event_id": session.get("id") or session.get("session_id") or "", + "agent_id": session.get("agent_id") or "", + "run_id": session.get("id") or session.get("session_id") or "", + "memory_type": "session_digest", + "confidentiality_scope": "personal", + "metadata": {k: v for k, v in session.items() if k not in {"messages"}}, + }] + + +def _shared_task_rows(results: Any) -> List[Dict[str, Any]]: + if isinstance(results, dict): + rows = results.get("results") or [] + else: + rows = results or [] + out: List[Dict[str, Any]] = [] + for row in rows: + if not isinstance(row, dict): + continue + digest = row.get("digest") or row.get("summary") or row.get("content") or "" + if not digest: + continue + out.append({ + "id": row.get("id") or _stable_hash(row, 12), + "kind": row.get("packet_kind") or "shared_task_result", + "title": row.get("tool_name") or row.get("packet_kind") or "Shared task result", + "content": digest, + "source_app": row.get("harness") or "dhee-shared-task", + "source_event_id": row.get("id") or "", + "agent_id": row.get("agent_id") or "", + "run_id": row.get("shared_task_id") or "", + "memory_type": "shared_task_result", + "confidentiality_scope": "personal", + "metadata": dict(row.get("metadata") or {}), + }) + return out + + +def _artifact_rows(artifacts: Any) -> List[Dict[str, Any]]: + rows = artifacts.get("results") if isinstance(artifacts, dict) else artifacts + out: List[Dict[str, Any]] = [] + for row in rows or []: + if not isinstance(row, dict): + continue + content = row.get("summary") or row.get("text") or row.get("filename") or row.get("source_path") or "" + if not content: + continue + out.append({ + "id": row.get("artifact_id") or row.get("id") or _stable_hash(row, 12), + "kind": "artifact", + "title": row.get("filename") or row.get("title") or "Artifact", + "content": content, + "source_app": "dhee-artifact", + "source_event_id": row.get("artifact_id") or row.get("id") or "", + "memory_type": "artifact", + "confidentiality_scope": str(row.get("confidentiality_scope") or "personal"), + "metadata": dict(row), + }) + return out + + +def collect_scene_evidence( + *, + evidence: Optional[Iterable[Any]] = None, + memory: Any = None, + query: str = "", + user_id: str = "default", + repo: Optional[str | os.PathLike[str]] = None, + session: Optional[Dict[str, Any]] = None, + shared_task_results: Any = None, + artifacts: Any = None, + sources: Optional[Iterable[str]] = None, + limit: int = 20, +) -> List[Any]: + """Collect compact evidence derivatives from Dhee's existing surfaces. + + This is intentionally pointer/card oriented: it pulls summaries, digests, + repo entries, and metadata identities, not raw media or unbounded logs. + """ + + selected = _normalize_sources(sources) + rows: List[Any] = [] + if evidence and ("evidence" in selected or not selected): + rows.extend(list(evidence)) + if "memory" in selected: + rows.extend(_memory_rows(memory, query=query, user_id=user_id, limit=limit)) + if "repo_context" in selected or "repo" in selected: + rows.extend(_repo_context_rows(repo, limit=limit)) + if "session" in selected or "session_digest" in selected: + rows.extend(_session_rows(session)) + if "shared_task_results" in selected or "shared_task" in selected: + rows.extend(_shared_task_rows(shared_task_results)) + if "artifacts" in selected or "artifact" in selected: + rows.extend(_artifact_rows(artifacts)) + + deduped: List[Any] = [] + seen: set[str] = set() + for row in rows: + if isinstance(row, dict): + key = str(row.get("id") or row.get("ref") or _stable_hash(row, 12)) + else: + key = _stable_hash(row, 12) + if key in seen: + continue + seen.add(key) + deduped.append(row) + if len(deduped) >= max(1, int(limit)): + break + return deduped + + +def collect_live_scene_sources( + *, + db: Any = None, + repo: Optional[str | os.PathLike[str]] = None, + user_id: str = "default", + agent_id: str = "codex", + limit: int = 10, + include_session: bool = True, + include_shared_task_results: bool = True, + include_artifacts: bool = False, +) -> Dict[str, Any]: + """Fetch compact live Dhee surfaces for scene compilation. + + The returned payload is shaped for :func:`collect_scene_evidence`. + All reads are best-effort and bounded; failures return missing/empty + fields instead of raising into MCP handlers. + """ + + out: Dict[str, Any] = {} + repo_str = str(repo) if repo else None + if include_session: + session = None + try: + from dhee.core.kernel import get_last_session + + candidate_agents = [] + for candidate in (agent_id, "codex", "claude-code", "mcp-server"): + if candidate and candidate not in candidate_agents: + candidate_agents.append(candidate) + for candidate in candidate_agents: + session = get_last_session( + agent_id=candidate, + repo=repo_str, + user_id=user_id, + requester_agent_id=agent_id or "codex", + fallback_log_recovery=True, + ) + if session: + break + except Exception: + session = None + if session: + out["session"] = session + + if db is not None and include_shared_task_results: + try: + from dhee.core.shared_tasks import shared_task_snapshot + + out["shared_task_results"] = shared_task_snapshot( + db, + user_id=user_id, + repo=repo_str, + workspace_id=repo_str, + limit=max(1, int(limit)), + ) + except Exception: + out["shared_task_results"] = {"task": None, "results": []} + + if db is not None and include_artifacts and hasattr(db, "list_artifacts"): + try: + out["artifacts"] = db.list_artifacts( + user_id=user_id, + workspace_id=repo_str, + limit=max(1, int(limit)), + ) + except Exception: + out["artifacts"] = [] + return out + + +class SceneStore: + """Append-only JSONL store for private scene cards.""" + + def __init__(self, root: Optional[str | os.PathLike[str]] = None) -> None: + base = ( + root or os.environ.get("DHEE_TEMPORAL_SCENE_DIR") or + (Path(os.environ["DHEE_DATA_DIR"]) / "temporal_scenes" if os.environ.get("DHEE_DATA_DIR") else None) or + (Path.home() / ".dhee" / "temporal_scenes") + ) + self.root = Path(base).expanduser().resolve() + + def _path(self, user_id: str) -> Path: + return self.root / f"{_safe_user_key(user_id)}.jsonl" + + def save(self, scene: TemporalScene) -> TemporalScene: + self.root.mkdir(parents=True, exist_ok=True) + path = self._path(str(scene.provenance.get("user_id") or "default")) + with path.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(scene.to_dict(), sort_keys=True, default=str) + "\n") + return scene + + def list( + self, + *, + user_id: str = "default", + limit: int = 50, + include_cold: bool = True, + ) -> List[TemporalScene]: + path = self._path(user_id) + if not path.exists(): + return [] + by_id: Dict[str, TemporalScene] = {} + with path.open("r", encoding="utf-8", errors="replace") as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + scene = TemporalScene.from_dict(json.loads(line)) + except Exception: + continue + if include_cold or scene.tier != "cold": + by_id[scene.id] = scene + scenes = sorted(by_id.values(), key=lambda scene: (scene.score, scene.created_at), reverse=True) + return scenes[: max(0, int(limit))] + + def search( + self, + query: str, + *, + user_id: str = "default", + repo: Optional[str] = None, + limit: int = 5, + include_personal: bool = True, + ) -> List[TemporalScene]: + query_terms = set(_tokens(query)) + scenes = self.list(user_id=user_id, limit=500, include_cold=True) + ranked: List[tuple[float, TemporalScene]] = [] + repo_norm = str(repo or "") + for scene in scenes: + if not include_personal and scene.privacy_scope == "personal": + continue + if repo_norm and not any(repo_norm in ref or ref in repo_norm for ref in scene.repo_refs): + continue + haystack = " ".join([ + scene.title, scene.summary, scene.topic, scene.lesson, + " ".join(scene.tags), " ".join(scene.repo_refs), + ]) + terms = set(_tokens(haystack)) + overlap = len(query_terms & terms) + if query_terms and overlap == 0: + continue + rank = float(scene.score) + overlap * 0.12 + (0.05 if scene.tier == "hot" else 0) + ranked.append((rank, scene)) + ranked.sort(key=lambda item: item[0], reverse=True) + return [scene for _rank, scene in ranked[: max(0, int(limit))]] + + +class ContextPackCompiler: + """Build hard-budget context packs from scene cards.""" + + def __init__(self, store: Optional[SceneStore] = None) -> None: + self.store = store or SceneStore() + + def build( + self, + query: str, + *, + user_id: str = "default", + repo: Optional[str] = None, + token_budget: int = 1200, + limit: int = 5, + include_personal: bool = True, + ) -> Dict[str, Any]: + cards: List[Dict[str, Any]] = [] + used_tokens = 0 + for scene in self.store.search( + query, + user_id=user_id, + repo=repo, + limit=limit * 3, + include_personal=include_personal, + ): + remaining = max(1, int(token_budget) - used_tokens) + card = scene.to_card(max_chars=max(80, min(900, remaining * 4))) + card_tokens = _estimate_tokens(json.dumps(card, sort_keys=True, default=str)) + if card_tokens > remaining: + card["summary"] = _clip(card.get("summary") or "", max(40, remaining * 2)) + card["lesson"] = _clip(card.get("lesson") or "", max(40, remaining)) + card["evidence_refs"] = list(card.get("evidence_refs") or [])[:3] + card_tokens = _estimate_tokens(json.dumps(card, sort_keys=True, default=str)) + if card_tokens > remaining: + card = { + "id": scene.id, + "title": _clip(scene.title, 120), + "summary": _clip(scene.summary, max(40, remaining * 2)), + "tags": scene.tags[:6], + "tier": scene.tier, + "evidence_refs": [ + { + "kind": pointer.kind, + "ref": pointer.ref, + "modality": pointer.modality, + } + for pointer in scene.evidence[:2] + ], + } + card_tokens = _estimate_tokens(json.dumps(card, sort_keys=True, default=str)) + if cards and used_tokens + card_tokens > token_budget: + continue + if card_tokens > token_budget: + continue + cards.append(card) + used_tokens += card_tokens + if len(cards) >= limit: + break + return { + "format": "dhee_context_pack.v1", + "query": query, + "user_id": user_id, + "repo": repo, + "token_budget": int(token_budget), + "estimated_tokens": used_tokens, + "scene_cards": cards, + "evidence_policy": "summaries_only_raw_evidence_by_pointer", + "raw_media_included": False, + "full_diffs_included": False, + } + + +class PromotionGate: + """Privacy boundary between personal scenes and shareable repo capsules.""" + + _LOCAL_PATH_RE = re.compile(r"(/Users/[^\s\"']+|/home/[^\s\"']+|[A-Za-z]:\\\\[^\s\"']+)") + + def sanitize_scene(self, scene: TemporalScene, *, share_scope: str = "repo") -> Dict[str, Any]: + data = scene.to_card() + safe_refs: List[Dict[str, Any]] = [] + for pointer in scene.evidence: + if pointer.confidentiality_scope in {"secret", "restricted"}: + continue + safe = pointer.to_dict(include_snippet=False, include_private_uri=False) + safe["confidentiality_scope"] = "redacted" if pointer.confidentiality_scope == "personal" else pointer.confidentiality_scope + safe["label"] = self._redact_text(safe.get("label") or "") + safe_refs.append(safe) + data["evidence_refs"] = safe_refs[:8] + data["privacy_scope"] = share_scope + data["personal_context_used"] = scene.privacy_scope == "personal" + data["summary"] = self._redact_text(data.get("summary") or "") + data["lesson"] = self._redact_text(data.get("lesson") or "") + return data + + def _redact_text(self, text: str) -> str: + return self._LOCAL_PATH_RE.sub("", text or "") + + +def compile_scene( + evidence: Iterable[Any], + *, + user_id: str = "default", + repo: Optional[str] = None, + task: str = "", + privacy_scope: str = "personal", + title: Optional[str] = None, + store_dir: Optional[str | os.PathLike[str]] = None, + save: bool = True, +) -> TemporalScene: + scene = SceneCompiler().compile_scene( + evidence, + user_id=user_id, + repo=repo, + task=task, + privacy_scope=privacy_scope, + title=title, + ) + if save: + SceneStore(store_dir).save(scene) + return scene + + +def compile_scene_from_sources( + *, + evidence: Optional[Iterable[Any]] = None, + memory: Any = None, + query: str = "", + user_id: str = "default", + repo: Optional[str | os.PathLike[str]] = None, + session: Optional[Dict[str, Any]] = None, + shared_task_results: Any = None, + artifacts: Any = None, + sources: Optional[Iterable[str]] = None, + limit: int = 20, + task: str = "", + privacy_scope: str = "personal", + title: Optional[str] = None, + store_dir: Optional[str | os.PathLike[str]] = None, + save: bool = True, +) -> TemporalScene: + collected = collect_scene_evidence( + evidence=evidence, + memory=memory, + query=query, + user_id=user_id, + repo=repo, + session=session, + shared_task_results=shared_task_results, + artifacts=artifacts, + sources=sources, + limit=limit, + ) + return compile_scene( + collected, + user_id=user_id, + repo=str(repo) if repo else None, + task=task or query, + privacy_scope=privacy_scope, + title=title, + store_dir=store_dir, + save=save, + ) + + +def search_scenes( + query: str, + *, + user_id: str = "default", + repo: Optional[str] = None, + limit: int = 5, + store_dir: Optional[str | os.PathLike[str]] = None, + include_personal: bool = True, +) -> List[TemporalScene]: + return SceneStore(store_dir).search( + query, + user_id=user_id, + repo=repo, + limit=limit, + include_personal=include_personal, + ) + + +def build_context_pack( + query: str, + *, + user_id: str = "default", + repo: Optional[str] = None, + token_budget: int = 1200, + limit: int = 5, + store_dir: Optional[str | os.PathLike[str]] = None, + include_personal: bool = True, +) -> Dict[str, Any]: + return ContextPackCompiler(SceneStore(store_dir)).build( + query, + user_id=user_id, + repo=repo, + token_budget=token_budget, + limit=limit, + include_personal=include_personal, + ) diff --git a/dhee/update_capsules.py b/dhee/update_capsules.py new file mode 100644 index 0000000..0905dfb --- /dev/null +++ b/dhee/update_capsules.py @@ -0,0 +1,714 @@ +"""Repo-shareable update capsules. + +An update capsule is a sanitized recipe, not an auto-applied patch. It gives +another agent the before/after story, changed interfaces, compact hunks, +hashes, commands, and evidence pointers needed to recreate behavior with +normal editing and verification tools. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +import subprocess +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from dhee.context_ir import build_context_ir, interpret_context_ir, render_context_ir, validate_context_ir +from dhee import repo_link + + +CAPSULE_SCHEMA_VERSION = 1 +CAPSULE_KIND = "update_capsule" +MAX_DIFF_CHARS_PER_FILE = 18_000 +MAX_MD_DIFF_CHARS = 4_000 + +_SECRET_PATTERNS = [ + re.compile(r"sk-[A-Za-z0-9_-]{20,}"), + re.compile(r"gh[pousr]_[A-Za-z0-9_]{20,}"), + re.compile(r"xox[baprs]-[A-Za-z0-9-]{20,}"), + re.compile(r"(?i)\b(api[_-]?key|token|secret|password|passwd)\b\s*[:=]\s*['\"]?[^'\"\s]{8,}"), + re.compile(r"eyJ[A-Za-z0-9_-]{20,}\.[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}"), +] +_LOCAL_PATH_RE = re.compile(r"(/Users/[^\s\"']+|/home/[^\s\"']+|[A-Za-z]:\\\\[^\s\"']+)") + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def _stable_hash(data: Any, length: int = 18) -> str: + raw = json.dumps(data, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:length] + + +def _sanitize_text(text: str) -> str: + value = str(text or "") + home = str(Path.home()) + if home: + value = value.replace(home, "$HOME") + value = _LOCAL_PATH_RE.sub("", value) + for pattern in _SECRET_PATTERNS: + value = pattern.sub("", value) + return value + + +def _sanitize_obj(value: Any) -> Any: + if isinstance(value, str): + return _sanitize_text(value) + if isinstance(value, list): + return [_sanitize_obj(item) for item in value] + if isinstance(value, tuple): + return [_sanitize_obj(item) for item in value] + if isinstance(value, dict): + return {str(key): _sanitize_obj(item) for key, item in value.items()} + return value + + +def _run_git(repo_root: Path, args: List[str], *, check: bool = False) -> subprocess.CompletedProcess[str]: + proc = subprocess.run( + ["git", "-C", str(repo_root), *args], + text=True, + capture_output=True, + check=False, + ) + if check and proc.returncode != 0: + raise ValueError((proc.stderr or proc.stdout or "git command failed").strip()) + return proc + + +def _git_out(repo_root: Path, args: List[str], default: str = "") -> str: + proc = _run_git(repo_root, args) + if proc.returncode != 0: + return default + return proc.stdout.strip() + + +def _resolve_repo_root(repo: str | os.PathLike[str] | None) -> Path: + base = Path(repo or os.getcwd()).expanduser().resolve() + proc = subprocess.run( + ["git", "-C", str(base), "rev-parse", "--show-toplevel"], + text=True, + capture_output=True, + check=False, + ) + if proc.returncode == 0 and proc.stdout.strip(): + return Path(proc.stdout.strip()).resolve() + return base + + +def _is_valid_ref(repo_root: Path, ref: str) -> bool: + if not ref: + return False + return _run_git(repo_root, ["rev-parse", "--verify", f"{ref}^{{commit}}"]).returncode == 0 + + +def _is_dhee_generated_context_path(rel_path: str) -> bool: + path = str(rel_path or "").replace("\\", "/") + while path.startswith("./"): + path = path[2:] + return path == ".dhee" or path.startswith(".dhee/") + + +def _changed_paths(repo_root: Path, since: Optional[str]) -> List[Dict[str, Any]]: + by_path: Dict[str, Dict[str, Any]] = {} + if since and _is_valid_ref(repo_root, since): + for path in _git_out(repo_root, ["diff", "--name-only", since, "--"]).splitlines(): + if path.strip(): + by_path[path.strip()] = {"path": path.strip(), "status": "modified"} + else: + for path in _git_out(repo_root, ["diff", "--name-only", "--"]).splitlines(): + if path.strip(): + by_path[path.strip()] = {"path": path.strip(), "status": "modified"} + + status = _git_out(repo_root, ["status", "--porcelain=v1", "--untracked-files=all"]) + for line in status.splitlines(): + if not line: + continue + code = line[:2] + raw_path = (line[3:] if len(line) > 2 and line[2] == " " else line[2:]).strip() + if " -> " in raw_path: + _old, raw_path = raw_path.split(" -> ", 1) + status_name = _status_name(code) + by_path.setdefault(raw_path, {"path": raw_path, "status": status_name}) + by_path[raw_path]["status"] = status_name + return [ + by_path[key] + for key in sorted(by_path) + if not _is_dhee_generated_context_path(key) + ] + + +def _status_name(code: str) -> str: + code = code or "" + if "?" in code: + return "untracked" + if "D" in code: + return "deleted" + if "A" in code: + return "added" + if "R" in code: + return "renamed" + if "M" in code: + return "modified" + return "changed" + + +def _file_hash(repo_root: Path, rel_path: str) -> Optional[str]: + path = (repo_root / rel_path).resolve() + try: + if not path.exists() or not path.is_file(): + return None + if not str(path).startswith(str(repo_root.resolve())): + return None + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + except OSError: + return None + + +def _file_size(repo_root: Path, rel_path: str) -> Optional[int]: + try: + path = repo_root / rel_path + if path.exists() and path.is_file(): + return path.stat().st_size + except OSError: + return None + return None + + +def _git_blob_hash(repo_root: Path, ref: str, rel_path: str) -> Optional[str]: + if not ref or not rel_path or not _is_valid_ref(repo_root, ref): + return None + proc = subprocess.run( + ["git", "-C", str(repo_root), "show", f"{ref}:{rel_path}"], + capture_output=True, + check=False, + ) + if proc.returncode != 0: + return None + return hashlib.sha256(proc.stdout).hexdigest() + + +def _diff_for_path(repo_root: Path, rel_path: str, since: Optional[str]) -> Tuple[str, bool]: + args = ["diff", "--no-ext-diff", "--unified=3"] + if since and _is_valid_ref(repo_root, since): + args.append(since) + args.extend(["--", rel_path]) + diff = _git_out(repo_root, args) + truncated = len(diff) > MAX_DIFF_CHARS_PER_FILE + if truncated: + diff = diff[:MAX_DIFF_CHARS_PER_FILE].rstrip() + "\n[diff truncated]" + return _sanitize_text(diff), truncated + + +def _status_summary(repo_root: Path) -> Dict[str, Any]: + status = _git_out(repo_root, ["status", "--porcelain=v1", "--untracked-files=all"]) + lines = [line for line in status.splitlines() if line.strip()] + return { + "dirty": bool(lines), + "porcelain": [_sanitize_text(line) for line in lines[:200]], + "untracked_count": sum(1 for line in lines if line.startswith("??")), + "changed_count": len(lines), + } + + +def _compact_evidence_pointers(evidence: Optional[Iterable[Any]]) -> Tuple[List[Dict[str, Any]], bool]: + pointers: List[Dict[str, Any]] = [] + personal_used = False + for item in evidence or []: + if not isinstance(item, dict): + item = {"ref": str(item), "kind": "evidence"} + scope = str(item.get("confidentiality_scope") or item.get("privacy_scope") or "personal") + if scope == "personal": + personal_used = True + if scope in {"secret", "restricted"}: + personal_used = True + continue + safe = { + "kind": str(item.get("kind") or item.get("memory_type") or "evidence"), + "ref": str(item.get("ref") or item.get("id") or item.get("memory_id") or _stable_hash(item, 12)), + "label": _sanitize_text(str(item.get("label") or item.get("title") or ""))[:160], + "source_app": str(item.get("source_app") or ""), + "agent_id": str(item.get("agent_id") or ""), + "source_event_id": str(item.get("source_event_id") or ""), + "run_id": str(item.get("run_id") or ""), + "modality": str(item.get("modality") or "text"), + "confidentiality_scope": "redacted" if scope == "personal" else scope, + } + pointers.append(safe) + return pointers, personal_used + + +@dataclass +class UpdateCapsule: + id: str + title: str + summary: str + repo_root: str + repo_id: str + base_ref: str + base_commit: str + head_commit: str + created_at: str + changed_paths: List[Dict[str, Any]] = field(default_factory=list) + base_file_hashes: Dict[str, str] = field(default_factory=dict) + file_hashes: Dict[str, str] = field(default_factory=dict) + compact_hunks: List[Dict[str, Any]] = field(default_factory=list) + commands: List[str] = field(default_factory=list) + evidence_pointers: List[Dict[str, Any]] = field(default_factory=list) + compatibility_notes: List[str] = field(default_factory=list) + personal_context_used: bool = False + privacy: Dict[str, Any] = field(default_factory=dict) + context_ir: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "schema_version": CAPSULE_SCHEMA_VERSION, + "kind": CAPSULE_KIND, + "id": self.id, + "title": self.title, + "summary": self.summary, + "repo": { + "root_name": Path(self.repo_root).name, + "repo_id": self.repo_id, + "base_ref": self.base_ref, + "base_commit": self.base_commit, + "head_commit": self.head_commit, + }, + "base_ref": self.base_ref, + "base_commit": self.base_commit, + "head_commit": self.head_commit, + "created_at": self.created_at, + "changed_paths": self.changed_paths, + "base_file_hashes": self.base_file_hashes, + "file_hashes": self.file_hashes, + "compact_hunks": self.compact_hunks, + "commands": self.commands, + "evidence_pointers": self.evidence_pointers, + "compatibility_notes": self.compatibility_notes, + "personal_context_used": self.personal_context_used, + "privacy": self.privacy, + "context_ir": self.context_ir, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, raw: Dict[str, Any]) -> "UpdateCapsule": + repo = raw.get("repo") or {} + return cls( + id=str(raw.get("id") or _stable_hash(raw)), + title=str(raw.get("title") or ""), + summary=str(raw.get("summary") or ""), + repo_root=str(raw.get("repo_root") or repo.get("root_name") or ""), + repo_id=str(raw.get("repo_id") or repo.get("repo_id") or ""), + base_ref=str(raw.get("base_ref") or repo.get("base_ref") or ""), + base_commit=str(raw.get("base_commit") or repo.get("base_commit") or ""), + head_commit=str(raw.get("head_commit") or repo.get("head_commit") or ""), + created_at=str(raw.get("created_at") or _now_iso()), + changed_paths=list(raw.get("changed_paths") or []), + base_file_hashes=dict(raw.get("base_file_hashes") or {}), + file_hashes=dict(raw.get("file_hashes") or {}), + compact_hunks=list(raw.get("compact_hunks") or []), + commands=list(raw.get("commands") or []), + evidence_pointers=list(raw.get("evidence_pointers") or []), + compatibility_notes=list(raw.get("compatibility_notes") or []), + personal_context_used=bool(raw.get("personal_context_used") or False), + privacy=dict(raw.get("privacy") or {}), + context_ir=dict(raw.get("context_ir") or {}), + metadata=dict(raw.get("metadata") or {}), + ) + + +def render_capsule_markdown(capsule: UpdateCapsule) -> str: + changed = capsule.changed_paths + paths = "\n".join( + f"- `{item.get('path')}` ({item.get('status') or 'changed'})" + for item in changed + ) or "- No changed paths detected." + hunks: List[str] = [] + for hunk in capsule.compact_hunks: + diff = str(hunk.get("diff") or "") + if len(diff) > MAX_MD_DIFF_CHARS: + diff = diff[:MAX_MD_DIFF_CHARS].rstrip() + "\n[diff clipped in markdown; see capsule.json]" + if diff: + hunks.append(f"### {hunk.get('path')}\n\n```diff\n{diff}\n```") + else: + hunks.append(f"### {hunk.get('path')}\n\nNo compact diff available, usually because the file is untracked or binary.") + hunk_text = "\n\n".join(hunks) or "No compact hunks captured." + commands = "\n".join(f"- `{cmd}`" for cmd in capsule.commands) or "- No test command was recorded." + evidence = "\n".join( + f"- `{ptr.get('kind')}` `{ptr.get('ref')}`" + + (f" from {ptr.get('source_app')}" if ptr.get("source_app") else "") + + (f" via {ptr.get('agent_id')}" if ptr.get("agent_id") else "") + for ptr in capsule.evidence_pointers + ) or "- No shareable evidence pointers were attached." + compatibility = "\n".join(f"- {note}" for note in capsule.compatibility_notes) or "- No compatibility notes." + ir_summary = render_context_ir(capsule.context_ir) if capsule.context_ir else "- No Context IR compiled." + md = f"""# {capsule.title} + +## Intent +{capsule.summary} + +## Before +Base ref: `{capsule.base_ref or '(unspecified)'}` +Base commit: `{capsule.base_commit or '(unknown)'}` + +## After +Head commit at capture: `{capsule.head_commit or '(unknown)'}` +Changed paths captured: {len(changed)} + +## Touched Interfaces +{paths} + +## Compact Hunks +{hunk_text} + +## Reproduction Guide +1. Read `capsule.json` for exact paths, hashes, and compact hunks. +2. Recreate the behavior with normal editing tools; this capsule is context, not an auto-apply patch. +3. Run the recorded commands or the nearest repo test suite. +4. Compare final file hashes or behavior against the capsule notes when useful. + +## Context IR +{ir_summary} +- Interpreter policy: validate schema, resolve file symbols on the target repo, produce an execution plan, never auto-apply. + +## Tests And Commands +{commands} + +## Evidence Pointers +{evidence} + +## Privacy And Sharing +- Raw personal memories, screenshots, transcripts, media, local paths, and secrets are not included. +- `personal_context_used`: `{str(capsule.personal_context_used).lower()}` +- Share scope: `{capsule.privacy.get('share_scope') or 'repo'}` + +## Compatibility Notes +{compatibility} +""" + return _sanitize_text(md).strip() + "\n" + + +def _capsule_root(repo_root: Path) -> Path: + return repo_link.repo_context_dir(repo_root) / "capsules" + + +def _write_capsule(capsule: UpdateCapsule, capsule_dir: Path) -> Dict[str, str]: + capsule_dir.mkdir(parents=True, exist_ok=True) + data = _sanitize_obj(capsule.to_dict()) + md = _sanitize_text(render_capsule_markdown(capsule)) + json_path = capsule_dir / "capsule.json" + md_path = capsule_dir / "capsule.md" + json_path.write_text(_json_dumps(data) + "\n", encoding="utf-8") + md_path.write_text(md, encoding="utf-8") + return {"json": str(json_path), "markdown": str(md_path), "dir": str(capsule_dir)} + + +def create_update_capsule( + *, + repo: str | os.PathLike[str] | None = None, + since: Optional[str] = None, + task_id: Optional[str] = None, + out: Optional[str | os.PathLike[str]] = None, + title: Optional[str] = None, + summary: Optional[str] = None, + commands: Optional[List[str]] = None, + evidence: Optional[Iterable[Any]] = None, +) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + repo_id = repo_link._ensure_repo_skeleton(repo_root) + base_ref = since or "HEAD" + base_commit = _git_out(repo_root, ["rev-parse", "--verify", f"{base_ref}^{{commit}}"], default="") + if not base_commit and since: + raise ValueError(f"Base ref {since!r} is not a valid commit") + head_commit = _git_out(repo_root, ["rev-parse", "--verify", "HEAD^{commit}"], default="") + status = _status_summary(repo_root) + changed_paths = _changed_paths(repo_root, since) + base_file_hashes: Dict[str, str] = {} + file_hashes: Dict[str, str] = {} + compact_hunks: List[Dict[str, Any]] = [] + for item in changed_paths: + rel_path = str(item.get("path") or "") + base_digest = _git_blob_hash(repo_root, base_ref, rel_path) + if base_digest: + base_file_hashes[rel_path] = base_digest + item["base_sha256"] = base_digest + digest = _file_hash(repo_root, rel_path) + if digest: + file_hashes[rel_path] = digest + item["sha256"] = digest + size = _file_size(repo_root, rel_path) + if size is not None: + item["size"] = size + diff, truncated = _diff_for_path(repo_root, rel_path, since) + compact_hunks.append({ + "path": rel_path, + "status": item.get("status") or "changed", + "diff": diff, + "truncated": truncated, + }) + evidence_pointers, evidence_personal = _compact_evidence_pointers(evidence) + personal_context_used = bool(evidence_personal) + capsule_title = title or f"Update capsule {task_id or (base_ref + ' -> worktree')}" + capsule_summary = summary or ( + f"Captured {len(changed_paths)} changed path(s) from {base_ref} to the current worktree." + ) + capsule_commands = _sanitize_obj(commands or [ + f"git diff --stat {base_ref}", + "git status --short", + ]) + capsule_privacy = { + "share_scope": "repo", + "raw_personal_memory_included": False, + "raw_media_included": False, + "screenshots_included": False, + "transcripts_included": False, + "local_paths_redacted": True, + "secrets_redacted": True, + "redaction_applied": True, + "promotion_required_for_personal_lessons": True, + } + compatibility_notes = [ + "Capsule is a hybrid recipe; V1 does not auto-apply patches.", + "Whole-file snapshots are not stored by default.", + "Context IR is interpreted on the receiving machine before any edits are attempted.", + ] + if status["dirty"]: + compatibility_notes.append("Worktree was dirty at capture time; verify staged and unstaged edits explicitly.") + if any(item.get("status") == "untracked" for item in changed_paths): + compatibility_notes.append("Untracked files are listed with hashes but no git diff body.") + payload_for_id = { + "repo_id": repo_id, + "base_commit": base_commit, + "head_commit": head_commit, + "changed_paths": changed_paths, + "compact_hunks_hash": _stable_hash(compact_hunks, 24), + "task_id": task_id or "", + } + capsule_id = "ucap_" + _stable_hash(payload_for_id, 20) + context_ir = build_context_ir( + capsule_id=capsule_id, + title=_sanitize_text(capsule_title), + summary=_sanitize_text(capsule_summary), + repo_id=repo_id, + base_ref=base_ref, + base_commit=base_commit, + head_commit=head_commit, + changed_paths=_sanitize_obj(changed_paths), + compact_hunks=_sanitize_obj(compact_hunks), + commands=capsule_commands, + evidence_pointers=evidence_pointers, + base_file_hashes=_sanitize_obj(base_file_hashes), + file_hashes=_sanitize_obj(file_hashes), + privacy=capsule_privacy, + ) + capsule = UpdateCapsule( + id=capsule_id, + title=_sanitize_text(capsule_title), + summary=_sanitize_text(capsule_summary), + repo_root=str(repo_root), + repo_id=repo_id, + base_ref=base_ref, + base_commit=base_commit, + head_commit=head_commit, + created_at=_now_iso(), + changed_paths=_sanitize_obj(changed_paths), + base_file_hashes=_sanitize_obj(base_file_hashes), + file_hashes=_sanitize_obj(file_hashes), + compact_hunks=_sanitize_obj(compact_hunks), + commands=capsule_commands, + evidence_pointers=evidence_pointers, + compatibility_notes=compatibility_notes, + personal_context_used=personal_context_used, + privacy=capsule_privacy, + context_ir=context_ir, + metadata={ + "task_id": task_id or "", + "status": status, + "capsule_payload": "hybrid_recipe_not_patch_only", + "compiler": "dhee-context-compiler", + }, + ) + capsule_dir = Path(out).expanduser().resolve() if out else _capsule_root(repo_root) / capsule.id + paths = _write_capsule(capsule, capsule_dir) + md = Path(paths["markdown"]).read_text(encoding="utf-8") + rel_dir = os.path.relpath(capsule_dir, repo_root) if str(capsule_dir).startswith(str(repo_root)) else str(capsule_dir) + entry = repo_link.add_entry( + repo_root, + kind=CAPSULE_KIND, + title=capsule.title, + content=md, + meta={ + "capsule_id": capsule.id, + "capsule_dir": rel_dir, + "base_commit": capsule.base_commit, + "head_commit": capsule.head_commit, + "changed_paths": [item.get("path") for item in capsule.changed_paths], + "personal_context_used": capsule.personal_context_used, + "privacy": capsule.privacy, + }, + ) + return { + "format": "dhee_update_capsule_create.v1", + "capsule": capsule.to_dict(), + "paths": paths, + "entry": entry.to_json(), + } + + +def list_update_capsules(*, repo: str | os.PathLike[str] | None = None) -> List[Dict[str, Any]]: + repo_root = _resolve_repo_root(repo) + root = _capsule_root(repo_root) + if not root.exists(): + return [] + capsules: List[Dict[str, Any]] = [] + for json_path in sorted(root.glob("*/capsule.json")): + try: + data = json.loads(json_path.read_text(encoding="utf-8")) + except Exception: + continue + capsules.append({ + "id": data.get("id"), + "title": data.get("title"), + "created_at": data.get("created_at"), + "changed_paths": data.get("changed_paths") or [], + "path": str(json_path.parent), + "personal_context_used": bool(data.get("personal_context_used")), + }) + return capsules + + +def get_update_capsule( + capsule_id: str, + *, + repo: str | os.PathLike[str] | None = None, +) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + root = _capsule_root(repo_root) + matches = [path for path in root.glob("*/capsule.json") if path.parent.name == capsule_id or path.parent.name.startswith(capsule_id)] + if not matches: + raise FileNotFoundError(f"Update capsule {capsule_id!r} not found") + json_path = matches[0] + md_path = json_path.with_name("capsule.md") + data = json.loads(json_path.read_text(encoding="utf-8")) + return { + "format": "dhee_update_capsule_get.v1", + "capsule": data, + "markdown": md_path.read_text(encoding="utf-8") if md_path.exists() else "", + "paths": {"json": str(json_path), "markdown": str(md_path), "dir": str(json_path.parent)}, + } + + +def _load_capsule_data( + capsule: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, +) -> Dict[str, Any]: + if isinstance(capsule, dict): + return capsule + value = str(capsule) + source = Path(value).expanduser() + if source.exists(): + data, _md, _source_dir = _read_import_source(source) + return data + return get_update_capsule(value, repo=repo)["capsule"] + + +def interpret_update_capsule( + capsule: str | os.PathLike[str] | Dict[str, Any], + *, + repo: str | os.PathLike[str] | None = None, + strict: bool = False, +) -> Dict[str, Any]: + """Interpret a compiled capsule on the target repo without applying it.""" + + data = _load_capsule_data(capsule, repo=repo) + return interpret_context_ir(repo=repo, capsule_or_ir=data, strict=strict) + + +def _read_import_source(path: str | os.PathLike[str]) -> Tuple[Dict[str, Any], str, Path]: + source = Path(path).expanduser().resolve() + if source.is_dir(): + json_path = source / "capsule.json" + md_path = source / "capsule.md" + source_dir = source + elif source.suffix == ".json": + json_path = source + md_path = source.with_name("capsule.md") + source_dir = source.parent + elif source.suffix == ".md": + json_path = source.with_name("capsule.json") + md_path = source + source_dir = source.parent + else: + raise ValueError("Import path must be a capsule directory, capsule.json, or capsule.md") + if not json_path.exists(): + raise FileNotFoundError(f"Missing capsule.json near {source}") + data = json.loads(json_path.read_text(encoding="utf-8")) + md = md_path.read_text(encoding="utf-8") if md_path.exists() else render_capsule_markdown(UpdateCapsule.from_dict(data)) + return data, md, source_dir + + +def import_update_capsule( + path: str | os.PathLike[str], + *, + repo: str | os.PathLike[str] | None = None, + allow_private: bool = False, +) -> Dict[str, Any]: + repo_root = _resolve_repo_root(repo) + repo_link._ensure_repo_skeleton(repo_root) + data, md, _source_dir = _read_import_source(path) + privacy = data.get("privacy") or {} + if privacy.get("raw_personal_memory_included") and not allow_private: + raise ValueError("Capsule import rejected: raw personal memory is marked as included") + data = _sanitize_obj(data) + if isinstance(data.get("context_ir"), dict) and data["context_ir"]: + validation = validate_context_ir(data["context_ir"], strict=True) + if not validation["ok"]: + codes = ", ".join( + str(item.get("code")) + for item in validation["diagnostics"] + if item.get("level") == "error" + ) + raise ValueError(f"Capsule import rejected: invalid context_ir ({codes or 'validation failed'})") + md = _sanitize_text(md) + capsule_id = str(data.get("id") or ("ucap_" + _stable_hash(data, 20))) + data["id"] = capsule_id + data["kind"] = CAPSULE_KIND + dest = _capsule_root(repo_root) / capsule_id + dest.mkdir(parents=True, exist_ok=True) + (dest / "capsule.json").write_text(_json_dumps(data) + "\n", encoding="utf-8") + (dest / "capsule.md").write_text(md, encoding="utf-8") + entry = repo_link.add_entry( + repo_root, + kind=CAPSULE_KIND, + title=str(data.get("title") or capsule_id), + content=md, + meta={ + "capsule_id": capsule_id, + "capsule_dir": os.path.relpath(dest, repo_root), + "imported": True, + "personal_context_used": bool(data.get("personal_context_used")), + "privacy": data.get("privacy") or {}, + }, + ) + return { + "format": "dhee_update_capsule_import.v1", + "capsule": data, + "paths": {"dir": str(dest), "json": str(dest / "capsule.json"), "markdown": str(dest / "capsule.md")}, + "entry": entry.to_json(), + } diff --git a/docs/dhee-flow.svg b/docs/dhee-flow.svg new file mode 100644 index 0000000..d562902 --- /dev/null +++ b/docs/dhee-flow.svg @@ -0,0 +1,83 @@ + + Dhee flow chart + Dhee compiles evidence, supervises actions, verifies outcomes, and stores surviving lessons. + + + + + + + + + + + + + + + + + + + + + + + + How Dhee turns work into action context + + + + + INPUTS + Messy reality + files, tests, chats, + screens, agents + + + + + COMPILE + Task contract + goal, files, budget, + tests, constraints + + + + + SUPERVISE + Allowed actions + read, search, edit, + test, submit + + + + + VERIFY + Proof bundle + tests, diffs, trace, + contamination + + + + + REMEMBER + Scene cards + lessons survive, + noise expires + + + + + + + + + + + + + The agent sees a small, current, auditable packet. + Raw evidence stays behind pointers until a contract or verifier needs it. + + diff --git a/docs/dhee-hero.svg b/docs/dhee-hero.svg new file mode 100644 index 0000000..fdce260 --- /dev/null +++ b/docs/dhee-hero.svg @@ -0,0 +1,98 @@ + + Dhee hero banner + Dhee turns noisy agent context into protected action context. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + NOISY REALITY + + + + + + + + + + + + + + + + + + + PROOF + + + + + + + + + + + + + + + + + + + + Dhee + CONTEXT COMPILER + Messy work in. + Deterministic action out. + + + compile - supervise - verify - remember + + + + + + LOCAL FIRST + + PROOF BUNDLES + + UPDATE CAPSULES + + GIT SHARED + + diff --git a/pyproject.toml b/pyproject.toml index 30d7882..b42deb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] +requires = ["setuptools>=77.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -8,7 +8,8 @@ version = "6.2.0" description = "Dhee Developer Brain — local memory, handoff, and git-backed context for AI coding agents" readme = "README.md" requires-python = ">=3.9" -license = {text = "MIT"} +license = "MIT" +license-files = ["LICENSE"] authors = [ {name = "Sankhya AI Labs"} ] @@ -16,7 +17,6 @@ keywords = ["context-firewall", "developer-brain", "coding-agents", "mcp", "clau classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -48,7 +48,8 @@ sqlite_vec = ["sqlite-vec>=0.1.1"] # Local Qwen stack (CPU-native, zero API cost) local = ["llama-cpp-python>=0.3", "sentence-transformers>=3.0"] # Integrations -mcp = ["mcp>=1.0.0"] +# `mcp>=1.0.0` requires Python 3.10+, while Dhee's core package supports 3.9. +mcp = ["mcp>=1.0.0; python_version >= '3.10'"] api = ["fastapi>=0.100.0", "uvicorn>=0.20.0", "python-multipart>=0.0.9", "httpx>=0.25.0", "websockets>=12.0"] # Handoff bus ships inside Dhee's wheel. Keep the extra as a no-op # compatibility alias so `pip install dhee[bus]` remains valid. @@ -67,7 +68,7 @@ all = [ "google-genai>=1.0.0", "openai>=1.0.0", "ollama>=0.4.0", - "mcp>=1.0.0", + "mcp>=1.0.0; python_version >= '3.10'", "fastapi>=0.100.0", "uvicorn>=0.20.0", "python-multipart>=0.0.9", diff --git a/tests/test_mcp_tools_slim.py b/tests/test_mcp_tools_slim.py index fa0dc34..889a376 100644 --- a/tests/test_mcp_tools_slim.py +++ b/tests/test_mcp_tools_slim.py @@ -37,6 +37,30 @@ "dhee_context_checkpoint", "dhee_context_rollover", "dhee_context_provision", + "dhee_scene_world_route", + "dhee_scene_compile", + "dhee_scene_search", + "dhee_context_pack", + "dhee_task_contract_compile", + "dhee_task_contract_create", + "dhee_task_contract_list", + "dhee_task_contract_get", + "dhee_task_contract_import", + "dhee_task_contract_interpret", + "dhee_contract_supervise_action", + "dhee_contract_record_observation", + "dhee_contract_proof_bundle", + "dhee_contract_runtime_activate", + "dhee_contract_runtime_status", + "dhee_contract_runtime_deactivate", + "dhee_contract_enforcement_set", + "dhee_contract_enforcement_status", + "dhee_contract_runtime_doctor", + "dhee_update_capsule_create", + "dhee_update_capsule_list", + "dhee_update_capsule_get", + "dhee_update_capsule_import", + "dhee_update_capsule_interpret", "dhee_tools_list", "dhee_shell", "dhee_list_assets", diff --git a/tests/test_memory_admission.py b/tests/test_memory_admission.py new file mode 100644 index 0000000..6647107 --- /dev/null +++ b/tests/test_memory_admission.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from datetime import date + +from dhee import Dhee, Engram, evaluate_memory_candidate + + +def _screen_metadata(**overrides): + metadata = { + "source": "chotu_screen_memory", + "type": "screen_activity", + "source_app": "chotu", + "confidence": 0.91, + "retention_policy": "durable", + "evidence": { + "kind": "screen_context", + "app": "Google Chrome", + "bundle_id": "com.google.Chrome", + "title": "YouTube - DreamerV3 world models tutorial", + "dwell_seconds": 95, + }, + } + evidence_overrides = overrides.pop("evidence", None) + if evidence_overrides: + metadata["evidence"].update(evidence_overrides) + metadata.update(overrides) + return metadata + + +def test_admission_rejects_noisy_passive_screen_observation(): + content = "\n".join( + [ + "Chotu observed visible screen activity.", + "App: Codex", + "Title: Codex", + "Visible text: iIDJiaiiru /MVVLlkllLIVII)/knIIVLU&PVV VIVV< IfVUI Illf LUI IfLL alValfU", + ] + ) + metadata = _screen_metadata( + evidence={ + "app": "Codex", + "bundle_id": "com.openai.codex", + "title": "Codex", + "dwell_seconds": 0, + } + ) + + decision = evaluate_memory_candidate(content, metadata) + + assert decision.applies is True + assert decision.should_store is False + assert decision.retention_policy == "ephemeral" + assert decision.skip_reason in {"low_quality_signal", "low_ocr_quality"} + + +def test_admission_marks_useful_long_dwell_screen_memory_durable(): + content = "\n".join( + [ + "Chotu observed visible screen activity.", + "App: Google Chrome", + "Title: YouTube - DreamerV3 world models tutorial", + "Visible text: The user is watching a tutorial about DreamerV3, reinforcement learning, world models, and model based agents.", + ] + ) + + decision = evaluate_memory_candidate(content, _screen_metadata()) + + assert decision.applies is True + assert decision.should_store is True + assert decision.retention_policy == "durable" + assert "interest_signal" in decision.reasons + + +def test_dhee_remember_returns_not_stored_for_rejected_passive_observation(tmp_path): + dhee = Dhee(provider="mock", in_memory=True, data_dir=str(tmp_path)) + result = dhee.remember( + "Chotu observed visible screen activity.\nApp: Codex\nTitle: Codex", + metadata=_screen_metadata( + evidence={ + "app": "Codex", + "bundle_id": "com.openai.codex", + "title": "Codex", + "dwell_seconds": 0, + } + ), + ) + + assert result["stored"] is False + assert result["event"] == "SKIP" + assert result["admission"]["should_store"] is False + assert dhee.recall("Codex") == [] + + +def test_engram_add_applies_admission_metadata_and_session_expiration(tmp_path): + memory = Engram(provider="mock", in_memory=True, data_dir=str(tmp_path)) + result = memory.add( + "Chotu observed visible screen activity.\n" + "App: Google Chrome\n" + "Title: Google Search - learn trading\n" + "Visible text: Learning to trade effectively requires mastering market mechanics and risk management.", + user_id="default", + agent_id="chotu", + source_app="chotu", + metadata=_screen_metadata( + evidence={ + "title": "Google Search - learn trading", + "dwell_seconds": 35, + } + ), + infer=False, + ) + + stored = result["results"][0] + memory_id = stored["id"] + loaded = memory.get(memory_id) + + assert stored["event"] == "ADD" + assert stored["admission"]["should_store"] is True + assert loaded["agent_id"] == "chotu" + assert loaded["source_app"] == "chotu" + assert loaded["metadata"]["dhee_admission"]["retention_policy"] == "session" + assert loaded["metadata"]["retention_policy"] == "session" + assert loaded["metadata"]["dhee_lite_path"] is True + assert loaded["metadata"]["enrichment_status"] == "pending" + assert stored["echo_depth"] is None + assert loaded["expiration_date"] is not None + assert date.fromisoformat(loaded["expiration_date"]) >= date.today() + + +def test_admission_strips_raw_ocr_when_vision_summary_is_available(tmp_path): + memory = Engram(provider="mock", in_memory=True, data_dir=str(tmp_path)) + result = memory.add( + "Chotu observed useful visible screen activity.\n" + "App: Google Chrome\n" + "Title: Home / X\n" + "Visible text excerpt:\n" + "X Hcth-iX x.comlhome ForN 8uild in Ptsblit Q Search poweror VUVAurve\n" + "Visual summary:\n" + "User is browsing the X home feed in Google Chrome and viewing posts about India's offshore exploration mission.", + user_id="default", + agent_id="chotu", + source_app="chotu", + metadata=_screen_metadata( + evidence={ + "title": "Home / X", + "dwell_seconds": 45, + "vision_summary_sha256": "abc123", + } + ), + infer=False, + ) + + loaded = memory.get(result["results"][0]["id"]) + + assert "Visual summary:" in loaded["memory"] + assert "Visible text excerpt:" not in loaded["memory"] + assert "Hcth-iX" not in loaded["memory"] + assert loaded["metadata"]["dhee_admission"]["include_ocr_excerpt"] is False + assert loaded["metadata"]["dhee_lite_path"] is True + + +def test_dhee_sweep_admission_flags_and_deletes_legacy_noise(tmp_path): + dhee = Dhee(provider="mock", in_memory=True, data_dir=str(tmp_path)) + memories = [ + { + "id": "legacy-noise", + "memory": "Chotu observed visible screen activity.\n" + "App: Codex\n" + "Title: Codex\n" + "Visible text: iIDJiaiiru /MVVLlkllLIVII)/knIIVLU&PVV", + "metadata": _screen_metadata( + evidence={ + "app": "Codex", + "bundle_id": "com.openai.codex", + "title": "Codex", + "dwell_seconds": 0, + } + ), + }, + { + "id": "useful", + "memory": "User prefers Python for backend services.", + "metadata": {"source": "manual_note"}, + }, + ] + deleted = [] + dhee._engram.get_all = lambda **_: memories + dhee._engram.delete = lambda memory_id: deleted.append(memory_id) + + dry_run = dhee.sweep_admission(dry_run=True) + applied = dhee.sweep_admission(dry_run=False) + + assert dry_run["candidate_count"] == 1 + assert dry_run["deleted_count"] == 0 + assert dry_run["candidates"][0]["id"] == "legacy-noise" + assert applied["deleted_count"] == 1 + assert deleted == ["legacy-noise"] diff --git a/tests/test_packaging.py b/tests/test_packaging.py index f59b7f5..25df39d 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -27,6 +27,25 @@ def test_handoff_bus_is_bundled_not_external_dependency(): assert (ROOT / "engram-bus" / "engram_bus" / "bus.py").exists() +def test_project_metadata_is_release_clean(): + pyproject = (ROOT / "pyproject.toml").read_text(encoding="utf-8") + + assert 'requires = ["setuptools>=77.0", "wheel"]' in pyproject + assert 'license = "MIT"' in pyproject + assert 'license-files = ["LICENSE"]' in pyproject + assert "license = {text" not in pyproject + assert "License :: OSI Approved :: MIT License" not in pyproject + + +def test_mcp_extra_is_python_version_honest(): + pyproject = (ROOT / "pyproject.toml").read_text(encoding="utf-8") + + assert 'mcp = ["mcp>=1.0.0; python_version >= \'3.10\'"]' in pyproject + assert '"mcp>=1.0.0; python_version >= \'3.10\'"' in pyproject + assert 'mcp = ["mcp>=1.0.0"]' not in pyproject + assert '"mcp>=1.0.0",' not in pyproject + + def test_curl_installer_verifies_handoff_bus(): installer = (ROOT / "install.sh").read_text(encoding="utf-8") diff --git a/tests/test_release_hygiene.py b/tests/test_release_hygiene.py new file mode 100644 index 0000000..162aab9 --- /dev/null +++ b/tests/test_release_hygiene.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +from dhee.release_hygiene import ( + RELEASE_INTENT_REL_PATH, + load_release_intent, + release_check, + write_release_intent, +) + + +def _git(repo: Path, *args: str) -> subprocess.CompletedProcess: + return subprocess.run( + ["git", "-C", str(repo), *args], + capture_output=True, + text=True, + check=True, + ) + + +def _init_repo(tmp_path: Path) -> Path: + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init", "-q", str(repo)], check=True) + _git(repo, "config", "user.email", "test@example.com") + _git(repo, "config", "user.name", "Test User") + (repo / "README.md").write_text("# test\n", encoding="utf-8") + _git(repo, "add", "README.md") + _git(repo, "commit", "-q", "-m", "init") + return repo + + +def _codes(report: dict) -> set[str]: + return {str(item.get("code")) for item in report.get("release_blockers") or []} + + +def test_release_check_allows_clean_repo(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + + report = release_check(repo) + + assert report["status"] == "ready" + assert report["release_allowed"] is True + assert report["git"]["clean"] is True + assert report["release_blockers"] == [] + + +def test_release_check_blocks_dirty_repo_without_intent(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + (repo / "README.md").write_text("# changed\n", encoding="utf-8") + + report = release_check(repo) + + assert report["status"] == "blocked" + assert report["release_allowed"] is False + assert "GIT_WORKTREE_DIRTY" in _codes(report) + assert "UNEXPECTED_DIRTY_PATHS" in _codes(report) + assert report["unexpected_dirty_paths"] == ["README.md"] + assert (report["warnings"] or [])[0]["code"] == "RELEASE_INTENT_MISSING" + + +def test_release_intent_documents_scope_but_does_not_allow_dirty_release(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + result = write_release_intent(repo, ["README.md"], reason="docs hardening") + assert result["ok"] is True + _git(repo, "add", RELEASE_INTENT_REL_PATH) + _git(repo, "commit", "-q", "-m", "add release intent") + + (repo / "README.md").write_text("# changed\n", encoding="utf-8") + report = release_check(repo) + + assert report["status"] == "blocked" + assert report["release_allowed"] is False + assert "GIT_WORKTREE_DIRTY" in _codes(report) + assert "UNEXPECTED_DIRTY_PATHS" not in _codes(report) + assert report["unexpected_dirty_paths"] == [] + assert report["intent"]["combined_intended_paths"] == ["README.md"] + + +def test_release_check_still_blocks_unexpected_paths_when_clean_requirement_relaxed(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + result = write_release_intent(repo, ["README.md"]) + assert result["ok"] is True + _git(repo, "add", RELEASE_INTENT_REL_PATH) + _git(repo, "commit", "-q", "-m", "add release intent") + + (repo / "README.md").write_text("# changed\n", encoding="utf-8") + (repo / "surprise.py").write_text("print('surprise')\n", encoding="utf-8") + report = release_check(repo, require_clean=False) + + assert report["status"] == "blocked" + assert "GIT_WORKTREE_DIRTY" not in _codes(report) + assert "UNEXPECTED_DIRTY_PATHS" in _codes(report) + assert report["unexpected_dirty_paths"] == ["surprise.py"] + + +def test_corrupt_release_intent_is_a_blocker(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + intent_path = repo / RELEASE_INTENT_REL_PATH + intent_path.parent.mkdir(parents=True, exist_ok=True) + intent_path.write_text("{bad json", encoding="utf-8") + + intent = load_release_intent(repo) + report = release_check(repo) + + assert intent["ok"] is False + assert "RELEASE_INTENT_UNREADABLE" in _codes(report) + assert any( + diagnostic.get("code") == "RUNTIME_JSON_CORRUPT" + for diagnostic in report["intent"]["diagnostics"] + ) + + +def test_release_cli_json_blocks_dirty_tree(tmp_path: Path) -> None: + repo = _init_repo(tmp_path) + (repo / "README.md").write_text("# changed\n", encoding="utf-8") + + proc = subprocess.run( + [ + sys.executable, + "-m", + "dhee.cli", + "release", + "check", + "--repo", + str(repo), + "--json", + ], + capture_output=True, + text=True, + check=False, + ) + + assert proc.returncode == 1 + payload = json.loads(proc.stdout) + assert payload["status"] == "blocked" + assert "GIT_WORKTREE_DIRTY" in _codes(payload) diff --git a/tests/test_runtime_hardening.py b/tests/test_runtime_hardening.py new file mode 100644 index 0000000..c799a4b --- /dev/null +++ b/tests/test_runtime_hardening.py @@ -0,0 +1,222 @@ +import subprocess +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest + +from dhee import repo_link +from dhee.contract_runtime import ( + ACTIVE_CONTRACT_SCHEMA, + CONTRACT_SUPERVISOR_UNAVAILABLE, + contract_enforcement_status, + contract_runtime_doctor, + contract_runtime_status, + guard_router_call, + set_contract_enforcement, +) +from dhee.mcp_registry import CONTEXT_COMPILER_TOOL_NAMES, TOOL_SPECS +from dhee.runtime_io import append_jsonl_locked, read_json_checked, read_jsonl_checked, write_json_atomic + + +def _init_repo(path: Path) -> Path: + path.mkdir(parents=True, exist_ok=True) + subprocess.run(["git", "init"], cwd=path, check=True, capture_output=True) + subprocess.run(["git", "config", "user.email", "test@example.com"], cwd=path, check=True) + subprocess.run(["git", "config", "user.name", "Test User"], cwd=path, check=True) + (path / "README.md").write_text("# test\n", encoding="utf-8") + subprocess.run(["git", "add", "README.md"], cwd=path, check=True) + subprocess.run(["git", "commit", "-m", "init"], cwd=path, check=True, capture_output=True) + return path + + +def _active_path(repo: Path) -> Path: + return repo / ".dhee" / "context" / "task_runs" / "active_contract.json" + + +def _write_minimal_active_runtime(repo: Path, task_id: str = "task_hardened") -> None: + repo_link._ensure_repo_skeleton(repo) + result = write_json_atomic( + _active_path(repo), + { + "format": ACTIVE_CONTRACT_SCHEMA, + "schema_version": ACTIVE_CONTRACT_SCHEMA, + "active": True, + "status": "active", + "task_id": task_id, + "contract_ref": task_id, + "repo": str(repo), + "strict": False, + "contract_hash": "test", + }, + ) + assert result["ok"] + + +def test_runtime_io_atomic_write_and_corrupt_quarantine(tmp_path): + path = tmp_path / "state.json" + assert write_json_atomic(path, {"schema_version": "x", "value": 1})["ok"] + assert write_json_atomic(path, {"schema_version": "x", "value": 2})["ok"] + checked = read_json_checked(path, expected_schema="x") + assert checked["ok"] + assert checked["data"]["value"] == 2 + + path.write_text("{broken", encoding="utf-8") + corrupt = read_json_checked(path, quarantine=True) + assert not corrupt["ok"] + assert corrupt["quarantine"]["ok"] + assert not path.exists() + assert Path(corrupt["quarantine"]["quarantine_path"]).exists() + + +def test_runtime_io_locked_jsonl_concurrent_appends(tmp_path): + path = tmp_path / "events.jsonl" + + def append_one(i: int) -> None: + result = append_jsonl_locked(path, {"i": i}) + assert result["ok"] + + with ThreadPoolExecutor(max_workers=8) as pool: + list(pool.map(append_one, range(64))) + + checked = read_jsonl_checked(path) + assert checked["ok"] + assert len(checked["records"]) == 64 + assert sorted(record["i"] for record in checked["records"]) == list(range(64)) + + +def test_corrupt_active_contract_is_quarantined_and_surfaced(tmp_path): + repo = _init_repo(tmp_path / "repo") + repo_link._ensure_repo_skeleton(repo) + path = _active_path(repo) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("{broken", encoding="utf-8") + + status = contract_runtime_status(repo=repo) + assert status["status"] == "corrupt" + assert status["error"] == "ACTIVE_CONTRACT_CORRUPT" + assert any(diag["code"] == "RUNTIME_JSON_CORRUPT" for diag in status["diagnostics"]) + assert status["quarantine"]["ok"] + assert not path.exists() + + +def test_enforcement_deny_blocks_without_active_contract(tmp_path): + repo = _init_repo(tmp_path / "repo") + set_contract_enforcement("deny", repo=repo, agent_id="test") + + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert not guard["allowed"] + assert guard["error"] == "ACTIVE_CONTRACT_REQUIRED" + assert guard["enforcement"]["mode"] == "deny" + + +def test_corrupt_enforcement_policy_fails_closed(tmp_path): + repo = _init_repo(tmp_path / "repo") + policy_path = repo / ".dhee" / "context" / "task_runs" / "enforcement.json" + policy_path.parent.mkdir(parents=True, exist_ok=True) + policy_path.write_text("{broken", encoding="utf-8") + + status = contract_enforcement_status(repo=repo) + assert status["mode"] == "deny" + assert status["policy_corrupt"] is True + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert not guard["allowed"] + assert guard["error"] == "ACTIVE_CONTRACT_REQUIRED" + + +def test_enforcement_warn_allows_and_records_warning(tmp_path): + repo = _init_repo(tmp_path / "repo") + set_contract_enforcement("warn", repo=repo, agent_id="test") + + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert guard["allowed"] + assert guard["enforcement"]["mode"] == "warn" + + events = read_jsonl_checked(repo / ".dhee" / "context" / "task_runs" / "enforcement" / "runtime_events.jsonl") + assert any(record.get("event") == "enforcement_warning" for record in events["records"]) + + +def test_enforcement_off_preserves_compatibility_without_active_contract(tmp_path): + repo = _init_repo(tmp_path / "repo") + set_contract_enforcement("off", repo=repo, agent_id="test") + + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert guard["allowed"] + assert guard["error"] is None + assert guard["enforcement"]["mode"] == "off" + + +def test_env_forces_deny_even_when_policy_off(tmp_path, monkeypatch): + repo = _init_repo(tmp_path / "repo") + set_contract_enforcement("off", repo=repo, agent_id="test") + monkeypatch.setenv("DHEE_REQUIRE_ACTIVE_CONTRACT", "1") + + assert contract_enforcement_status(repo=repo)["mode"] == "deny" + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert not guard["allowed"] + assert guard["error"] == "ACTIVE_CONTRACT_REQUIRED" + + +def test_deny_blocks_when_supervisor_unavailable(tmp_path, monkeypatch): + repo = _init_repo(tmp_path / "repo") + _write_minimal_active_runtime(repo) + set_contract_enforcement("deny", repo=repo, agent_id="test") + + import dhee.contract_supervisor as contract_supervisor + + def explode(*_args, **_kwargs): + raise RuntimeError("simulated supervisor failure") + + monkeypatch.setattr(contract_supervisor, "supervise_action", explode) + guard = guard_router_call("dhee_read", {"repo": str(repo), "file_path": str(repo / "README.md")}) + assert not guard["allowed"] + assert guard["error"] == CONTRACT_SUPERVISOR_UNAVAILABLE + + +def test_contract_runtime_doctor_reports_unprotected_and_protected(tmp_path, monkeypatch): + repo = _init_repo(tmp_path / "repo") + assert contract_runtime_doctor(repo=repo)["status"] == "unprotected" + + _write_minimal_active_runtime(repo) + set_contract_enforcement("deny", repo=repo, agent_id="test") + + class RouterState: + enabled = True + managed = True + env_flag = True + allowed_tools = ["Read", "Grep", "Bash"] + settings_path = repo / "settings.json" + + from dhee.router import install as router_install + + monkeypatch.setattr(router_install, "status", lambda: RouterState()) + protected = contract_runtime_doctor(repo=repo) + assert protected["status"] == "protected" + assert protected["protected"] is True + + class MissingRouterState(RouterState): + enabled = False + + monkeypatch.setattr(router_install, "status", lambda: MissingRouterState()) + partial = contract_runtime_doctor(repo=repo) + assert partial["status"] == "partially_protected" + assert "native_hook_or_router_not_enabled" in partial["bypass_risks"] + + +def test_mcp_registry_slim_parity_for_compiler_runtime_tools(): + import dhee.mcp_slim as slim + + tools = {tool.name: tool for tool in slim.TOOLS} + for name in CONTEXT_COMPILER_TOOL_NAMES: + assert name in tools + assert name in slim.HANDLERS + assert tools[name].inputSchema == TOOL_SPECS[name]["inputSchema"] + + +def test_mcp_registry_full_parity_when_mcp_installed(): + mcp_server = pytest.importorskip("dhee.mcp_server", reason="mcp package not installed") + + tools = {tool.name: tool for tool in mcp_server.TOOLS} + for name in CONTEXT_COMPILER_TOOL_NAMES: + assert name in tools + assert name in mcp_server.HANDLERS + assert tools[name].inputSchema == TOOL_SPECS[name]["inputSchema"] diff --git a/tests/test_scene_world_hook.py b/tests/test_scene_world_hook.py new file mode 100644 index 0000000..b5de230 --- /dev/null +++ b/tests/test_scene_world_hook.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import sys +import types + +from dhee.hooks.claude_code.renderer import render_context +from dhee.hooks.scene_world import route_task + + +def test_scene_world_route_disabled_by_default(monkeypatch): + monkeypatch.delenv("DHEE_SCENE_WORLD_ENABLED", raising=False) + monkeypatch.delenv("DHEE_SCENE_WORLD", raising=False) + + result = route_task("fix failing pytest", repo="/tmp/example") + + assert result == {"enabled": False, "status": "disabled"} + + +def test_scene_world_route_uses_lazy_adapter(monkeypatch): + monkeypatch.setenv("DHEE_SCENE_WORLD_ENABLED", "1") + monkeypatch.delenv("DHEE_SCENE_WORLD_MODEL", raising=False) + monkeypatch.delenv("SCENE_WORLD_MODEL_PATH", raising=False) + + package = types.ModuleType("sankhya_wm") + adapter = types.ModuleType("sankhya_wm.dhee_scene_world_adapter") + calls = {} + + def fake_predict_next_action(task, **kwargs): + calls["task"] = task + calls["kwargs"] = kwargs + return { + "route_id": "route-1", + "task": task, + "source": "dhee", + "era": "repo_work", + "active_project": "Dhee", + "best_action": { + "action": "inspect_repo", + "expected_reward": 0.75, + "confidence": 0.8, + "predicted_next_scene": "Agent reads source before editing.", + "likely_user_reaction": "Good. This is grounded.", + "risks": ["extra_tool_cost"], + }, + "ranked_actions": [], + "warnings": [], + } + + adapter.predict_next_action = fake_predict_next_action + monkeypatch.setitem(sys.modules, "sankhya_wm", package) + monkeypatch.setitem(sys.modules, "sankhya_wm.dhee_scene_world_adapter", adapter) + + result = route_task("fix failing pytest", repo="/tmp/example", user_id="default", top_k=2, record=True) + + assert result["status"] == "ok" + assert result["route"]["best_action"]["action"] == "inspect_repo" + assert calls["task"] == "fix failing pytest" + assert calls["kwargs"]["top_k"] == 2 + assert calls["kwargs"]["record"] is True + + +def test_renderer_includes_scene_world_block(): + xml = render_context( + {}, + task_description="fix tests", + scene_world={ + "route_id": "route-1", + "source": "dhee", + "era": "repo_work", + "active_project": "Dhee", + "_scene_world": {"harness": "claude-code"}, + "best_action": { + "action": "inspect_repo", + "expected_reward": 0.72, + "confidence": 0.81, + "predicted_next_scene": "Agent reads source before editing.", + "likely_user_reaction": "User sees grounded work, not generic talk.", + "risks": ["tool_cost"], + }, + "ranked_actions": [ + {"action": "inspect_repo", "expected_reward": 0.72, "confidence": 0.81, "risks": ["tool_cost"]}, + {"action": "answer_directly", "expected_reward": -0.2, "confidence": 0.5, "risks": ["generic"]}, + ], + "warnings": ["forecast only"], + }, + ) + + assert " Path: + path.mkdir() + _run(["git", "init"], path) + _run(["git", "config", "user.email", "dhee-test@example.com"], path) + _run(["git", "config", "user.name", "Dhee Test"], path) + (path / "dhee").mkdir() + (path / "tests").mkdir() + (path / "dhee" / "__init__.py").write_text("", encoding="utf-8") + (path / "dhee" / "context_firewall.py").write_text( + "def allow_path(path):\n return not path.startswith('.env')\n", + encoding="utf-8", + ) + (path / "tests" / "test_context_firewall.py").write_text( + "from dhee.context_firewall import allow_path\n\n" + "def test_env_is_blocked():\n" + " assert allow_path('.env') is False\n", + encoding="utf-8", + ) + _run(["git", "add", "."], path) + _run(["git", "commit", "-m", "initial"], path) + return path + + +def _edit_proof(test_command: str = "pytest tests/test_context_firewall.py") -> dict: + return { + "edit_span": {"path": "dhee/context_firewall.py", "start_line": 1, "end_line": 2}, + "invariant": "context firewall must reject .env paths", + "related_tests": [test_command], + "rollback_point": "HEAD", + } + + +def test_compile_task_contract_builds_deterministic_actionables(tmp_path): + repo = _init_repo(tmp_path / "repo") + + compiled = compile_task_contract("Fix failing context firewall tests in Dhee", repo=repo) + contract = compiled["contract"] + actions = compiled["actions"] + + assert compiled["format"] == "dhee_task_contract_compile.v1" + assert contract["schema_version"] == "dhee.task_contract.v1" + assert contract["goal"] == "Fix failing context firewall tests in Dhee" + assert contract["mode"] == "patch" + assert contract["repo"] == "repo" + assert "dhee/context_firewall.py" in contract["relevant_files"] + assert "tests/test_context_firewall.py" in contract["relevant_files"] + assert "pytest tests/test_context_firewall.py" in contract["must_run"] + assert "dhee/" in contract["allowed_write_paths"] + assert "tests/" in contract["allowed_write_paths"] + assert ".env" in contract["forbidden_paths"] + assert contract["context_budget"]["repo_context_tokens"] == 6000 + assert contract["repo_intelligence"]["schema_version"] == "dhee.repo_intelligence.v1" + assert contract["compiled_context"]["schema_version"] == "dhee.context_ledger.v1" + assert contract["compiled_context"]["items"][0]["why_included"] + assert contract["verification_card"]["schema_version"] == "dhee.verification_card.v1" + assert contract["contamination_status"]["schema_version"] == "dhee.contamination_status.v1" + assert "issue_parse" in [item["name"] for item in compiled["compiler"]["passes"]] + + assert validate_task_contract(compiled)["ok"] is True + assert compiled["actions_schema"] == "dhee.chotu_action_bytecode.v1" + assert compiled["compiler"]["schema_version"] == "dhee.contract_compiler.v1" + assert {action["type"] for action in actions} <= ACTION_TYPES + assert len({action["action_id"] for action in actions}) == len(actions) + assert actions[0]["type"] == "SEARCH_CODE" + assert any(action["type"] == "RUN_TEST" for action in actions) + assert actions[-1]["type"] == "SUBMIT_PATCH" + run_ids = {action["action_id"] for action in actions if action["type"] == "RUN_TEST"} + assert set(actions[-1]["requires"]) == run_ids + for action in actions: + for field in ("precondition", "execution", "observation", "postcondition", "memory_update"): + assert field in action + assert action["bytecode"]["schema_version"] == "dhee.chotu_action_bytecode.v1" + + +def test_compile_task_contract_accepts_explicit_must_run_and_budget(tmp_path): + repo = _init_repo(tmp_path / "repo") + + compiled = compile_task_contract( + "Fix context firewall", + repo=repo, + must_run=["pytest tests/test_context_firewall.py -q"], + context_budget={"state_card_tokens": 100, "retrieved_memory_tokens": 200, "repo_context_tokens": 300, "tool_output_tokens": 400}, + ) + + assert compiled["contract"]["must_run"] == ["pytest tests/test_context_firewall.py -q"] + assert compiled["contract"]["context_budget"]["tool_output_tokens"] == 400 + run_action = next(action for action in compiled["actions"] if action["type"] == "RUN_TEST") + assert run_action["command"] == "pytest tests/test_context_firewall.py -q" + assert run_action["precondition"] == "Dependency environment exists and command is safe for the sandbox." + + +def test_create_task_contract_writes_md_json_and_indexes_repo_context(tmp_path): + repo = _init_repo(tmp_path / "repo") + + created = create_task_contract("Fix context firewall tests", repo=repo) + task_id = created["contract"]["task_id"] + + assert created["format"] == "dhee_task_contract_create.v1" + assert Path(created["paths"]["json"]).exists() + assert Path(created["paths"]["markdown"]).exists() + assert "Task Contract" in Path(created["paths"]["markdown"]).read_text(encoding="utf-8") + listed = list_task_contracts(repo=repo) + assert [item["task_id"] for item in listed] == [task_id] + fetched = get_task_contract(task_id, repo=repo) + assert fetched["contract"]["task_id"] == task_id + entries = repo_link.list_entries(repo) + assert any(entry.kind == "task_contract" and entry.meta["task_id"] == task_id for entry in entries) + + +def test_import_and_interpret_task_contract_in_target_repo(tmp_path): + source_repo = _init_repo(tmp_path / "source") + created = create_task_contract( + "Fix context firewall tests", + repo=source_repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + + target_repo = _init_repo(tmp_path / "target") + imported = import_task_contract(created["paths"]["dir"], repo=target_repo) + interpreted = interpret_task_contract(imported["contract"]["task_id"], repo=target_repo) + + assert imported["format"] == "dhee_task_contract_import.v1" + assert interpreted["format"] == "dhee.task_contract_interpretation.v1" + assert interpreted["readiness"] == "ready" + assert interpreted["policy"]["auto_execute"] is False + assert any(step["type"] == "RUN_TEST" for step in interpreted["execution_plan"]) + assert any(diag["code"] == "REPO_ID_MISMATCH" for diag in interpreted["diagnostics"]) + + +def test_interpret_task_contract_blocks_missing_required_file(tmp_path): + source_repo = _init_repo(tmp_path / "source") + created = create_task_contract("Fix context firewall tests", repo=source_repo) + + target_repo = _init_repo(tmp_path / "target") + (target_repo / "dhee" / "context_firewall.py").unlink() + interpreted = interpret_task_contract(created["paths"]["dir"], repo=target_repo) + + assert interpreted["readiness"] == "blocked" + assert any(diag["code"] == "READ_PATH_MISSING" for diag in interpreted["diagnostics"]) + + +def test_mcp_slim_task_contract_compile_handler(tmp_path): + from dhee import mcp_slim + + repo = _init_repo(tmp_path / "repo") + + result = mcp_slim.HANDLERS["dhee_task_contract_compile"]( + {"repo": str(repo), "goal": "Fix context firewall tests"} + ) + + assert result["contract"]["goal"] == "Fix context firewall tests" + assert result["validation"]["ok"] is True + assert result["actions"][0]["type"] == "SEARCH_CODE" + + +def test_mcp_slim_task_contract_create_and_interpret_handlers(tmp_path): + from dhee import mcp_slim + + repo = _init_repo(tmp_path / "repo") + + created = mcp_slim.HANDLERS["dhee_task_contract_create"]( + {"repo": str(repo), "goal": "Fix context firewall tests"} + ) + interpreted = mcp_slim.HANDLERS["dhee_task_contract_interpret"]( + {"repo": str(repo), "task_id": created["contract"]["task_id"]} + ) + + assert created["format"] == "dhee_task_contract_create.v1" + assert interpreted["readiness"] == "ready" + + +def test_contract_supervisor_allows_and_denies_actions(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + + allowed = supervise_action( + task_id, + {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + repo=repo, + ) + denied = supervise_action( + task_id, + {"type": "RUN_TEST", "command": "pytest tests/test_unrelated.py", "timeout_sec": 120}, + repo=repo, + ) + forbidden_edit = supervise_action( + task_id, + {"type": "EDIT_FILE", "path": ".env", "patch": "--- a/.env\n+++ b/.env\n"}, + repo=repo, + ) + + assert allowed["decision"] == "allow" + assert denied["decision"] == "deny" + assert denied["violations"][0]["code"] == "TEST_COMMAND_OUT_OF_CONTRACT" + assert forbidden_edit["decision"] == "deny" + assert any(item["code"] == "EDIT_PATH_FORBIDDEN" for item in forbidden_edit["violations"]) + + +def test_contract_supervisor_enforces_observation_graph(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + + edit_before_read = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + "proof": _edit_proof(), + }, + repo=repo, + ) + submit_before_tests = supervise_action( + task_id, + {"type": "SUBMIT_PATCH", "summary": "Fix context firewall", "tests": ["pytest tests/test_context_firewall.py"]}, + repo=repo, + ) + + assert edit_before_read["decision"] == "deny" + assert any(item["code"] == "EDIT_REQUIRES_READ_OBSERVATION" for item in edit_before_read["violations"]) + assert submit_before_tests["decision"] == "deny" + assert any(item["code"] == "SUBMIT_REQUIRES_PASSING_TESTS" for item in submit_before_tests["violations"]) + + record_observation_transition( + task_id, + {"type": "READ_FILE", "path": "dhee/context_firewall.py"}, + "Read allow_path implementation", + repo=repo, + outcome="observed", + ) + edit_after_read = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + "proof": _edit_proof(), + }, + repo=repo, + ) + assert edit_after_read["decision"] == "allow" + + record_observation_transition( + task_id, + {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + "1 passed", + repo=repo, + outcome="passed", + ) + submit_after_tests = supervise_action( + task_id, + {"type": "SUBMIT_PATCH", "summary": "Fix context firewall", "tests": ["pytest tests/test_context_firewall.py"]}, + repo=repo, + ) + + assert submit_after_tests["decision"] == "allow" + assert "pytest tests/test_context_firewall.py" in submit_after_tests["runtime_state"]["passed_tests"] + assert submit_after_tests["proof_bundle_preview"]["proof_bundle"]["verifier_result"]["status"] == "passed" + + submitted = record_observation_transition( + task_id, + {"type": "SUBMIT_PATCH", "summary": "Fix context firewall", "tests": ["pytest tests/test_context_firewall.py"]}, + "Ready to submit", + repo=repo, + outcome="submitted", + ) + proof_bundle = submitted["proof_bundle"]["proof_bundle"] + assert proof_bundle["schema_version"] == "dhee.proof_bundle.v1" + assert proof_bundle["contract_id"] == task_id + assert proof_bundle["verifier_result"]["status"] == "passed" + assert proof_bundle["tests_run"][0]["command"] == "pytest tests/test_context_firewall.py" + assert Path(submitted["proof_bundle"]["paths"]["proof_bundle"]).exists() + + +def test_contract_supervisor_requires_edit_proof_obligations(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + record_observation_transition( + task_id, + {"type": "READ_FILE", "path": "dhee/context_firewall.py"}, + "Read failing implementation", + repo=repo, + outcome="observed", + ) + + denied = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + }, + repo=repo, + ) + allowed = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + "proof": _edit_proof(), + }, + repo=repo, + ) + + assert denied["decision"] == "deny" + assert any(item["code"] == "EDIT_PROOF_OBLIGATION_MISSING" for item in denied["violations"]) + assert allowed["decision"] == "allow" + + +def test_contract_supervisor_rejects_invalid_edit_span(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + record_observation_transition( + task_id, + {"type": "READ_FILE", "path": "dhee/context_firewall.py"}, + "Read failing implementation", + repo=repo, + outcome="observed", + ) + bad_proof = _edit_proof() + bad_proof["edit_span"] = {"path": "tests/test_context_firewall.py", "start_line": 1, "end_line": 2} + + denied = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + "proof": bad_proof, + }, + repo=repo, + ) + + assert denied["decision"] == "deny" + assert any(item["code"] == "EDIT_SPAN_INVALID" for item in denied["violations"]) + + +def test_contract_supervisor_blocks_submit_with_unrelated_changed_file(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + record_observation_transition( + task_id, + {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + "1 passed", + repo=repo, + outcome="passed", + ) + (repo / "README.md").write_text("unrelated change\n", encoding="utf-8") + + denied = supervise_action( + task_id, + {"type": "SUBMIT_PATCH", "summary": "Fix context firewall", "tests": ["pytest tests/test_context_firewall.py"]}, + repo=repo, + ) + + assert denied["decision"] == "deny" + assert any(item["code"] == "SUBMIT_CHANGED_PATH_OUT_OF_CONTRACT" for item in denied["violations"]) + assert denied["proof_bundle_preview"]["proof_bundle"]["verifier_result"]["status"] == "blocked" + + +def test_contract_supervisor_records_observation_transition(tmp_path): + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + + result = record_observation_transition( + created["contract"]["task_id"], + {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + "1 failed: allow_path returned true for .env", + repo=repo, + outcome="failed", + next_action={"type": "READ_FILE", "path": "dhee/context_firewall.py", "reason": "Inspect failing implementation"}, + ) + + assert result["format"] == "dhee_contract_observation_record.v1" + assert result["decision"]["decision"] == "allow" + assert result["next_decision"]["decision"] == "allow" + assert result["checkpoint"]["checkpoint"]["stage"] == "after_failing_test" + events_path = Path(result["paths"]["events"]) + assert events_path.exists() + assert "allow_path returned true" in events_path.read_text(encoding="utf-8") + + +def test_active_contract_runtime_refuses_router_calls_and_records_observations(tmp_path): + from dhee.router.handlers import handle_dhee_bash, handle_dhee_grep, handle_dhee_read + + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["python3 -m pytest tests/test_context_firewall.py"], + ) + task_id = created["contract"]["task_id"] + + activated = activate_contract_runtime(task_id, repo=repo, agent_id="test", harness="pytest") + assert activated["active"] is True + assert contract_runtime_status(repo=repo)["task_id"] == task_id + + allowed_read = handle_dhee_read({"file_path": str(repo / "dhee" / "context_firewall.py")}) + denied_read = handle_dhee_read({"file_path": str(repo / "README.md")}) + denied_search = handle_dhee_grep({"repo": str(repo), "path": str(repo), "pattern": "totally_unrelated_symbol"}) + denied_bash = handle_dhee_bash({"cwd": str(repo), "command": "python3 -m pytest tests/test_unrelated.py", "timeout": 30}) + + assert allowed_read["contract_runtime"]["decision"] == "allow" + assert allowed_read["contract_runtime"]["observation"]["outcome"] == "observed" + assert denied_read["format"] == "dhee.contract_tool_refusal.v1" + assert denied_read["will_execute"] is False + assert "READ_PATH_OUT_OF_CONTRACT" in denied_read["violation_codes"] + assert denied_search["format"] == "dhee.contract_tool_refusal.v1" + assert "SEARCH_QUERY_OUT_OF_CONTRACT" in denied_search["violation_codes"] + assert denied_bash["format"] == "dhee.contract_tool_refusal.v1" + assert "TEST_COMMAND_OUT_OF_CONTRACT" in denied_bash["violation_codes"] + + edit_after_router_read = supervise_action( + task_id, + { + "type": "EDIT_FILE", + "path": "dhee/context_firewall.py", + "patch": "--- a/dhee/context_firewall.py\n+++ b/dhee/context_firewall.py\n", + "proof": _edit_proof("python3 -m pytest tests/test_context_firewall.py"), + }, + repo=repo, + ) + assert edit_after_router_read["decision"] == "allow" + + passed_test = handle_dhee_bash({"cwd": str(repo), "command": "python3 -m pytest tests/test_context_firewall.py", "timeout": 60}) + assert passed_test["exit_code"] == 0 + assert passed_test["contract_runtime"]["observation"]["outcome"] == "passed" + + submit = supervise_action( + task_id, + {"type": "SUBMIT_PATCH", "summary": "Fix context firewall", "tests": ["python3 -m pytest tests/test_context_firewall.py"]}, + repo=repo, + ) + assert submit["decision"] == "allow" + + deactivated = deactivate_contract_runtime(repo=repo, agent_id="test", reason="test complete") + assert deactivated["active"] is False + + +def test_pre_tool_gate_blocks_native_edit_under_active_contract(tmp_path): + from dhee.router.pre_tool_gate import evaluate + + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + activate_contract_runtime(created["contract"]["task_id"], repo=repo, agent_id="test") + + denied = evaluate( + { + "tool_name": "Edit", + "tool_input": { + "file_path": str(repo / "dhee" / "context_firewall.py"), + "old_string": "return True", + "new_string": "return not path.startswith('.env')", + }, + } + ) + + assert denied["permissionDecision"] == "deny" + assert "EDIT_PROOF_OBLIGATION_MISSING" in denied["additionalContext"] + + +def test_router_can_require_active_contract_for_coding_actions(tmp_path, monkeypatch): + from dhee.router.handlers import handle_dhee_read + + repo = _init_repo(tmp_path / "repo") + monkeypatch.setenv("DHEE_REQUIRE_ACTIVE_CONTRACT", "1") + + denied = handle_dhee_read({"file_path": str(repo / "dhee" / "context_firewall.py")}) + + assert denied["format"] == "dhee.contract_tool_refusal.v1" + assert "ACTIVE_CONTRACT_REQUIRED" in denied["violation_codes"] + + +def test_mcp_slim_contract_supervisor_handlers(tmp_path): + from dhee import mcp_slim + + repo = _init_repo(tmp_path / "repo") + created = create_task_contract( + "Fix context firewall tests", + repo=repo, + must_run=["pytest tests/test_context_firewall.py"], + ) + + decision = mcp_slim.HANDLERS["dhee_contract_supervise_action"]( + { + "repo": str(repo), + "task_id": created["contract"]["task_id"], + "action": {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + } + ) + event = mcp_slim.HANDLERS["dhee_contract_record_observation"]( + { + "repo": str(repo), + "task_id": created["contract"]["task_id"], + "action": {"type": "RUN_TEST", "command": "pytest tests/test_context_firewall.py", "timeout_sec": 120}, + "observation": "passed", + "outcome": "passed", + } + ) + proof = mcp_slim.HANDLERS["dhee_contract_proof_bundle"]( + {"repo": str(repo), "task_id": created["contract"]["task_id"], "persist": False} + ) + runtime = mcp_slim.HANDLERS["dhee_contract_runtime_activate"]( + {"repo": str(repo), "task_id": created["contract"]["task_id"], "agent_id": "pytest"} + ) + runtime_status = mcp_slim.HANDLERS["dhee_contract_runtime_status"]({"repo": str(repo)}) + runtime_deactivated = mcp_slim.HANDLERS["dhee_contract_runtime_deactivate"]( + {"repo": str(repo), "reason": "test complete"} + ) + + assert decision["decision"] == "allow" + assert event["event"]["outcome"] == "passed" + assert proof["proof_bundle"]["verifier_result"]["status"] == "passed" + assert proof["paths"] == {} + assert runtime["active"] is True + assert runtime_status["task_id"] == created["contract"]["task_id"] + assert runtime_deactivated["active"] is False + + +def test_cli_context_task_parser_accepts_compile_goal(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + ["context", "task", "compile", "Fix context firewall", "--repo", ".", "--must-run", "pytest tests/test_context_firewall.py", "--json"] + ) + + assert args.context_action == "task" + assert args.entry_id == "compile" + assert args.context_args == ["Fix context firewall"] + assert args.must_run == ["pytest tests/test_context_firewall.py"] + + +def test_cli_context_task_parser_accepts_interpret(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + ["context", "task", "interpret", "task_123", "--repo", ".", "--strict", "--json"] + ) + + assert args.context_action == "task" + assert args.entry_id == "interpret" + assert args.context_args == ["task_123"] + assert args.strict is True + + +def test_cli_context_task_parser_accepts_supervise_action_json(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + [ + "context", + "task", + "supervise", + "task_123", + "--repo", + ".", + "--action-json", + '{"type":"RUN_TEST","command":"pytest tests/test_context_firewall.py","timeout_sec":120}', + "--json", + ] + ) + + assert args.context_action == "task" + assert args.entry_id == "supervise" + assert args.context_args == ["task_123"] + assert "RUN_TEST" in args.action_json + + +def test_cli_context_task_parser_accepts_runtime_activation(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + ["context", "task", "activate", "task_123", "--repo", ".", "--strict", "--force", "--json"] + ) + + assert args.context_action == "task" + assert args.entry_id == "activate" + assert args.context_args == ["task_123"] + assert args.strict is True + assert args.force is True + + +def test_cli_context_task_parser_accepts_proof_bundle(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + ["context", "task", "proof", "task_123", "--repo", ".", "--dry-run", "--json"] + ) + + assert args.context_action == "task" + assert args.entry_id == "proof" + assert args.context_args == ["task_123"] + assert args.dry_run is True + + +def test_cli_context_task_proof_dispatches_to_proof_bundle(monkeypatch, capsys): + from dhee.cli import build_parser, cmd_context + import dhee.contract_supervisor as contract_supervisor + + called = {} + + def fake_build_proof_bundle(task_contract, *, repo=None, strict=False, persist=True): + called["task_contract"] = task_contract + called["repo"] = repo + called["persist"] = persist + return { + "format": "dhee_contract_proof_bundle.v1", + "proof_bundle": { + "contract_id": task_contract, + "verifier_result": { + "status": "passed", + "passed_tests": [], + "required_tests": [], + "out_of_contract_changed_paths": [], + "forbidden_changed_paths": [], + }, + }, + "paths": {}, + } + + monkeypatch.setattr(contract_supervisor, "build_proof_bundle", fake_build_proof_bundle) + args = build_parser().parse_args( + ["context", "task", "proof", "task_123", "--repo", ".", "--dry-run"] + ) + cmd_context(args) + + assert called == {"task_contract": "task_123", "repo": ".", "persist": False} + assert "Proof bundle task_123" in capsys.readouterr().out diff --git a/tests/test_temporal_scenes.py b/tests/test_temporal_scenes.py new file mode 100644 index 0000000..5358fc5 --- /dev/null +++ b/tests/test_temporal_scenes.py @@ -0,0 +1,324 @@ +from dhee.temporal_scenes import ( + PromotionGate, + build_context_pack, + collect_live_scene_sources, + collect_scene_evidence, + compile_scene, + compile_scene_from_sources, + search_scenes, +) + + +def test_scene_compile_preserves_multi_agent_provenance_and_cards_are_pointer_safe(tmp_path): + evidence = [ + { + "id": "mem-chotu-1", + "memory": "Watched a Kimi CLI agent walkthrough about compact agent adapters and replayable update recipes.", + "user_id": "u1", + "agent_id": "chotu", + "source_app": "chotu-browser", + "source_event_id": "video-42", + "memory_type": "world_memory", + "confidentiality_scope": "personal", + "metadata": {"url": "https://example.test/kimi"}, + "categories": ["kimi", "agent-adapter"], + }, + { + "id": "codex-run-7", + "content": "Codex found that repo updates should travel as capsules with before and after behavior, hashes, and tests.", + "user_id": "u1", + "agent_id": "codex", + "source_app": "codex", + "source_event_id": "run-7", + "run_id": "run-7", + "memory_type": "session_digest", + "confidentiality_scope": "personal", + }, + ] + + scene = compile_scene( + evidence, + user_id="u1", + repo="/tmp/repo", + task="Implement Kimi CLI adapter update capsules", + store_dir=tmp_path, + ) + + assert scene.provenance["agent_ids"] == ["chotu", "codex"] + assert scene.provenance["source_apps"] == ["chotu-browser", "codex"] + assert "text" in scene.modalities + assert scene.tier in {"hot", "warm", "cold"} + + card = scene.to_card() + assert "evidence_refs" in card + assert "snippet" not in card["evidence_refs"][0] + assert "url" not in card["evidence_refs"][0] + + hits = search_scenes("Kimi adapter capsules", user_id="u1", repo="/tmp/repo", store_dir=tmp_path) + assert hits and hits[0].id == scene.id + + +def test_context_pack_obeys_budget_and_never_includes_raw_evidence_fields(tmp_path): + compile_scene( + [ + { + "id": "mem-1", + "memory": "A long transcript derivative about token budgets, scene cards, and pointer-only media expansion.", + "agent_id": "chotu", + "source_app": "wearable-transcript", + "source_event_id": "audio-1", + "memory_type": "transcript_chunk", + "modality": "audio", + "confidentiality_scope": "personal", + } + ], + user_id="u1", + task="Use wearable transcript memory for repo task context", + store_dir=tmp_path, + ) + + pack = build_context_pack( + "token budgets transcript scene cards", + user_id="u1", + token_budget=300, + store_dir=tmp_path, + ) + + assert pack["estimated_tokens"] <= 300 + assert pack["raw_media_included"] is False + assert pack["full_diffs_included"] is False + serialized = str(pack) + assert "snippet" not in serialized + assert "transcript_chunk" in serialized + + +def test_promotion_gate_redacts_personal_scene_evidence(tmp_path): + scene = compile_scene( + [ + { + "id": "private-1", + "memory": "Personal browsing note from /Users/alice/private/topic.txt about adapter design.", + "agent_id": "chotu", + "source_app": "browser", + "source_event_id": "evt-1", + "confidentiality_scope": "personal", + "uri": "/Users/alice/private/topic.txt", + } + ], + user_id="u1", + task="Adapter design", + store_dir=tmp_path, + ) + + safe = PromotionGate().sanitize_scene(scene) + assert safe["personal_context_used"] is True + assert "" in safe["summary"] + assert "snippet" not in str(safe["evidence_refs"]) + assert "/Users/alice" not in str(safe) + + +def test_mcp_slim_scene_handlers_compile_search_and_pack(tmp_path): + from dhee import mcp_slim + + compile_result = mcp_slim.HANDLERS["dhee_scene_compile"]( + { + "evidence": [ + { + "id": "mem-mcp", + "memory": "MCP scene handler captures a compact adapter lesson for future coding agents.", + "agent_id": "codex", + "source_app": "codex", + "confidentiality_scope": "personal", + } + ], + "task": "adapter lesson", + "user_id": "u1", + "store_dir": str(tmp_path), + } + ) + assert compile_result["format"] == "dhee_scene_compile.v1" + + search_result = mcp_slim.HANDLERS["dhee_scene_search"]( + {"query": "adapter lesson", "user_id": "u1", "store_dir": str(tmp_path)} + ) + assert search_result["results"] + + pack = mcp_slim.HANDLERS["dhee_context_pack"]( + {"query": "adapter lesson", "user_id": "u1", "store_dir": str(tmp_path), "token_budget": 300} + ) + assert pack["format"] == "dhee_context_pack.v1" + assert pack["estimated_tokens"] <= 300 + + +def test_collect_scene_evidence_from_repo_context_session_shared_task_and_artifacts(tmp_path): + from dhee import repo_link + + repo = tmp_path / "repo" + repo.mkdir() + repo_link._ensure_repo_skeleton(repo) + repo_link.add_entry( + repo, + kind="decision", + title="Use update capsules", + content="Share adapter updates as sanitized capsule recipes with hashes and tests.", + ) + session = { + "id": "sess-1", + "agent_id": "codex", + "task_summary": "Implemented temporal scenes", + "decisions_made": ["Use pointer-backed evidence cards"], + "files_touched": ["dhee/temporal_scenes.py"], + } + shared_results = { + "results": [ + { + "id": "packet-1", + "packet_kind": "native_bash", + "tool_name": "pytest", + "digest": "pytest passed for scene packs", + "agent_id": "codex", + "harness": "codex", + } + ] + } + artifacts = [ + { + "artifact_id": "artifact-1", + "filename": "README.md", + "summary": "Documentation explains capsule import workflow.", + } + ] + + evidence = collect_scene_evidence( + repo=repo, + session=session, + shared_task_results=shared_results, + artifacts=artifacts, + sources=["repo_context", "session", "shared_task_results", "artifacts"], + limit=10, + ) + + assert {row["kind"] for row in evidence} >= {"repo_context:decision", "session_digest", "native_bash", "artifact"} + scene = compile_scene_from_sources( + repo=repo, + session=session, + shared_task_results=shared_results, + artifacts=artifacts, + sources=["repo_context", "session", "shared_task_results", "artifacts"], + user_id="u1", + query="capsule adapter scenes", + store_dir=tmp_path / "scenes", + ) + assert "dhee-repo-context" in scene.provenance["source_apps"] + assert "codex" in scene.provenance["agent_ids"] + + +def test_mcp_slim_scene_compile_can_collect_repo_context(tmp_path): + from dhee import mcp_slim, repo_link + + repo = tmp_path / "repo" + repo.mkdir() + repo_link._ensure_repo_skeleton(repo) + repo_link.add_entry( + repo, + kind="learning", + title="Personal memory bridge", + content="Coding agents should receive compact scene cards derived from relevant personal observations.", + ) + + result = mcp_slim.HANDLERS["dhee_scene_compile"]( + { + "repo": str(repo), + "query": "personal observations coding agents", + "include_repo_context": True, + "user_id": "u1", + "store_dir": str(tmp_path / "scenes"), + } + ) + + assert result["format"] == "dhee_scene_compile.v1" + assert result["scene"]["provenance"]["source_apps"] == ["dhee-repo-context"] + + +class _FakeLiveSceneDB: + def __init__(self, repo): + self.repo = str(repo) + + def list_shared_tasks(self, user_id="default", status="active", repo=None, limit=50): + return [ + { + "id": "task-1", + "repo": self.repo, + "workspace_id": self.repo, + "folder_path": ".", + "title": "live task", + "status": "active", + "metadata": {}, + } + ] + + def list_shared_task_results(self, shared_task_id, limit=5, **_kwargs): + return [ + { + "id": "packet-1", + "packet_kind": "native_bash", + "tool_name": "pytest", + "digest": "Live shared task result says scene packs passed.", + "harness": "codex", + "agent_id": "codex", + "metadata": {"command": "pytest"}, + } + ][:limit] + + def list_artifacts(self, **_kwargs): + return [ + { + "artifact_id": "artifact-1", + "filename": "capsule-notes.md", + "source_path": f"{self.repo}/capsule-notes.md", + "lifecycle_state": "attached", + } + ] + + +def test_collect_live_scene_sources_reads_bounded_shared_task_and_artifacts(tmp_path): + repo = tmp_path / "repo" + repo.mkdir() + live = collect_live_scene_sources( + db=_FakeLiveSceneDB(repo), + repo=repo, + user_id="u1", + include_session=False, + include_shared_task_results=True, + include_artifacts=True, + limit=5, + ) + + evidence = collect_scene_evidence( + repo=repo, + shared_task_results=live["shared_task_results"], + artifacts=live["artifacts"], + sources=["shared_task_results", "artifacts"], + ) + assert {row["kind"] for row in evidence} == {"native_bash", "artifact"} + + +def test_mcp_slim_scene_compile_can_collect_live_sources(tmp_path, monkeypatch): + from dhee import mcp_slim + + repo = tmp_path / "repo" + repo.mkdir() + monkeypatch.setattr(mcp_slim, "_get_db", lambda: _FakeLiveSceneDB(repo)) + + result = mcp_slim.HANDLERS["dhee_scene_compile"]( + { + "repo": str(repo), + "query": "live shared task artifact scene", + "include_live_sources": True, + "user_id": "u1", + "store_dir": str(tmp_path / "scenes"), + } + ) + + assert result["format"] == "dhee_scene_compile.v1" + assert {"codex", "dhee-artifact"}.issubset(set(result["scene"]["provenance"]["source_apps"])) diff --git a/tests/test_update_capsules.py b/tests/test_update_capsules.py new file mode 100644 index 0000000..65aedc2 --- /dev/null +++ b/tests/test_update_capsules.py @@ -0,0 +1,235 @@ +import json +import subprocess +from pathlib import Path + +import pytest + +from dhee import repo_link +from dhee.update_capsules import ( + create_update_capsule, + get_update_capsule, + import_update_capsule, + interpret_update_capsule, + list_update_capsules, +) + + +def _run(args, cwd): + subprocess.run(args, cwd=cwd, check=True, text=True, capture_output=True) + + +def _init_repo(path: Path) -> Path: + path.mkdir() + _run(["git", "init"], path) + _run(["git", "config", "user.email", "dhee-test@example.com"], path) + _run(["git", "config", "user.name", "Dhee Test"], path) + (path / "app.py").write_text("def feature():\n return 'before'\n", encoding="utf-8") + _run(["git", "add", "app.py"], path) + _run(["git", "commit", "-m", "initial"], path) + return path + + +def test_create_update_capsule_writes_md_json_indexes_repo_context_and_redacts(tmp_path): + repo = _init_repo(tmp_path / "repo") + secret = "sk-" + ("a" * 32) + (repo / "app.py").write_text( + "def feature():\n" + " token = '" + secret + "'\n" + " return 'after'\n", + encoding="utf-8", + ) + (repo / "adapter.md").write_text("new adapter behavior\n", encoding="utf-8") + + result = create_update_capsule( + repo=repo, + since="HEAD", + task_id="task-123", + evidence=[ + { + "kind": "temporal_scene", + "ref": "scene_abc", + "label": "Private browsing context from /Users/alice/notes.txt", + "agent_id": "chotu", + "source_app": "chotu", + "confidentiality_scope": "personal", + } + ], + ) + + capsule = result["capsule"] + paths = result["paths"] + assert capsule["kind"] == "update_capsule" + assert capsule["personal_context_used"] is True + assert capsule["privacy"]["raw_personal_memory_included"] is False + assert Path(paths["json"]).exists() + assert Path(paths["markdown"]).exists() + + data = json.loads(Path(paths["json"]).read_text(encoding="utf-8")) + assert data["context_ir"]["schema_version"] == "dhee.context_ir.v1" + assert data["context_ir"]["symbol_table"]["files"] + assert data["context_ir"]["operations"] + assert data["base_file_hashes"]["app.py"] + changed_paths = {item["path"] for item in data["changed_paths"]} + assert {"app.py", "adapter.md"}.issubset(changed_paths) + assert data["file_hashes"]["app.py"] + serialized = json.dumps(data) + markdown = Path(paths["markdown"]).read_text(encoding="utf-8") + assert secret not in serialized + assert secret not in markdown + assert "/Users/alice" not in serialized + assert "diff --git" in serialized + + entries = repo_link.list_entries(repo) + capsule_entries = [entry for entry in entries if entry.kind == "update_capsule"] + assert capsule_entries + assert capsule_entries[-1].meta["capsule_id"] == capsule["id"] + + +def test_capsule_import_into_clean_repo_lists_and_gets_capsule(tmp_path): + source_repo = _init_repo(tmp_path / "source") + (source_repo / "app.py").write_text("def feature():\n return 'after'\n", encoding="utf-8") + created = create_update_capsule(repo=source_repo, since="HEAD", task_id="task-import") + source_dir = Path(created["paths"]["dir"]) + + target_repo = _init_repo(tmp_path / "target") + imported = import_update_capsule(source_dir, repo=target_repo) + + assert imported["capsule"]["id"] == created["capsule"]["id"] + listed = list_update_capsules(repo=target_repo) + assert [item["id"] for item in listed] == [created["capsule"]["id"]] + fetched = get_update_capsule(created["capsule"]["id"], repo=target_repo) + assert fetched["capsule"]["id"] == created["capsule"]["id"] + assert "Reproduction Guide" in fetched["markdown"] + assert "Context IR" in fetched["markdown"] + + +def test_capsule_import_rejects_raw_private_memory_marker(tmp_path): + repo = _init_repo(tmp_path / "repo") + capsule_dir = tmp_path / "bad_capsule" + capsule_dir.mkdir() + (capsule_dir / "capsule.json").write_text( + json.dumps( + { + "id": "ucap_private", + "title": "bad", + "privacy": {"raw_personal_memory_included": True}, + } + ), + encoding="utf-8", + ) + (capsule_dir / "capsule.md").write_text("private body", encoding="utf-8") + + with pytest.raises(ValueError): + import_update_capsule(capsule_dir, repo=repo) + + +def test_update_capsule_interpreter_reports_ready_applied_and_conflict(tmp_path): + source_repo = _init_repo(tmp_path / "source") + after = "def feature():\n return 'after'\n" + (source_repo / "app.py").write_text(after, encoding="utf-8") + created = create_update_capsule(repo=source_repo, since="HEAD", task_id="interp") + source_dir = Path(created["paths"]["dir"]) + + target_repo = _init_repo(tmp_path / "target") + ready = interpret_update_capsule(source_dir, repo=target_repo) + assert ready["format"] == "dhee.context_interpretation.v1" + assert ready["readiness"] == "ready" + assert ready["execution_plan"][0]["action"] == "modify_file" + assert ready["policy"]["auto_apply"] is False + + (target_repo / "app.py").write_text(after, encoding="utf-8") + applied = interpret_update_capsule(source_dir, repo=target_repo) + assert applied["readiness"] == "already_applied" + + (target_repo / "app.py").write_text("def feature():\n return 'other'\n", encoding="utf-8") + conflict = interpret_update_capsule(source_dir, repo=target_repo) + assert conflict["readiness"] == "conflict" + assert any(diag["code"] == "PRECONDITION_MISMATCH" for diag in conflict["diagnostics"]) + + +def test_update_capsule_interpreter_resolves_moved_target_by_hash(tmp_path): + source_repo = _init_repo(tmp_path / "source") + (source_repo / "app.py").write_text("def feature():\n return 'after'\n", encoding="utf-8") + created = create_update_capsule(repo=source_repo, since="HEAD", task_id="moved-target") + source_dir = Path(created["paths"]["dir"]) + + target_repo = _init_repo(tmp_path / "target") + (target_repo / "src").mkdir() + (target_repo / "src" / "app.py").write_text( + "def feature():\n return 'before'\n", + encoding="utf-8", + ) + (target_repo / "app.py").unlink() + + interpreted = interpret_update_capsule(source_dir, repo=target_repo) + + assert interpreted["readiness"] == "ready" + state = interpreted["operation_states"][0] + assert state["path"] == "app.py" + assert state["resolved_path"] == "src/app.py" + assert state["resolution"] == "moved_before_hash_match" + assert interpreted["execution_plan"][0]["resolved_path"] == "src/app.py" + + +def test_capsule_import_rejects_invalid_context_ir(tmp_path): + repo = _init_repo(tmp_path / "repo") + capsule_dir = tmp_path / "bad_ir_capsule" + capsule_dir.mkdir() + (capsule_dir / "capsule.json").write_text( + json.dumps( + { + "id": "ucap_bad_ir", + "title": "bad ir", + "privacy": {"raw_personal_memory_included": False}, + "context_ir": { + "schema_version": "dhee.context_ir.v1", + "symbol_table": {"files": []}, + "operations": [], + }, + } + ), + encoding="utf-8", + ) + (capsule_dir / "capsule.md").write_text("bad ir body", encoding="utf-8") + + with pytest.raises(ValueError, match="invalid context_ir"): + import_update_capsule(capsule_dir, repo=repo) + + +def test_mcp_slim_capsule_handlers_create_list_get(tmp_path): + from dhee import mcp_slim + + repo = _init_repo(tmp_path / "repo") + (repo / "app.py").write_text("def feature():\n return 'after via mcp'\n", encoding="utf-8") + + created = mcp_slim.HANDLERS["dhee_update_capsule_create"]( + {"repo": str(repo), "since": "HEAD", "task_id": "mcp-task"} + ) + capsule_id = created["capsule"]["id"] + + listed = mcp_slim.HANDLERS["dhee_update_capsule_list"]({"repo": str(repo)}) + assert [item["id"] for item in listed["results"]] == [capsule_id] + + fetched = mcp_slim.HANDLERS["dhee_update_capsule_get"]( + {"repo": str(repo), "capsule_id": capsule_id} + ) + assert fetched["capsule"]["id"] == capsule_id + + interpreted = mcp_slim.HANDLERS["dhee_update_capsule_interpret"]( + {"repo": str(repo), "capsule_id": capsule_id} + ) + assert interpreted["format"] == "dhee.context_interpretation.v1" + assert interpreted["readiness"] in {"ready", "already_applied", "conflict"} + + +def test_cli_context_capsule_parser_accepts_nested_subcommands(): + from dhee.cli import build_parser + + args = build_parser().parse_args( + ["context", "capsule", "interpret", "ucap_123", "--repo", ".", "--strict", "--json"] + ) + assert args.context_action == "capsule" + assert args.entry_id == "interpret" + assert args.context_args == ["ucap_123"] + assert args.repo == "." + assert args.strict is True