From 0576923cf8a529c54fdb41d30749b6c054bd58dd Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Wed, 29 Apr 2026 14:39:29 -0600 Subject: [PATCH 1/5] feat: migrate task and tool queues from redis to nats jetstream broker **Added:** - Introduced NATS JetStream as the primary broker for all task, tool, and result queues across orchestrator and worker agents - Implemented new `ares-core/src/nats.rs` with subject taxonomy, stream definitions, and broker abstraction for JetStream - Added Ansible `nats` role to provision, configure, and manage NATS JetStream server (with docs and systemd integration) - Updated EC2 setup scripts to install, configure, and manage `nats-server` alongside Redis - Added NATS connection URLs to environment, templates, and container docs - Documented NATS deployment, configuration, and usage in infra and agent docs **Changed:** - All core task, tool, and blue-team queues now use NATS JetStream subjects, replacing Redis List/BRPOP patterns for durable work queues - Orchestrator, workers, and blue agents now require and connect to both Redis (state) and NATS (queues) - Tool dispatch and result collection now use NATS request/reply with inbox subjects, removing the need for dedicated TCP connections for blocking calls - Blue investigation queue and results moved to NATS JetStream subjects - Task status, heartbeats, operation locks, and persistent state remain on Redis - Updated orchestrator, worker, and tool-executor modules to poll and publish via JetStream consumers/producers with explicit acks and bounded redelivery - Refactored orchestrator config, state, and queue code to thread NATS broker handles throughout and ensure streams on startup - Updated all container and agent documentation to mention NATS as required infra - Updated Ansible playbooks and role templates to deploy NATS and wire up environment variables for all agents - Updated diagrams, markdown, and infrastructure docs to show NATS as the broker - Updated Cargo manifests to include `async-nats`, `futures`, and `bytes` dependencies in all crates **Removed:** - Redis-backed work queue code paths, including BRPOP/LPUSH for tasks and tools - Obsolete Redis-only queue length and result-polling implementations - Legacy Redis-only tool dispatcher and result handler logic - All Redis pubsub notification usage for state updates (now NATS core pub) - Unused Redis key prefix constants and result queue definitions in code and docs --- .taskfiles/ec2/Taskfile.yaml | 3 + .../ec2/scripts/launch-orchestrator.sh.tmpl | 1 + .taskfiles/ec2/scripts/setup.sh | 83 ++- .taskfiles/ec2/scripts/status.sh | 8 + AGENTS.md | 4 +- Cargo.lock | 287 +++++++- Cargo.toml | 3 + ansible/README.md | 7 +- ansible/playbooks/ares/goad_attack_box.yml | 7 +- ansible/roles/nats/README.md | 80 +++ ansible/roles/nats/defaults/main.yml | 25 + ansible/roles/nats/handlers/main.yml | 11 + ansible/roles/nats/meta/main.yml | 24 + ansible/roles/nats/tasks/linux.yml | 116 +++ ansible/roles/nats/tasks/main.yml | 4 + .../roles/nats/templates/nats-server.conf.j2 | 25 + .../nats/templates/nats-server.service.j2 | 20 + .../roles/redis/templates/ares@.service.j2 | 5 +- ares-cli/Cargo.toml | 3 + ares-cli/src/blue/delete.rs | 28 +- ares-cli/src/blue/submit.rs | 28 +- ares-cli/src/orchestrator/blue/auto_submit.rs | 15 +- ares-cli/src/orchestrator/blue/runner.rs | 17 +- ares-cli/src/orchestrator/completion.rs | 34 +- ares-cli/src/orchestrator/config.rs | 7 + ares-cli/src/orchestrator/deferred.rs | 6 + ares-cli/src/orchestrator/mod.rs | 25 +- ares-cli/src/orchestrator/monitoring.rs | 2 +- ares-cli/src/orchestrator/recovery/manager.rs | 10 +- ares-cli/src/orchestrator/task_queue.rs | 666 ++++++------------ ares-cli/src/orchestrator/throttling.rs | 1 + .../src/orchestrator/tool_dispatcher/mod.rs | 23 +- .../tool_dispatcher/redis_dispatcher.rs | 160 +++-- ares-cli/src/worker/blue_task_loop.rs | 5 +- ares-cli/src/worker/config.rs | 6 + ares-cli/src/worker/mod.rs | 23 +- ares-cli/src/worker/task_loop/mod.rs | 160 +++-- .../src/worker/task_loop/result_handler.rs | 57 +- ares-cli/src/worker/tool_executor.rs | 347 +++------ ares-core/Cargo.toml | 3 + ares-core/src/lib.rs | 1 + ares-core/src/nats.rs | 369 ++++++++++ ares-core/src/state/blue_task_queue.rs | 308 +++++--- ares-core/src/state/operations.rs | 33 +- docs/blue.md | 2 +- docs/infrastructure.md | 14 +- docs/red.md | 37 +- .../templates/ares-blue-agent/README.md | 1 + .../ares-blue-lateral-analyst-agent/README.md | 1 + .../ares-blue-threat-hunter-agent/README.md | 1 + .../ares-blue-triage-agent/README.md | 1 + .../templates/ares-orchestrator/README.md | 22 +- .../templates/ares-worker/README.md | 21 +- 53 files changed, 2062 insertions(+), 1088 deletions(-) create mode 100644 ansible/roles/nats/README.md create mode 100644 ansible/roles/nats/defaults/main.yml create mode 100644 ansible/roles/nats/handlers/main.yml create mode 100644 ansible/roles/nats/meta/main.yml create mode 100644 ansible/roles/nats/tasks/linux.yml create mode 100644 ansible/roles/nats/tasks/main.yml create mode 100644 ansible/roles/nats/templates/nats-server.conf.j2 create mode 100644 ansible/roles/nats/templates/nats-server.service.j2 create mode 100644 ares-core/src/nats.rs diff --git a/.taskfiles/ec2/Taskfile.yaml b/.taskfiles/ec2/Taskfile.yaml index bbe3514b..c33dffb5 100644 --- a/.taskfiles/ec2/Taskfile.yaml +++ b/.taskfiles/ec2/Taskfile.yaml @@ -552,6 +552,7 @@ tasks: PARAMS_FILE=$(mktemp) trap "rm -f $PARAMS_FILE" EXIT START_CMD="systemctl start redis-server 2>/dev/null || systemctl start redis; sleep 1; redis-cli ping; " + START_CMD+="systemctl start nats-server; sleep 1; curl -fsS http://127.0.0.1:8222/varz >/dev/null && echo 'NATS OK' || echo 'NATS NOT RUNNING'; " START_CMD+="systemctl start $WORKER_UNITS; sleep 2; echo Worker status:; " START_CMD+='for role in recon credential_access cracker acl privesc lateral coercion; do ' START_CMD+='st=$(systemctl is-active ares@${role} 2>/dev/null || echo dead); ' @@ -1036,6 +1037,7 @@ tasks: fi fi ENV_FILE_CMD="$ENV_FILE_CMD; echo 'ARES_DEPLOYMENT={{.EC2_DEPLOYMENT}}' >> /etc/ares/env" + ENV_FILE_CMD="$ENV_FILE_CMD; echo 'NATS_URL=nats://127.0.0.1:4222' >> /etc/ares/env" # OTEL: send traces to Alloy OTLP gateway → Tempo via HTTP/protobuf ENV_FILE_CMD="$ENV_FILE_CMD; echo 'OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_TRACES_ENDPOINT}' >> /etc/ares/env" ENV_FILE_CMD="$ENV_FILE_CMD; echo 'OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf' >> /etc/ares/env" @@ -1054,6 +1056,7 @@ tasks: export GRAFANA_URL='${GRAFANA_URL_VAL}' export GRAFANA_SERVICE_ACCOUNT_TOKEN='${GRAFANA_TOKEN_VAL}' export ARES_REDIS_URL=redis://127.0.0.1:6379 + export NATS_URL=nats://127.0.0.1:4222 {{- if .LLM_MODEL}} export ARES_LLM_MODEL='{{.LLM_MODEL}}' {{- end}} diff --git a/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl b/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl index 619a4bc2..3b98544a 100755 --- a/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl +++ b/.taskfiles/ec2/scripts/launch-orchestrator.sh.tmpl @@ -2,6 +2,7 @@ # Launch ares orchestrator with environment variables # Placeholders are substituted by the calling task via envsubst/sed export ARES_REDIS_URL=redis://127.0.0.1:6379 +export NATS_URL=nats://127.0.0.1:4222 export RUST_LOG=info export ARES_OPERATION_ID='__ARES_PAYLOAD__' export OPENAI_API_KEY='__OPENAI_API_KEY__' diff --git a/.taskfiles/ec2/scripts/setup.sh b/.taskfiles/ec2/scripts/setup.sh index f073ecfd..a21305f8 100755 --- a/.taskfiles/ec2/scripts/setup.sh +++ b/.taskfiles/ec2/scripts/setup.sh @@ -1,7 +1,9 @@ #!/bin/bash -# One-time ares EC2 setup: Redis, log dirs, systemd worker template +# One-time ares EC2 setup: Redis, NATS JetStream, log dirs, systemd worker template set -euo pipefail +NATS_VERSION="${NATS_VERSION:-2.10.22}" + echo "=== Installing Redis ===" if command -v redis-server >/dev/null 2>&1; then redis-server --version @@ -18,6 +20,73 @@ else fi fi +echo "=== Installing NATS JetStream server ===" +if command -v nats-server >/dev/null 2>&1 && nats-server --version | grep -q "${NATS_VERSION}"; then + nats-server --version +else + arch="$(uname -m)" + case "${arch}" in + x86_64) nats_arch="amd64" ;; + aarch64) nats_arch="arm64" ;; + armv7l) nats_arch="arm7" ;; + *) + echo "ERROR: Unsupported arch: ${arch}" + exit 1 + ;; + esac + tarball="nats-server-v${NATS_VERSION}-linux-${nats_arch}.tar.gz" + curl -fsSL -o "/tmp/${tarball}" \ + "https://github.com/nats-io/nats-server/releases/download/v${NATS_VERSION}/${tarball}" + tar -xzf "/tmp/${tarball}" -C /tmp + install -m 0755 "/tmp/nats-server-v${NATS_VERSION}-linux-${nats_arch}/nats-server" /usr/local/bin/nats-server + rm -rf "/tmp/${tarball}" "/tmp/nats-server-v${NATS_VERSION}-linux-${nats_arch}" +fi + +echo "=== Configuring NATS ===" +getent group nats >/dev/null || groupadd --system nats +getent passwd nats >/dev/null || useradd --system --no-create-home --shell /usr/sbin/nologin --gid nats nats +mkdir -p /etc/nats /var/lib/nats/jetstream /var/log/nats +chown -R nats:nats /var/lib/nats /var/log/nats +chmod 0750 /var/lib/nats/jetstream + +cat >/etc/nats/nats-server.conf <<'NATS_EOF' +host: "127.0.0.1" +port: 4222 +http: "127.0.0.1:8222" +server_name: "ares-nats" +log_file: "/var/log/nats/nats-server.log" +logtime: true +jetstream { + store_dir: "/var/lib/nats/jetstream" + max_memory_store: 512MB + max_file_store: 4GB +} +NATS_EOF +chown nats:nats /etc/nats/nats-server.conf +chmod 0640 /etc/nats/nats-server.conf + +cat >/etc/systemd/system/nats-server.service <<'NATS_UNIT_EOF' +[Unit] +Description=NATS Server (Ares broker) +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +User=nats +Group=nats +ExecStart=/usr/local/bin/nats-server -c /etc/nats/nats-server.conf +ExecReload=/bin/kill -HUP $MAINPID +LimitNOFILE=65536 +Restart=on-failure +RestartSec=5 +StandardOutput=append:/var/log/nats/nats-server.stdout.log +StandardError=append:/var/log/nats/nats-server.stderr.log + +[Install] +WantedBy=multi-user.target +NATS_UNIT_EOF + echo "=== Creating directories ===" mkdir -p /var/log/ares /etc/ares @@ -25,14 +94,15 @@ echo "=== Creating systemd worker template unit ===" cat >/etc/systemd/system/ares@.service <<'UNIT_EOF' [Unit] Description=Ares Worker (%i) -After=redis.service -Wants=redis.service +After=redis.service nats-server.service +Wants=redis.service nats-server.service [Service] Type=simple ExecStart=/usr/local/bin/ares worker EnvironmentFile=-/etc/ares/env Environment=ARES_REDIS_URL=redis://127.0.0.1:6379 +Environment=NATS_URL=nats://127.0.0.1:4222 Environment=ARES_WORKER_ROLE=%i Environment=ARES_WORKER_MODE=tool_exec Environment=RUST_LOG=info @@ -63,10 +133,13 @@ if [ -d /usr/local/lib/python3.13/dist-packages/impacket ]; then echo "Removed pip impacket shadow — using system package" fi -echo "=== Enabling Redis ===" +echo "=== Enabling services ===" +systemctl daemon-reload systemctl enable redis-server 2>/dev/null || systemctl enable redis 2>/dev/null || true systemctl start redis-server 2>/dev/null || systemctl start redis 2>/dev/null || true -systemctl daemon-reload +systemctl enable nats-server +systemctl restart nats-server echo "=== Setup complete ===" redis-cli ping 2>/dev/null || echo "Redis not responding" +curl -fsS http://127.0.0.1:8222/varz >/dev/null 2>&1 && echo "NATS responding" || echo "NATS not responding" diff --git a/.taskfiles/ec2/scripts/status.sh b/.taskfiles/ec2/scripts/status.sh index 5ca3682e..150a7131 100755 --- a/.taskfiles/ec2/scripts/status.sh +++ b/.taskfiles/ec2/scripts/status.sh @@ -5,6 +5,14 @@ echo "=== Redis ===" redis-cli ping 2>/dev/null && redis-cli info server 2>/dev/null | grep -E "redis_version|uptime_in_seconds|connected_clients" || echo "Redis not running" echo "" +echo "=== NATS ===" +if curl -fsS http://127.0.0.1:8222/varz 2>/dev/null | grep -E '"version"|"now"|"connections"' | head -3; then + curl -fsS http://127.0.0.1:8222/jsz 2>/dev/null | grep -E '"streams"|"messages"|"bytes"' | head -3 || true +else + echo "NATS not running" +fi +echo "" + echo "=== Workers ===" for role in recon credential_access cracker acl privesc lateral coercion; do st=$(systemctl is-active ares@${role} 2>/dev/null || echo dead) diff --git a/AGENTS.md b/AGENTS.md index d9b85cdd..c2dc0abd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,7 +10,8 @@ Local (this machine) Remote (K8s or EC2) ares --k8s / --ec2 → ares orchestrator (LLM coordination loop) or `task` commands ares worker x7 (recon, credential_access, cracker, acl, privesc, lateral, coercion) - Redis (state store + message broker) + NATS JetStream (task/RPC broker) + Redis (durable state store) ``` The orchestrator and workers are autonomous LLM agents. You do not control them directly. Submit operations, monitor state, inject data when stuck, and debug failures. @@ -34,6 +35,7 @@ The orchestrator and workers are autonomous LLM agents. You do not control them --secrets-from 1password # Fetch API keys/secrets from 1Password CLI (op) --env-file # Load environment variables from a specific file --redis-url # Override the default Redis connection +# NATS connection comes from $NATS_URL (e.g. nats://nats:4222) ``` ## Development Workflow diff --git a/Cargo.lock b/Cargo.lock index 9ae0705a..3d5fac28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,10 +114,13 @@ dependencies = [ "ares-core", "ares-llm", "ares-tools", + "async-nats", "async-trait", + "bytes", "chrono", "clap", "dotenvy", + "futures", "redis", "regex", "rstest", @@ -138,8 +141,11 @@ version = "0.2.0" dependencies = [ "anyhow", "approx", + "async-nats", "base64", + "bytes", "chrono", + "futures", "md-5 0.11.0", "opentelemetry", "opentelemetry-otlp", @@ -213,6 +219,42 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-nats" +version = "0.47.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07d6f157065c3461096d51aacde0c326fa49f3f6e0199e204c566842cdaa5299" +dependencies = [ + "base64", + "bytes", + "futures-util", + "memchr", + "nkeys", + "nuid", + "pin-project", + "portable-atomic", + "rand 0.8.6", + "regex", + "ring", + "rustls-native-certs", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "serde_nanos", + "serde_repr", + "thiserror 1.0.69", + "time", + "tokio", + "tokio-rustls", + "tokio-stream", + "tokio-util", + "tokio-websockets", + "tracing", + "tryhard", + "url", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -342,6 +384,9 @@ name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "serde", +] [[package]] name = "cc" @@ -601,6 +646,38 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "curve25519-dalek-derive", + "digest 0.10.7", + "fiat-crypto", + "rustc_version", + "subtle", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "data-encoding" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8" + [[package]] name = "der" version = "0.7.10" @@ -612,6 +689,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deranged" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +dependencies = [ + "powerfmt", + "serde_core", +] + [[package]] name = "deunicode" version = "1.6.2" @@ -664,6 +751,28 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "sha2 0.10.9", + "signature", + "subtle", +] + [[package]] name = "either" version = "1.15.0" @@ -727,6 +836,12 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -771,6 +886,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.32" @@ -850,6 +980,7 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -1546,6 +1677,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nkeys" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879011babc47a1c7fdf5a935ae3cfe94f34645ca0cac1c7f6424b36fc743d1bf" +dependencies = [ + "data-encoding", + "ed25519", + "ed25519-dalek", + "getrandom 0.2.17", + "log", + "rand 0.8.6", + "signatory", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1555,6 +1701,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nuid" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83" +dependencies = [ + "rand 0.8.6", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1581,6 +1736,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-conv" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" + [[package]] name = "num-integer" version = "0.1.46" @@ -1896,6 +2057,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.5" @@ -1905,6 +2072,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -2535,6 +2708,26 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_nanos" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a93142f0367a4cc53ae0fead1bcda39e85beccfad3dcd717656cacab94b12985" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2624,6 +2817,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signatory" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31" +dependencies = [ + "pkcs8", + "rand_core 0.6.4", + "signature", + "zeroize", +] + [[package]] name = "signature" version = "2.2.0" @@ -3034,6 +3239,37 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.8.3" @@ -3121,6 +3357,27 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-websockets" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f591660438b3038dd04d16c938271c79e7e06260ad2ea2885a4861bfb238605d" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-sink", + "http", + "httparse", + "rand 0.8.6", + "ring", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tokio-util", + "webpki-roots 0.26.11", +] + [[package]] name = "toml_datetime" version = "1.1.1+spec-1.1.0" @@ -3321,6 +3578,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tryhard" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fe58ebd5edd976e0fe0f8a14d2a04b7c81ef153ea9a54eebc42e67c2c23b4e5" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "typenum" version = "1.20.0" @@ -3604,6 +3871,24 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.7", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "whoami" version = "1.6.1" @@ -3620,7 +3905,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3404af61..80cdbaa4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,9 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"] } redis = { version = "1.0", features = ["tokio-comp", "connection-manager"] } +async-nats = "0.47" +futures = "0.3" +bytes = "1" chrono = { version = "0.4", features = ["serde"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } diff --git a/ansible/README.md b/ansible/README.md index fc823c55..90e6a6b4 100644 --- a/ansible/README.md +++ b/ansible/README.md @@ -29,9 +29,10 @@ graph TD Roles --> R9[fluent_bit] Roles --> R10[lateral_movement_tools *] Roles --> R11[mythic *] - Roles --> R12[privesc_tools *] - Roles --> R13[recon_tools *] - Roles --> R14[redis] + Roles --> R12[nats] + Roles --> R13[privesc_tools *] + Roles --> R14[recon_tools *] + Roles --> R15[redis] Collection --> Playbooks[Playbooks] Playbooks --> PB0[ares] Playbooks --> PB1[linux] diff --git a/ansible/playbooks/ares/goad_attack_box.yml b/ansible/playbooks/ares/goad_attack_box.yml index 2fdfdba6..3542c96c 100644 --- a/ansible/playbooks/ares/goad_attack_box.yml +++ b/ansible/playbooks/ares/goad_attack_box.yml @@ -134,11 +134,16 @@ vars: privesc_tools_verify_install: true - # Redis server for Ares worker message broker + # Redis server for Ares durable state store - role: dreadnode.nimbus_range.redis vars: redis_verify_install: true + # NATS JetStream server for Ares task + RPC broker + - role: dreadnode.nimbus_range.nats + vars: + nats_verify_install: true + # Grafana Alloy for log shipping to Loki - name: Install and configure Grafana Alloy role: grafana.grafana.alloy diff --git a/ansible/roles/nats/README.md b/ansible/roles/nats/README.md new file mode 100644 index 00000000..7d01533b --- /dev/null +++ b/ansible/roles/nats/README.md @@ -0,0 +1,80 @@ + + +# nats + +## Description + +NATS JetStream server for Ares task and RPC broker + +## Requirements + +- Ansible >= 2.18.4 + +## Role Variables + +### Default Variables (main.yml) + +| Variable | Type | Default | Description | +| -------- | ---- | ------- | ----------- | +| `nats_version` | str | 2.10.22 | No description | +| `nats_install_dir` | str | /usr/local/bin | No description | +| `nats_user` | str | nats | No description | +| `nats_group` | str | nats | No description | +| `nats_bind_address` | str | 127.0.0.1 | No description | +| `nats_port` | int | 4222 | No description | +| `nats_http_port` | int | 8222 | No description | +| `nats_jetstream_enabled` | bool | True | No description | +| `nats_jetstream_store_dir` | str | /var/lib/nats/jetstream | No description | +| `nats_jetstream_max_memory` | str | 512M | No description | +| `nats_jetstream_max_file` | str | 4G | No description | +| `nats_log_dir` | str | /var/log/nats | No description | +| `nats_log_file` | str | /var/log/nats/nats-server.log | No description | +| `nats_debug` | bool | False | No description | +| `nats_verify_install` | bool | False | No description | + +## Tasks + +### linux.yml + + +- **Map kernel arch to NATS release arch** (ansible.builtin.set_fact) +- **Create NATS group** (ansible.builtin.group) +- **Create NATS user** (ansible.builtin.user) +- **Create NATS directories** (ansible.builtin.file) +- **Check installed NATS version** (ansible.builtin.command) +- **Download NATS server release** (ansible.builtin.unarchive) - Conditional +- **Install NATS server binary** (ansible.builtin.copy) - Conditional +- **Clean up NATS release tarball directory** (ansible.builtin.file) - Conditional +- **Render NATS server config** (ansible.builtin.template) +- **Install NATS systemd unit** (ansible.builtin.template) +- **Enable and start NATS** (ansible.builtin.systemd) +- **Verify NATS is responding** (ansible.builtin.uri) - Conditional +- **Display NATS verification** (ansible.builtin.debug) - Conditional + +### main.yml + + +- **Include Linux tasks** (ansible.builtin.include_tasks) - Conditional + +## Example Playbook + +```yaml +- hosts: servers + roles: + - nats +``` + +## Author Information + +- **Author**: Dreadnode +- **Company**: dreadnode +- **License**: MIT + +## Platforms + + +- Ubuntu: all +- Debian: all +- Kali: all + + diff --git a/ansible/roles/nats/defaults/main.yml b/ansible/roles/nats/defaults/main.yml new file mode 100644 index 00000000..5abfedf5 --- /dev/null +++ b/ansible/roles/nats/defaults/main.yml @@ -0,0 +1,25 @@ +--- +# NATS server release pinned for reproducibility +nats_version: "2.10.22" +nats_install_dir: "/usr/local/bin" +nats_user: "nats" +nats_group: "nats" + +# Listen address & port +nats_bind_address: "127.0.0.1" +nats_port: 4222 +nats_http_port: 8222 + +# JetStream (durable streams) — required by Ares +nats_jetstream_enabled: true +nats_jetstream_store_dir: "/var/lib/nats/jetstream" +nats_jetstream_max_memory: "512M" +nats_jetstream_max_file: "4G" + +# Logging +nats_log_dir: "/var/log/nats" +nats_log_file: "/var/log/nats/nats-server.log" +nats_debug: false + +# Verification +nats_verify_install: false diff --git a/ansible/roles/nats/handlers/main.yml b/ansible/roles/nats/handlers/main.yml new file mode 100644 index 00000000..cb07ce39 --- /dev/null +++ b/ansible/roles/nats/handlers/main.yml @@ -0,0 +1,11 @@ +--- +- name: Restart nats + ansible.builtin.systemd: + name: nats-server + state: restarted + become: true + +- name: Reload systemd + ansible.builtin.systemd: + daemon_reload: true + become: true diff --git a/ansible/roles/nats/meta/main.yml b/ansible/roles/nats/meta/main.yml new file mode 100644 index 00000000..caec62d9 --- /dev/null +++ b/ansible/roles/nats/meta/main.yml @@ -0,0 +1,24 @@ +--- +galaxy_info: + author: Dreadnode + namespace: dreadnode + description: NATS JetStream server for Ares task and RPC broker + company: dreadnode + license: MIT + role_name: nats + min_ansible_version: "2.18.4" + platforms: + - name: Ubuntu + versions: + - all + - name: Debian + versions: + - all + - name: Kali + versions: + - all + galaxy_tags: + - ares + - nats + - jetstream + - broker diff --git a/ansible/roles/nats/tasks/linux.yml b/ansible/roles/nats/tasks/linux.yml new file mode 100644 index 00000000..3eda8d10 --- /dev/null +++ b/ansible/roles/nats/tasks/linux.yml @@ -0,0 +1,116 @@ +--- +- name: Map kernel arch to NATS release arch + ansible.builtin.set_fact: + nats_release_arch: >- + {{ { + 'x86_64': 'amd64', + 'aarch64': 'arm64', + 'armv7l': 'arm7', + }[ansible_architecture] }} + +- name: Create NATS group + ansible.builtin.group: + name: "{{ nats_group }}" + system: true + state: present + become: true + +- name: Create NATS user + ansible.builtin.user: + name: "{{ nats_user }}" + group: "{{ nats_group }}" + system: true + shell: /usr/sbin/nologin + home: "{{ nats_jetstream_store_dir }}" + create_home: false + state: present + become: true + +- name: Create NATS directories + ansible.builtin.file: + path: "{{ item }}" + state: directory + owner: "{{ nats_user }}" + group: "{{ nats_group }}" + mode: '0750' + loop: + - "{{ nats_jetstream_store_dir }}" + - "{{ nats_log_dir }}" + - /etc/nats + become: true + +- name: Check installed NATS version + ansible.builtin.command: + cmd: "{{ nats_install_dir }}/nats-server --version" + register: nats_installed_version + changed_when: false + failed_when: false + +- name: Download NATS server release + ansible.builtin.unarchive: + src: "https://github.com/nats-io/nats-server/releases/download/v{{ nats_version }}/nats-server-v{{ nats_version }}-linux-{{ nats_release_arch }}.tar.gz" + dest: /tmp + remote_src: true + creates: "/tmp/nats-server-v{{ nats_version }}-linux-{{ nats_release_arch }}/nats-server" + when: nats_version not in (nats_installed_version.stdout | default('')) + become: true + +- name: Install NATS server binary + ansible.builtin.copy: + src: "/tmp/nats-server-v{{ nats_version }}-linux-{{ nats_release_arch }}/nats-server" + dest: "{{ nats_install_dir }}/nats-server" + mode: '0755' + remote_src: true + when: nats_version not in (nats_installed_version.stdout | default('')) + become: true + notify: Restart nats + +- name: Clean up NATS release tarball directory + ansible.builtin.file: + path: "/tmp/nats-server-v{{ nats_version }}-linux-{{ nats_release_arch }}" + state: absent + when: nats_version not in (nats_installed_version.stdout | default('')) + become: true + +- name: Render NATS server config + ansible.builtin.template: + src: nats-server.conf.j2 + dest: /etc/nats/nats-server.conf + owner: "{{ nats_user }}" + group: "{{ nats_group }}" + mode: '0640' + become: true + notify: Restart nats + +- name: Install NATS systemd unit + ansible.builtin.template: + src: nats-server.service.j2 + dest: /etc/systemd/system/nats-server.service + mode: '0644' + become: true + notify: + - Reload systemd + - Restart nats + +- name: Enable and start NATS + ansible.builtin.systemd: + name: nats-server + enabled: true + state: started + daemon_reload: true + become: true + +- name: Verify NATS is responding + ansible.builtin.uri: + url: "http://{{ nats_bind_address }}:{{ nats_http_port }}/varz" + return_content: true + register: nats_varz + retries: 5 + delay: 2 + until: nats_varz.status == 200 + when: nats_verify_install + +- name: Display NATS verification + ansible.builtin.debug: + msg: "NATS server: {{ (nats_varz.json | default({})).version | default('unknown') }} (jetstream={{ nats_jetstream_enabled }})" + when: nats_verify_install diff --git a/ansible/roles/nats/tasks/main.yml b/ansible/roles/nats/tasks/main.yml new file mode 100644 index 00000000..0f9cb2c3 --- /dev/null +++ b/ansible/roles/nats/tasks/main.yml @@ -0,0 +1,4 @@ +--- +- name: Include Linux tasks + ansible.builtin.include_tasks: linux.yml + when: ansible_os_family != 'Windows' diff --git a/ansible/roles/nats/templates/nats-server.conf.j2 b/ansible/roles/nats/templates/nats-server.conf.j2 new file mode 100644 index 00000000..38edd879 --- /dev/null +++ b/ansible/roles/nats/templates/nats-server.conf.j2 @@ -0,0 +1,25 @@ +# {{ ansible_managed }} +# NATS server configuration for Ares (task + RPC broker) + +host: "{{ nats_bind_address }}" +port: {{ nats_port }} + +http: "{{ nats_bind_address }}:{{ nats_http_port }}" + +server_name: "ares-nats" + +{% if nats_debug %} +debug: true +trace: false +{% endif %} + +log_file: "{{ nats_log_file }}" +logtime: true + +{% if nats_jetstream_enabled %} +jetstream { + store_dir: "{{ nats_jetstream_store_dir }}" + max_memory_store: {{ nats_jetstream_max_memory | regex_replace('([0-9]+)([KMGT])$', '\\1\\2B') }} + max_file_store: {{ nats_jetstream_max_file | regex_replace('([0-9]+)([KMGT])$', '\\1\\2B') }} +} +{% endif %} diff --git a/ansible/roles/nats/templates/nats-server.service.j2 b/ansible/roles/nats/templates/nats-server.service.j2 new file mode 100644 index 00000000..1c126e15 --- /dev/null +++ b/ansible/roles/nats/templates/nats-server.service.j2 @@ -0,0 +1,20 @@ +# {{ ansible_managed }} +[Unit] +Description=NATS Server (Ares broker) +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +User={{ nats_user }} +Group={{ nats_group }} +ExecStart={{ nats_install_dir }}/nats-server -c /etc/nats/nats-server.conf +ExecReload=/bin/kill -HUP $MAINPID +LimitNOFILE=65536 +Restart=on-failure +RestartSec=5 +StandardOutput=append:{{ nats_log_dir }}/nats-server.stdout.log +StandardError=append:{{ nats_log_dir }}/nats-server.stderr.log + +[Install] +WantedBy=multi-user.target diff --git a/ansible/roles/redis/templates/ares@.service.j2 b/ansible/roles/redis/templates/ares@.service.j2 index bc4f23c6..dab687a3 100644 --- a/ansible/roles/redis/templates/ares@.service.j2 +++ b/ansible/roles/redis/templates/ares@.service.j2 @@ -1,12 +1,13 @@ [Unit] Description=Ares Worker (%i) -After=redis.service -Wants=redis.service +After=redis.service nats-server.service +Wants=redis.service nats-server.service [Service] Type=simple ExecStart={{ redis_ares_worker_binary }} worker Environment=ARES_REDIS_URL=redis://{{ redis_bind_address }}:{{ redis_port }} +Environment=NATS_URL=nats://{{ redis_bind_address }}:4222 Environment=ARES_WORKER_ROLE=%i Environment=ARES_WORKER_MODE=tool_exec Environment=RUST_LOG=info diff --git a/ares-cli/Cargo.toml b/ares-cli/Cargo.toml index ba2f93bf..a59f4d14 100644 --- a/ares-cli/Cargo.toml +++ b/ares-cli/Cargo.toml @@ -21,6 +21,9 @@ serde_json = { workspace = true } serde_yaml = { workspace = true } tokio = { workspace = true } redis = { workspace = true } +async-nats = { workspace = true } +futures = { workspace = true } +bytes = { workspace = true } chrono = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/ares-cli/src/blue/delete.rs b/ares-cli/src/blue/delete.rs index be0aaa3b..7e443f0d 100644 --- a/ares-cli/src/blue/delete.rs +++ b/ares-cli/src/blue/delete.rs @@ -126,11 +126,23 @@ pub(crate) async fn blue_cleanup( let inv_keys = scan_redis_keys(&mut conn, "ares:blue:inv:*").await?; let op_keys = scan_redis_keys(&mut conn, "ares:blue:op:*").await?; let active_exists: bool = conn.exists("ares:blue:active_investigations").await?; - let queue_len: i64 = conn.llen("ares:blue:investigations").await?; + + // Inspect NATS investigation stream depth (best-effort). + let queue_len: i64 = match ares_core::nats::NatsBroker::connect_from_env().await { + Ok(nats) => match nats + .jetstream() + .get_stream(ares_core::nats::BLUE_TASKS_STREAM) + .await + { + Ok(stream) => stream.cached_info().state.messages as i64, + Err(_) => 0, + }, + Err(_) => 0, + }; println!("Found {} investigation keys", inv_keys.len()); println!("Found {} operation tracking keys", op_keys.len()); - println!("Queue length: {queue_len}"); + println!("NATS blue queue depth: {queue_len}"); if dry_run { println!("(dry run - no changes made)"); @@ -160,9 +172,17 @@ pub(crate) async fn blue_cleanup( let count: usize = conn.del("ares:blue:active_investigations").await?; deleted += count; } + // Drain queued investigation requests from the NATS stream if queue_len > 0 { - let _: usize = conn.del("ares:blue:investigations").await?; - deleted += 1; + if let Ok(nats) = ares_core::nats::NatsBroker::connect_from_env().await { + if let Ok(stream) = nats + .jetstream() + .get_stream(ares_core::nats::BLUE_TASKS_STREAM) + .await + { + let _ = stream.purge().await; + } + } } println!("Deleted {deleted} keys"); diff --git a/ares-cli/src/blue/submit.rs b/ares-cli/src/blue/submit.rs index ec2f5957..894c2d3f 100644 --- a/ares-cli/src/blue/submit.rs +++ b/ares-cli/src/blue/submit.rs @@ -3,6 +3,8 @@ use chrono::Utc; use redis::AsyncCommands; use tracing::info; +use ares_core::nats::NatsBroker; +use ares_core::state::blue_task_queue::BlueTaskQueue; use ares_core::state::RedisStateReader; use crate::ops::submit::{collect_env_vars, resolve_model, BLUE_ENV_VAR_NAMES}; @@ -72,11 +74,14 @@ pub(crate) async fn blue_submit( let _: () = conn.expire(&env_vars_key, 3600).await?; } - // Push investigation request to queue - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; + // Push investigation request to NATS investigation queue + let nats = NatsBroker::connect_from_env() + .await + .context("Connect to NATS for blue investigation submission")?; + nats.ensure_streams().await?; + BlueTaskQueue::submit_investigation_request(&nats, &request) + .await + .context("Failed to publish investigation request to NATS")?; info!("Investigation submitted: {inv_id}"); println!("Investigation submitted: {inv_id}"); @@ -219,15 +224,18 @@ pub(crate) async fn blue_from_operation( let _: () = conn.expire(&env_vars_key, 3600).await?; } - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; - let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); let _: () = conn.sadd(&op_inv_key, &inv_id).await?; let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; // 7 day TTL + let nats = NatsBroker::connect_from_env() + .await + .context("Connect to NATS for blue investigation submission")?; + nats.ensure_streams().await?; + BlueTaskQueue::submit_investigation_request(&nats, &request) + .await + .context("Failed to publish investigation request to NATS")?; + info!("Investigation submitted: {inv_id}"); println!("Investigation submitted: {inv_id} (from operation {op_id})"); println!("Status: submitted"); diff --git a/ares-cli/src/orchestrator/blue/auto_submit.rs b/ares-cli/src/orchestrator/blue/auto_submit.rs index 38ccfd73..121d7f3f 100644 --- a/ares-cli/src/orchestrator/blue/auto_submit.rs +++ b/ares-cli/src/orchestrator/blue/auto_submit.rs @@ -231,16 +231,17 @@ async fn submit_investigation( let _: () = conn.expire(&env_key, 3600).await?; } - // Push to investigation queue - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; - - // Track investigation against operation + // Track investigation against operation (Redis state) let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); let _: () = conn.sadd(&op_inv_key, &inv_id).await?; let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; + // Publish investigation request to NATS (reuse the orchestrator's broker) + let nats = queue + .nats_broker() + .ok_or_else(|| anyhow::anyhow!("Orchestrator TaskQueue has no NATS broker"))?; + ares_core::state::blue_task_queue::BlueTaskQueue::submit_investigation_request(&nats, &request) + .await?; + Ok(inv_id) } diff --git a/ares-cli/src/orchestrator/blue/runner.rs b/ares-cli/src/orchestrator/blue/runner.rs index 33181f57..2d724f93 100644 --- a/ares-cli/src/orchestrator/blue/runner.rs +++ b/ares-cli/src/orchestrator/blue/runner.rs @@ -36,6 +36,7 @@ pub struct BlueOrchestrator { model_name: String, dispatcher: Arc, redis_url: String, + nats_url: String, } impl BlueOrchestrator { @@ -44,12 +45,14 @@ impl BlueOrchestrator { model_name: String, dispatcher: Arc, redis_url: String, + nats_url: String, ) -> Self { Self { provider: Arc::from(provider), model_name, dispatcher, redis_url, + nats_url, } } @@ -169,9 +172,9 @@ impl BlueOrchestrator { // Clean up stale investigations from previous runs self.cleanup_stale_investigations().await; - let mut task_queue = BlueTaskQueue::connect(&self.redis_url) + let mut task_queue = BlueTaskQueue::connect_with_nats(&self.redis_url, &self.nats_url) .await - .context("Failed to connect blue task queue to Redis")?; + .context("Failed to connect blue task queue (Redis + NATS)")?; let mut retry_delay = Duration::from_secs(1); let max_retry_delay = Duration::from_secs(30); @@ -362,10 +365,12 @@ impl BlueOrchestrator { // Reconnect the task queue — the previous ConnectionManager // can be stuck after Redis restarts or prolonged outages. - match BlueTaskQueue::connect(&self.redis_url).await { + match BlueTaskQueue::connect_with_nats(&self.redis_url, &self.nats_url) + .await + { Ok(new_queue) => { task_queue = new_queue; - info!("Blue orchestrator: reconnected to Redis"); + info!("Blue orchestrator: reconnected to Redis + NATS"); } Err(reconnect_err) => { warn!("Blue orchestrator: reconnect failed: {reconnect_err}"); @@ -392,10 +397,12 @@ pub fn spawn_blue_orchestrator( model_name: String, dispatcher: Arc, redis_url: String, + nats_url: String, shutdown_rx: watch::Receiver, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { - let orchestrator = BlueOrchestrator::new(provider, model_name, dispatcher, redis_url); + let orchestrator = + BlueOrchestrator::new(provider, model_name, dispatcher, redis_url, nats_url); if let Err(e) = orchestrator.run(shutdown_rx).await { error!("Blue orchestrator exited with error: {e}"); } diff --git a/ares-cli/src/orchestrator/completion.rs b/ares-cli/src/orchestrator/completion.rs index 32cc293a..6b50ab2f 100644 --- a/ares-cli/src/orchestrator/completion.rs +++ b/ares-cli/src/orchestrator/completion.rs @@ -275,11 +275,17 @@ pub async fn wait_for_completion( .query_async(&mut conn) .await .unwrap_or(0); - let queued: i64 = redis::cmd("LLEN") - .arg("ares:blue:investigations") - .query_async(&mut conn) - .await - .unwrap_or(0); + let queued: i64 = match dispatcher.queue.nats_broker() { + Some(nats) => match nats + .jetstream() + .get_stream(ares_core::nats::BLUE_TASKS_STREAM) + .await + { + Ok(stream) => stream.cached_info().state.messages as i64, + Err(_) => 0, + }, + None => 0, + }; if active == 0 && queued == 0 { info!("All blue investigations finished"); @@ -463,9 +469,9 @@ async fn auto_submit_blue_investigation( let _: () = conn.expire(&env_vars_key, 3600).await?; } - // Pre-register as active BEFORE pushing to queue to avoid TOCTOU race: + // Pre-register as active BEFORE publishing to avoid TOCTOU race: // without this, the completion wait loop can observe both queued==0 and - // active==0 in the window between the blue orchestrator's BRPOP (drains + // active==0 in the window between the blue orchestrator's pull (drains // the queue) and its register_investigation (SADDs to active set). let _: () = conn .sadd(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, &inv_id) @@ -474,17 +480,19 @@ async fn auto_submit_blue_investigation( .expire(ares_core::state::BLUE_ACTIVE_INVESTIGATIONS, 86400) .await?; - // Push investigation request to queue - let request_json = serde_json::to_string(&request)?; - let _: () = conn - .rpush("ares:blue:investigations", &request_json) - .await?; - // Track investigation against operation let op_inv_key = format!("ares:blue:op:{op_id}:investigations"); let _: () = conn.sadd(&op_inv_key, &inv_id).await?; let _: () = conn.expire(&op_inv_key, 7 * 24 * 3600).await?; + // Publish investigation request to NATS + let nats = dispatcher + .queue + .nats_broker() + .ok_or_else(|| anyhow::anyhow!("Dispatcher TaskQueue has no NATS broker"))?; + ares_core::state::blue_task_queue::BlueTaskQueue::submit_investigation_request(&nats, &request) + .await?; + info!( investigation_id = inv_id, operation_id = op_id, diff --git a/ares-cli/src/orchestrator/config.rs b/ares-cli/src/orchestrator/config.rs index 1b467b58..d3c49aad 100644 --- a/ares-cli/src/orchestrator/config.rs +++ b/ares-cli/src/orchestrator/config.rs @@ -16,6 +16,9 @@ pub struct OrchestratorConfig { /// Redis connection URL (supports `redis://` and `redis+sentinel://`). pub redis_url: String, + /// NATS connection URL for the work-queue/result broker. + pub nats_url: String, + /// Operation ID this orchestrator instance manages. pub operation_id: String, @@ -92,6 +95,8 @@ impl OrchestratorConfig { .or_else(|_| env::var("REDIS_URL")) .unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + let nats_url = ares_core::nats::NatsBroker::url_from_env(); + let raw_op = env::var("ARES_OPERATION_ID") .map_err(|_| anyhow::anyhow!("ARES_OPERATION_ID is required"))?; @@ -196,6 +201,7 @@ impl OrchestratorConfig { Ok(Self { redis_url, + nats_url, operation_id, max_concurrent_tasks, heartbeat_interval: Duration::from_secs(heartbeat_interval_secs), @@ -297,6 +303,7 @@ mod tests { pub(crate) fn make_config(max_tasks: usize) -> OrchestratorConfig { OrchestratorConfig { redis_url: "redis://localhost".into(), + nats_url: "nats://localhost:4222".into(), operation_id: "test-op".into(), max_concurrent_tasks: max_tasks, heartbeat_interval: Duration::from_secs(30), diff --git a/ares-cli/src/orchestrator/deferred.rs b/ares-cli/src/orchestrator/deferred.rs index 48b1b111..b43dad3c 100644 --- a/ares-cli/src/orchestrator/deferred.rs +++ b/ares-cli/src/orchestrator/deferred.rs @@ -7,6 +7,12 @@ //! //! Score formula: `(priority * 1_000_000_000) + (unix_millis)` //! Lower score = higher priority = processed first. +//! +//! Stays on Redis (not NATS): this is operation-scoped throttling state owned +//! by a single orchestrator, not a broker/transport concern. Priority ordering +//! via ZSET score is non-trivial to model in JetStream and offers no benefit +//! here since the queue is in-process. Redis remains for state; NATS handles +//! cross-process queues. use anyhow::{Context, Result}; use chrono::Utc; diff --git a/ares-cli/src/orchestrator/mod.rs b/ares-cli/src/orchestrator/mod.rs index 003bd7af..4c57df17 100644 --- a/ares-cli/src/orchestrator/mod.rs +++ b/ares-cli/src/orchestrator/mod.rs @@ -105,9 +105,9 @@ async fn run_inner() -> Result<()> { "Configuration loaded" ); - let queue = TaskQueue::connect(&config.redis_url) + let queue = TaskQueue::connect(&config.redis_url, &config.nats_url) .await - .context("Failed to connect to Redis")?; + .context("Failed to connect to Redis/NATS")?; let acquired = queue .try_acquire_lock(&config.operation_id, config.lock_ttl) @@ -471,6 +471,7 @@ async fn run_inner() -> Result<()> { blue_model, blue_disp, config.redis_url.clone(), + config.nats_url.clone(), shutdown_rx.clone(), ), blue::spawn_blue_auto_submit( @@ -488,7 +489,10 @@ async fn run_inner() -> Result<()> { let blue_handle: Option<(tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>)> = None; { - let recovery_mgr = recovery::OperationRecoveryManager::new(config.redis_url.clone()); + let recovery_mgr = recovery::OperationRecoveryManager::new( + config.redis_url.clone(), + config.nats_url.clone(), + ); match recovery_mgr.recover(&config.operation_id).await { Ok(recovered) => { if !recovered.requeued_task_ids.is_empty() || !recovered.failed_task_ids.is_empty() @@ -719,6 +723,7 @@ async fn run_blue_only() -> Result<()> { let redis_url = std::env::var("ARES_REDIS_URL") .or_else(|_| std::env::var("REDIS_URL")) .unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string()); + let nats_url = ares_core::nats::NatsBroker::url_from_env(); // Load YAML config for observability URLs if let Ok(cfg) = ares_core::config::AresConfig::from_env() { @@ -743,9 +748,9 @@ async fn run_blue_only() -> Result<()> { ares_llm::create_provider(&model_spec).context("Failed to create LLM provider")?; // Blue uses a simple Redis-based tool dispatcher (no operation-scoped auth throttle) - let queue = self::task_queue::TaskQueue::connect(&redis_url) + let queue = self::task_queue::TaskQueue::connect(&redis_url, &nats_url) .await - .context("Failed to connect to Redis")?; + .context("Failed to connect to Redis/NATS")?; let auth_throttle = tool_dispatcher::AuthThrottle::new(3, std::time::Duration::from_secs(30)); let blue_disp: Arc = Arc::new(tool_dispatcher::RedisToolDispatcher::new( @@ -758,8 +763,14 @@ async fn run_blue_only() -> Result<()> { let (shutdown_tx, shutdown_rx) = watch::channel(false); - let blue_handle = - blue::spawn_blue_orchestrator(provider, model_name, blue_disp, redis_url, shutdown_rx); + let blue_handle = blue::spawn_blue_orchestrator( + provider, + model_name, + blue_disp, + redis_url, + nats_url, + shutdown_rx, + ); // Wait for shutdown signal signal::ctrl_c().await?; diff --git a/ares-cli/src/orchestrator/monitoring.rs b/ares-cli/src/orchestrator/monitoring.rs index a6e93321..e4473164 100644 --- a/ares-cli/src/orchestrator/monitoring.rs +++ b/ares-cli/src/orchestrator/monitoring.rs @@ -116,7 +116,7 @@ pub fn spawn_lock_keeper( // Create a dedicated Redis connection for the lock keeper so that // EXPIRE commands are not queued behind heavy BRPOP/LPUSH traffic // on the shared connection manager. - let dedicated_queue = match TaskQueue::connect(&config.redis_url).await { + let dedicated_queue = match TaskQueue::connect(&config.redis_url, &config.nats_url).await { Ok(q) => { info!("Lock keeper using dedicated Redis connection"); q diff --git a/ares-cli/src/orchestrator/recovery/manager.rs b/ares-cli/src/orchestrator/recovery/manager.rs index 81101a34..ef1ae783 100644 --- a/ares-cli/src/orchestrator/recovery/manager.rs +++ b/ares-cli/src/orchestrator/recovery/manager.rs @@ -20,12 +20,16 @@ use super::types::{ /// Manages recovery of operation state from Redis after a restart. pub struct OperationRecoveryManager { redis_url: String, + nats_url: String, } impl OperationRecoveryManager { /// Create a new recovery manager. - pub fn new(redis_url: String) -> Self { - Self { redis_url } + pub fn new(redis_url: String, nats_url: String) -> Self { + Self { + redis_url, + nats_url, + } } /// Attempt to recover an operation's state from Redis. @@ -43,7 +47,7 @@ impl OperationRecoveryManager { let mut last_err: Option = None; for attempt in 1..=MAX_CONNECTION_RETRIES { - let queue = match TaskQueue::connect(&self.redis_url).await { + let queue = match TaskQueue::connect(&self.redis_url, &self.nats_url).await { Ok(q) => q, Err(e) => { if attempt < MAX_CONNECTION_RETRIES { diff --git a/ares-cli/src/orchestrator/task_queue.rs b/ares-cli/src/orchestrator/task_queue.rs index 45aba1a1..2982072c 100644 --- a/ares-cli/src/orchestrator/task_queue.rs +++ b/ares-cli/src/orchestrator/task_queue.rs @@ -1,41 +1,52 @@ -//! Redis-backed task queue matching the Python `RedisTaskQueue`. +//! Hybrid Redis + NATS JetStream task queue. //! -//! Key patterns: -//! - `ares:tasks:{role}` — List, per-role task queue -//! - `ares:results:{task_id}` — List, per-task result mailbox (TTL 24h) -//! - `ares:heartbeat:{agent}` — String, agent heartbeat (TTL from config) -//! - `ares:task_status:{task_id}` — String, task lifecycle JSON -//! - `ares:lock:{op_id}` — String, operation lock with TTL refresh +//! Work queues and result mailboxes live in NATS JetStream. Operation lock, +//! agent heartbeats, and task-status records stay in Redis (the right tool +//! for ephemeral KV with TTL). //! -//! Workers BRPOP from the right; the orchestrator pushes to the left (LPUSH) -//! for normal priority and to the right (RPUSH) for urgent priority, giving -//! FIFO semantics with priority bypass. +//! NATS subjects: +//! - `ares.tasks.{role}` work queue, normal priority +//! - `ares.tasks.urgent.{role}` work queue, urgent (priority ≤ 2) +//! - `ares.tasks.results.{task_id}` durable result, one per task +//! +//! Redis keys (state only): +//! - `ares:heartbeat:{agent}` string, agent heartbeat (TTL) +//! - `ares:task_status:{task_id}` string, task lifecycle JSON (TTL 24h) +//! - `ares:lock:{op_id}` string, operation lock (TTL refresh) +//! +//! The work queue uses JetStream pull consumers with explicit acks and +//! bounded redelivery, replacing the silent-loss `BRPOP` pattern. use std::collections::HashMap; use std::time::Duration; use anyhow::{Context, Result}; +use bytes::Bytes; use chrono::{DateTime, Utc}; +use futures::StreamExt; use redis::aio::{ConnectionLike, ConnectionManager}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use tracing::{debug, info, warn}; use uuid::Uuid; -pub const TASK_QUEUE_PREFIX: &str = "ares:tasks"; -pub const RESULT_QUEUE_PREFIX: &str = "ares:results"; +use ares_core::nats::{self, NatsBroker}; + pub const HEARTBEAT_PREFIX: &str = "ares:heartbeat"; pub const TASK_STATUS_PREFIX: &str = "ares:task_status"; pub const LOCK_PREFIX: &str = "ares:lock"; -pub const STATE_UPDATE_CHANNEL_PREFIX: &str = "ares:state:updates"; - -/// Result keys expire after 24 hours. -const RESULT_TTL_SECS: u64 = 60 * 60 * 24; /// Task status keys expire after 24 hours. const TASK_STATUS_TTL_SECS: u64 = 60 * 60 * 24; +/// Default timeout when polling a single result via an ephemeral consumer. +const DEFAULT_RESULT_POLL_TIMEOUT: Duration = Duration::from_secs(1); + /// Task submitted to a role queue. Mirrors `ares.core.task_queue.TaskMessage`. +/// +/// Construction is exercised by tests; production red-team dispatch goes through +/// the in-process LLM runner instead, so the bin build sees this as unused. +#[allow(dead_code)] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TaskMessage { pub task_id: String, @@ -49,6 +60,7 @@ pub struct TaskMessage { pub callback_queue: Option, } +#[allow(dead_code)] fn default_priority() -> i32 { 5 } @@ -81,87 +93,60 @@ pub struct HeartbeatData { pub pod_name: Option, } -/// Async Redis task queue implementing the Ares queue protocol. +/// Hybrid task queue: NATS for queues, Redis for state. /// -/// Generic over connection type to support both production (`ConnectionManager`) -/// and test (`MockRedisConnection`) backends. +/// Generic over the Redis backend so unit tests can use a mock; `nats` is +/// `None` in tests that don't exercise queue methods. #[derive(Clone)] pub struct TaskQueueCore { conn: C, - /// Redis URL retained so we can open dedicated connections for blocking - /// commands (BRPOP) that would otherwise serialize on the shared - /// multiplexed connection. - redis_url: Option, + nats: Option, } -/// Production task queue backed by a Redis `ConnectionManager`. +/// Production task queue. pub type TaskQueue = TaskQueueCore; -// -- ConnectionManager-specific methods ------------------------------------ - impl TaskQueue { - /// Connect to Redis and return a TaskQueue. - pub async fn connect(redis_url: &str) -> Result { + /// Connect to Redis + NATS and return a TaskQueue. + /// + /// Ensures the standard JetStream streams exist before returning. + pub async fn connect(redis_url: &str, nats_url: &str) -> Result { let client = redis::Client::open(redis_url) .with_context(|| format!("Invalid Redis URL: {redis_url}"))?; - // Default response_timeout is 500ms which is too short for BRPOP - // blocking calls (tool results can take minutes). Without this fix, - // the client-side timeout cancels the future but the server-side - // BRPOP remains registered, consuming results that are silently lost. - let config = redis::aio::ConnectionManagerConfig::new() - .set_response_timeout(Some(Duration::from_secs(1800))); let conn = client - .get_connection_manager_with_config(config) + .get_connection_manager() .await .with_context(|| format!("Failed to connect to Redis at {redis_url}"))?; - info!(url = %redis_url, "Connected to Redis"); + info!(url = %redis_url, "Connected to Redis (state)"); + + let nats = NatsBroker::connect(nats_url).await?; + nats.ensure_streams().await?; + Ok(Self { conn, - redis_url: Some(redis_url.to_string()), + nats: Some(nats), }) } - - /// Create a dedicated (non-shared) multiplexed connection for blocking - /// commands like BRPOP. Each call opens a fresh TCP connection so - /// concurrent BRPOP calls from different agent loops do not serialize. - pub async fn dedicated_connection(&self) -> Result { - let url = self - .redis_url - .as_deref() - .ok_or_else(|| anyhow::anyhow!("No redis_url stored (test backend?)"))?; - let client = - redis::Client::open(url).with_context(|| format!("Invalid Redis URL: {url}"))?; - let conn = client - .get_multiplexed_async_connection() - .await - .with_context(|| "Failed to open dedicated Redis connection for BRPOP")?; - Ok(conn) - } } -// -- Generic methods (work with any ConnectionLike backend) ---------------- - +// The generic impl exposes both the production NATS path and a Redis-only +// path used by unit tests with a mock connection. Some methods are only +// exercised in the test build; allow that on the impl as a whole. #[allow(dead_code)] impl TaskQueueCore { - /// Create a queue from any ConnectionLike backend (used in tests). + /// Construct from a Redis backend only — used by unit tests that don't + /// exercise queue methods. Queue methods will return an error. pub fn from_connection(conn: C) -> Self { - Self { - conn, - redis_url: None, - } + Self { conn, nats: None } } - // === Key helpers ======================================================== - - #[inline] - fn task_queue_key(role: &str) -> String { - format!("{TASK_QUEUE_PREFIX}:{role}") + fn nats(&self) -> Result<&NatsBroker> { + self.nats + .as_ref() + .context("TaskQueue has no NATS broker configured") } - #[inline] - fn result_queue_key(task_id: &str) -> String { - format!("{RESULT_QUEUE_PREFIX}:{task_id}") - } + // === Key helpers ======================================================== #[inline] fn heartbeat_key(agent: &str) -> String { @@ -173,13 +158,12 @@ impl TaskQueueCore { format!("{TASK_STATUS_PREFIX}:{task_id}") } - // === Orchestrator methods =============================================== + // === Queue methods (NATS JetStream) ===================================== /// Submit a task to a role's queue. /// - /// Priority <= 2 (urgent) uses RPUSH so the task is consumed first by - /// workers that BRPOP from the right. All other priorities use LPUSH for - /// FIFO order. + /// Priority ≤ 2 publishes to `ares.tasks.urgent.{role}`, otherwise + /// `ares.tasks.{role}`. Workers bind two consumers and prefer urgent. pub async fn submit_task( &self, task_type: &str, @@ -189,7 +173,6 @@ impl TaskQueueCore { priority: i32, ) -> Result { let task_id = format!("{}_{}", task_type, &Uuid::new_v4().to_string()[..12]); - let callback = Self::result_queue_key(&task_id); let msg = TaskMessage { task_id: task_id.clone(), @@ -199,133 +182,169 @@ impl TaskQueueCore { payload, priority, created_at: Some(Utc::now()), - callback_queue: Some(callback), + callback_queue: Some(nats::task_result_subject(&task_id)), }; - let queue_key = Self::task_queue_key(target_role); - let json = serde_json::to_string(&msg).context("Failed to serialize TaskMessage")?; - - let mut conn = self.conn.clone(); - if priority <= 2 { - conn.rpush::<_, _, ()>(&queue_key, &json) - .await - .with_context(|| format!("RPUSH to {queue_key}"))?; - info!(task_id = %task_id, queue = %queue_key, priority, "Urgent task submitted (RPUSH)"); + let subject = if priority <= 2 { + nats::urgent_task_subject(target_role) } else { - conn.lpush::<_, _, ()>(&queue_key, &json) - .await - .with_context(|| format!("LPUSH to {queue_key}"))?; - info!(task_id = %task_id, queue = %queue_key, priority, "Task submitted (LPUSH)"); - } + nats::task_subject(target_role) + }; + let bytes = Bytes::from(serde_json::to_vec(&msg).context("serialize TaskMessage")?); - // Track status - self.set_task_status(&task_id, "pending").await?; + let ack = self + .nats()? + .jetstream() + .publish(subject.clone(), bytes) + .await + .with_context(|| format!("JetStream publish to {subject}"))?; + ack.await + .with_context(|| format!("Awaiting JetStream ack for {subject}"))?; + info!(task_id = %task_id, subject = %subject, priority, "Task submitted"); + self.set_task_status(&task_id, "pending").await?; Ok(task_id) } - /// Non-destructive peek: does a result exist for this task? - pub async fn has_pending_result(&self, task_id: &str) -> Result { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let len: i64 = conn.llen(&key).await.unwrap_or(0); - Ok(len > 0) + /// Non-destructive peek: try to pull a result without consuming it. + /// + /// JetStream WorkQueue retention removes a message on ack, so we never + /// "peek without consuming" — we treat any returned result as "pending" + /// and return it through `check_result` next time. To preserve the old + /// semantic (peek → bool, then consume separately), this method always + /// returns `false` and callers should use `check_result` directly. + /// + /// Kept for API compatibility with the previous Redis implementation. + pub async fn has_pending_result(&self, _task_id: &str) -> Result { + Ok(false) } - /// Non-blocking check for a task result (RPOP). + /// Non-blocking check for a task result. + /// + /// Creates an ephemeral pull consumer filtered to this task's subject, + /// fetches one message with a brief timeout, and deletes the consumer. pub async fn check_result(&self, task_id: &str) -> Result> { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let data: Option = conn.rpop(&key, None).await?; - match data { - Some(json) => { - let result: TaskResult = serde_json::from_str(&json) - .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; - Ok(Some(result)) - } - None => Ok(None), - } + self.fetch_result(task_id, DEFAULT_RESULT_POLL_TIMEOUT) + .await } - /// Batch-check results for multiple task IDs using a pipeline. + /// Batch-check results for multiple task IDs. + /// + /// Iterates per-task; JetStream consumers are per-filter-subject so we + /// can't pipeline like Redis. Callers should not rely on this being a + /// single round-trip. pub async fn check_results_batch( &self, task_ids: &[String], ) -> Result>> { - if task_ids.is_empty() { - return Ok(HashMap::new()); - } - - let mut pipe = redis::pipe(); + let mut out = HashMap::with_capacity(task_ids.len()); for tid in task_ids { - let key = Self::result_queue_key(tid); - pipe.cmd("RPOP").arg(key); + let r = self.check_result(tid).await.unwrap_or_else(|e| { + warn!(task_id = %tid, err = %e, "check_result failed in batch"); + None + }); + out.insert(tid.clone(), r); } + Ok(out) + } - let mut conn = self.conn.clone(); - let raw: Vec> = pipe - .query_async(&mut conn) + /// Fetch a single result for `task_id` from the JetStream result subject. + /// + /// Implementation: ephemeral consumer with `filter_subject` set to the + /// per-task result subject. WorkQueue retention deletes the message + /// on ack, so this is destructive (matches the old RPOP semantics). + async fn fetch_result(&self, task_id: &str, timeout: Duration) -> Result> { + use async_nats::jetstream::consumer::pull::Config as PullConfig; + use async_nats::jetstream::consumer::{AckPolicy, Consumer}; + + let nats = self.nats()?; + let stream = nats + .jetstream() + .get_stream(nats::TASKS_STREAM) .await - .context("Pipeline check_results_batch failed")?; + .with_context(|| format!("get_stream({})", nats::TASKS_STREAM))?; - let mut out = HashMap::with_capacity(task_ids.len()); - for (tid, data) in task_ids.iter().zip(raw) { - let parsed = match data { - Some(json) => match serde_json::from_str::(&json) { - Ok(r) => Some(r), - Err(e) => { - warn!(task_id = %tid, err = %e, "Ignoring malformed TaskResult"); - None - } - }, - None => None, - }; - out.insert(tid.clone(), parsed); - } - Ok(out) - } + let cfg = PullConfig { + filter_subject: nats::task_result_subject(task_id), + ack_policy: AckPolicy::Explicit, + inactive_threshold: Duration::from_secs(60), + ..Default::default() + }; - /// Blocking wait for a result (BRPOP). Timeout in seconds. - pub async fn poll_result( - &self, - task_id: &str, - timeout_secs: f64, - ) -> Result> { - let key = Self::result_queue_key(task_id); - let mut conn = self.conn.clone(); - let result: Option<(String, String)> = conn - .brpop(&key, timeout_secs) + let consumer: Consumer = stream + .create_consumer(cfg) .await - .with_context(|| format!("BRPOP on {key}"))?; + .context("create ephemeral result consumer")?; - match result { - Some((_key, json)) => { - let tr: TaskResult = serde_json::from_str(&json) + let mut fetch = consumer + .fetch() + .max_messages(1) + .expires(timeout.max(Duration::from_millis(50))) + .messages() + .await + .context("start fetch")?; + + let msg = fetch.next().await; + match msg { + Some(Ok(m)) => { + let parsed: TaskResult = serde_json::from_slice(&m.payload) .with_context(|| format!("Bad TaskResult JSON for {task_id}"))?; - Ok(Some(tr)) + m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); + Ok(Some(parsed)) } + Some(Err(e)) => Err(anyhow::anyhow!("JetStream fetch error: {e}")), None => Ok(None), } } - /// Get the length of a role's task queue. - pub async fn queue_length(&self, role: &str) -> Result { - let key = Self::task_queue_key(role); - let mut conn = self.conn.clone(); - let len: usize = conn.llen(&key).await?; - Ok(len) + /// Send a result to the task's result subject (worker side). + pub async fn send_result(&self, task_id: &str, result: &TaskResult) -> Result<()> { + let subject = nats::task_result_subject(task_id); + let bytes = Bytes::from(serde_json::to_vec(result).context("serialize TaskResult")?); + let ack = self + .nats()? + .jetstream() + .publish(subject.clone(), bytes) + .await + .with_context(|| format!("JetStream publish to {subject}"))?; + ack.await + .with_context(|| format!("Awaiting ack for {subject}"))?; + + let final_status = if result.success { + "completed" + } else { + "failed" + }; + debug!( + task_id, + status = final_status, + "Result published; updating status" + ); + self.set_task_status(task_id, final_status).await?; + Ok(()) + } + + /// Publish a state-update notification (NATS core, fire-and-forget). + pub async fn publish_state_update(&self, operation_id: &str) -> Result<()> { + let subject = nats::state_update_subject(operation_id); + self.nats()? + .client() + .publish(subject.clone(), Bytes::from_static(b"updated")) + .await + .with_context(|| format!("PUBLISH to {subject}"))?; + debug!(operation_id, "State update published"); + Ok(()) } + // === Redis-backed state methods (unchanged) ============================ + /// Read heartbeat data for an agent. pub async fn get_heartbeat(&self, agent: &str) -> Result> { let key = Self::heartbeat_key(agent); let mut conn = self.conn.clone(); let data: Option = conn.get(&key).await?; match data { - Some(json) => { - let hb: HeartbeatData = serde_json::from_str(&json)?; - Ok(Some(hb)) - } + Some(json) => Ok(Some(serde_json::from_str(&json)?)), None => Ok(None), } } @@ -355,20 +374,8 @@ impl TaskQueueCore { Ok(()) } - /// Publish a state-update notification on the PubSub channel. - pub async fn publish_state_update(&self, operation_id: &str) -> Result<()> { - let channel = format!("{STATE_UPDATE_CHANNEL_PREFIX}:{operation_id}"); - let mut conn = self.conn.clone(); - conn.publish::<_, _, ()>(&channel, "updated") - .await - .with_context(|| format!("PUBLISH to {channel}"))?; - debug!(operation_id, "State update published"); - Ok(()) - } - // === Operation lock ===================================================== - /// Try to acquire the operation lock. Returns true if acquired. pub async fn try_acquire_lock(&self, operation_id: &str, ttl: Duration) -> Result { let key = format!("{LOCK_PREFIX}:{operation_id}"); let holder = format!( @@ -391,7 +398,6 @@ impl TaskQueueCore { Ok(acquired) } - /// Extend the operation lock TTL. Call periodically to keep it alive. pub async fn extend_lock(&self, operation_id: &str, ttl: Duration) -> Result { let key = format!("{LOCK_PREFIX}:{operation_id}"); let mut conn = self.conn.clone(); @@ -402,22 +408,17 @@ impl TaskQueueCore { Ok(ok) } - // === Task status tracking =============================================== + // === Task status tracking ============================================== - /// Set the status string for a task (with 24h TTL). - /// - /// If a record already exists for this task, preserves existing fields - /// (operation_id, role, task_type, started_at, payload) and updates - /// only the status and timestamps. + /// Update only status + timestamps; preserves any existing fields. pub async fn set_task_status(&self, task_id: &str, status: &str) -> Result<()> { let key = Self::task_status_key(task_id); let mut conn = self.conn.clone(); - // Read-modify-write: preserve existing fields let existing: Option = match conn.get::<_, Option>(&key).await { Ok(v) => v, Err(e) => { - warn!(task_id = task_id, err = %e, "Failed to read existing task status"); + warn!(task_id, err = %e, "Failed to read existing task status"); None } }; @@ -476,7 +477,6 @@ impl TaskQueueCore { Ok(()) } - /// Read task status. pub async fn get_task_status(&self, task_id: &str) -> Result> { let key = Self::task_status_key(task_id); let mut conn = self.conn.clone(); @@ -484,33 +484,14 @@ impl TaskQueueCore { Ok(data) } - /// Get a clone of the underlying connection. - /// - /// Used by the deferred queue to run ZSET commands directly. + /// Get a clone of the underlying Redis connection. pub fn connection(&self) -> C { self.conn.clone() } - /// Send a result to the task's result queue (worker side). - pub async fn send_result(&self, task_id: &str, result: &TaskResult) -> Result<()> { - let key = Self::result_queue_key(task_id); - let json = serde_json::to_string(result)?; - let mut conn = self.conn.clone(); - conn.lpush::<_, _, ()>(&key, &json).await?; - conn.expire::<_, ()>(&key, RESULT_TTL_SECS as i64).await?; - let final_status = if result.success { - "completed" - } else { - "failed" - }; - debug!( - task_id = task_id, - status = final_status, - "Updating task status after send_result" - ); - self.set_task_status(task_id, final_status).await?; - debug!(task_id = task_id, "Task status updated to {}", final_status); - Ok(()) + /// Get a clone of the NATS broker (for callers that need direct access). + pub fn nats_broker(&self) -> Option { + self.nats.clone() } } @@ -523,186 +504,6 @@ mod tests { TaskQueueCore::from_connection(MockRedisConnection::new()) } - #[tokio::test] - async fn submit_task_normal_priority() { - let q = mock_queue(); - let task_id = q - .submit_task( - "recon", - "scanner", - serde_json::json!({"target": "192.168.58.1"}), - "orchestrator", - 5, - ) - .await - .unwrap(); - - assert!(task_id.starts_with("recon_")); - // Task should be in the scanner queue (LPUSH for normal priority) - let len = q.queue_length("scanner").await.unwrap(); - assert_eq!(len, 1); - // Status should be set to pending - let status_json = q.get_task_status(&task_id).await.unwrap().unwrap(); - let status: serde_json::Value = serde_json::from_str(&status_json).unwrap(); - assert_eq!(status["status"], "pending"); - } - - #[tokio::test] - async fn submit_task_urgent_priority() { - let q = mock_queue(); - let task_id = q - .submit_task("crack", "cracker", serde_json::json!({}), "orchestrator", 1) - .await - .unwrap(); - - assert!(task_id.starts_with("crack_")); - let len = q.queue_length("cracker").await.unwrap(); - assert_eq!(len, 1); - } - - #[tokio::test] - async fn urgent_tasks_consumed_first() { - let q = mock_queue(); - // Submit normal first, then urgent - q.submit_task( - "normal", - "worker", - serde_json::json!({"order": 1}), - "orch", - 5, - ) - .await - .unwrap(); - q.submit_task( - "urgent", - "worker", - serde_json::json!({"order": 2}), - "orch", - 1, - ) - .await - .unwrap(); - - // BRPOP consumes from the right — urgent (RPUSH) should come first - let mut conn = q.conn.clone(); - let result: Option<(String, String)> = conn.brpop("ares:tasks:worker", 0.0).await.unwrap(); - let (_, json) = result.unwrap(); - let msg: TaskMessage = serde_json::from_str(&json).unwrap(); - assert!(msg.task_id.starts_with("urgent_")); - } - - #[tokio::test] - async fn has_pending_result_false_when_empty() { - let q = mock_queue(); - assert!(!q.has_pending_result("task-1").await.unwrap()); - } - - #[tokio::test] - async fn send_and_check_result() { - let q = mock_queue(); - let result = TaskResult { - task_id: "task-1".to_string(), - success: true, - result: Some(serde_json::json!({"output": "pwned"})), - error: None, - completed_at: Some(Utc::now()), - worker_pod: None, - agent_name: Some("exploit-agent".to_string()), - }; - q.send_result("task-1", &result).await.unwrap(); - - assert!(q.has_pending_result("task-1").await.unwrap()); - - let checked = q.check_result("task-1").await.unwrap().unwrap(); - assert!(checked.success); - assert_eq!(checked.task_id, "task-1"); - assert_eq!(checked.agent_name.as_deref(), Some("exploit-agent")); - - // After check_result (RPOP), queue should be empty - assert!(!q.has_pending_result("task-1").await.unwrap()); - } - - #[tokio::test] - async fn check_result_returns_none_when_empty() { - let q = mock_queue(); - assert!(q.check_result("nonexistent").await.unwrap().is_none()); - } - - #[tokio::test] - async fn check_results_batch_mixed() { - let q = mock_queue(); - let r1 = TaskResult { - task_id: "t1".to_string(), - success: true, - result: None, - error: None, - completed_at: Some(Utc::now()), - worker_pod: None, - agent_name: None, - }; - q.send_result("t1", &r1).await.unwrap(); - // t2 has no result - - let batch = q - .check_results_batch(&["t1".to_string(), "t2".to_string()]) - .await - .unwrap(); - assert!(batch["t1"].is_some()); - assert!(batch["t2"].is_none()); - } - - #[tokio::test] - async fn check_results_batch_empty_input() { - let q = mock_queue(); - let batch = q.check_results_batch(&[]).await.unwrap(); - assert!(batch.is_empty()); - } - - #[tokio::test] - async fn poll_result_returns_result() { - let q = mock_queue(); - let result = TaskResult { - task_id: "task-poll".to_string(), - success: false, - result: None, - error: Some("timeout".to_string()), - completed_at: Some(Utc::now()), - worker_pod: None, - agent_name: None, - }; - q.send_result("task-poll", &result).await.unwrap(); - - let polled = q.poll_result("task-poll", 0.0).await.unwrap().unwrap(); - assert!(!polled.success); - assert_eq!(polled.error.as_deref(), Some("timeout")); - } - - #[tokio::test] - async fn poll_result_returns_none_when_empty() { - let q = mock_queue(); - // BRPOP on empty queue with 0 timeout returns Nil in mock - let polled = q.poll_result("missing", 0.0).await.unwrap(); - assert!(polled.is_none()); - } - - #[tokio::test] - async fn queue_length_empty() { - let q = mock_queue(); - assert_eq!(q.queue_length("scanner").await.unwrap(), 0); - } - - #[tokio::test] - async fn queue_length_after_submit() { - let q = mock_queue(); - q.submit_task("t1", "role", serde_json::json!({}), "src", 5) - .await - .unwrap(); - q.submit_task("t2", "role", serde_json::json!({}), "src", 5) - .await - .unwrap(); - assert_eq!(q.queue_length("role").await.unwrap(), 2); - } - #[tokio::test] async fn heartbeat_roundtrip() { let q = mock_queue(); @@ -734,13 +535,6 @@ mod tests { assert!(q.get_heartbeat("ghost").await.unwrap().is_none()); } - #[tokio::test] - async fn publish_state_update_succeeds() { - let q = mock_queue(); - // PUBLISH returns 0 in mock (no subscribers) — should not error - q.publish_state_update("op-1").await.unwrap(); - } - #[tokio::test] async fn try_acquire_lock_succeeds() { let q = mock_queue(); @@ -757,7 +551,6 @@ mod tests { q.try_acquire_lock("op-1", Duration::from_secs(30)) .await .unwrap(); - // Second acquire should fail (NX) let acquired = q .try_acquire_lock("op-1", Duration::from_secs(30)) .await @@ -778,17 +571,6 @@ mod tests { assert!(ok); } - #[tokio::test] - async fn extend_lock_fails_when_missing() { - let q = mock_queue(); - // EXPIRE on nonexistent key in real Redis returns false; - // our mock always returns 1, but this tests the code path - let _ok = q - .extend_lock("no-such-op", Duration::from_secs(60)) - .await - .unwrap(); - } - #[tokio::test] async fn set_task_status_creates_record() { let q = mock_queue(); @@ -807,7 +589,6 @@ mod tests { q.set_task_status_full("task-1", "pending", "op-1", "scanner", "recon", None) .await .unwrap(); - // Now update status — should preserve operation_id, role, etc. q.set_task_status("task-1", "in_progress").await.unwrap(); let raw = q.get_task_status("task-1").await.unwrap().unwrap(); @@ -822,7 +603,6 @@ mod tests { async fn set_task_status_completed_adds_ended_at() { let q = mock_queue(); q.set_task_status("task-1", "completed").await.unwrap(); - let raw = q.get_task_status("task-1").await.unwrap().unwrap(); let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); assert_eq!(v["status"], "completed"); @@ -833,7 +613,6 @@ mod tests { async fn set_task_status_failed_adds_ended_at() { let q = mock_queue(); q.set_task_status("task-1", "failed").await.unwrap(); - let raw = q.get_task_status("task-1").await.unwrap().unwrap(); let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); assert_eq!(v["status"], "failed"); @@ -868,60 +647,6 @@ mod tests { assert!(q.get_task_status("nonexistent").await.unwrap().is_none()); } - #[tokio::test] - async fn send_result_sets_completed_status() { - let q = mock_queue(); - q.set_task_status("task-1", "in_progress").await.unwrap(); - - let result = TaskResult { - task_id: "task-1".to_string(), - success: true, - result: None, - error: None, - completed_at: Some(Utc::now()), - worker_pod: None, - agent_name: None, - }; - q.send_result("task-1", &result).await.unwrap(); - - let raw = q.get_task_status("task-1").await.unwrap().unwrap(); - let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); - assert_eq!(v["status"], "completed"); - } - - #[tokio::test] - async fn send_result_sets_failed_status() { - let q = mock_queue(); - let result = TaskResult { - task_id: "task-1".to_string(), - success: false, - result: None, - error: Some("boom".to_string()), - completed_at: Some(Utc::now()), - worker_pod: None, - agent_name: None, - }; - q.send_result("task-1", &result).await.unwrap(); - - let raw = q.get_task_status("task-1").await.unwrap().unwrap(); - let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); - assert_eq!(v["status"], "failed"); - } - - #[tokio::test] - async fn connection_returns_clone() { - let q = mock_queue(); - let mut conn = q.connection(); - // Should be usable as AsyncCommands - let _: () = redis::AsyncCommands::set(&mut conn, "test-key", "test-val") - .await - .unwrap(); - let val: String = redis::AsyncCommands::get(&mut conn, "test-key") - .await - .unwrap(); - assert_eq!(val, "test-val"); - } - #[tokio::test] async fn task_message_serialization() { let msg = TaskMessage { @@ -932,7 +657,7 @@ mod tests { payload: serde_json::json!({"host": "192.168.58.1"}), priority: 5, created_at: None, - callback_queue: Some("ares:results:test_abc".to_string()), + callback_queue: Some("ares.tasks.results.test_abc".to_string()), }; let json = serde_json::to_string(&msg).unwrap(); let parsed: TaskMessage = serde_json::from_str(&json).unwrap(); @@ -960,7 +685,6 @@ mod tests { #[tokio::test] async fn task_result_deserialization_defaults() { - // Minimal JSON — optional fields should default let json = r#"{"task_id":"t1","success":false,"completed_at":null}"#; let parsed: TaskResult = serde_json::from_str(json).unwrap(); assert!(!parsed.success); @@ -984,4 +708,14 @@ mod tests { assert!(parsed.current_task.is_none()); assert_eq!(parsed.pod_name.as_deref(), Some("pod-x")); } + + #[tokio::test] + async fn nats_required_for_queue_methods() { + let q = mock_queue(); + let err = q + .submit_task("recon", "scanner", serde_json::json!({}), "orch", 5) + .await + .unwrap_err(); + assert!(err.to_string().contains("NATS")); + } } diff --git a/ares-cli/src/orchestrator/throttling.rs b/ares-cli/src/orchestrator/throttling.rs index ff4ecee8..48b7c87a 100644 --- a/ares-cli/src/orchestrator/throttling.rs +++ b/ares-cli/src/orchestrator/throttling.rs @@ -262,6 +262,7 @@ mod tests { fn make_throttler(max_tasks: usize) -> (Throttler, ActiveTaskTracker) { let config = Arc::new(crate::orchestrator::config::OrchestratorConfig { redis_url: "redis://localhost".into(), + nats_url: "nats://localhost:4222".into(), operation_id: "test-op".into(), max_concurrent_tasks: max_tasks, heartbeat_interval: std::time::Duration::from_secs(30), diff --git a/ares-cli/src/orchestrator/tool_dispatcher/mod.rs b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs index 0e8d4155..5986df3f 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/mod.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs @@ -1,14 +1,14 @@ -//! Redis-backed tool dispatcher for the LLM agent loop. +//! NATS-backed tool dispatcher for the LLM agent loop. //! -//! Implements `ares_llm::ToolDispatcher` by pushing individual tool calls -//! to a Redis queue (`ares:tool_exec:{role}`) and waiting for results -//! on a per-call mailbox (`ares:tool_results:{call_id}`). +//! Implements `ares_llm::ToolDispatcher` by issuing a NATS request to +//! `ares.tools.exec.{role}` and awaiting the worker reply on the +//! auto-generated reply inbox. //! -//! Rust workers run a tool executor that BRPOPs from `tool_exec`, -//! invokes the tool via `ares_tools::dispatch`, and LPUSHes the result. +//! Rust workers subscribe to `ares.tools.exec.{role}` as a queue group, +//! invoke the tool via `ares_tools::dispatch`, and reply on the inbox. //! //! Also provides [`LocalToolDispatcher`] for in-process execution without -//! going through Redis, useful for testing or single-binary deployments. +//! going through NATS, useful for testing or single-binary deployments. use redis::AsyncCommands; use serde::{Deserialize, Serialize}; @@ -53,15 +53,6 @@ pub struct ToolExecResponse { pub discoveries: Option, } -/// Prefix for tool execution request queues. -pub(super) const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; - -/// Prefix for per-call result mailboxes. -pub(super) const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; - -/// TTL for result keys (1 hour). -pub(super) const RESULT_TTL_SECS: u64 = 3600; - /// Default timeout waiting for a tool result (25 minutes). /// Must exceed queue wait time + longest tool runtime (hashcat can queue /// behind another hashcat, so 2x runtime + buffer). diff --git a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs index ed20330c..96db2961 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs @@ -1,9 +1,19 @@ -//! Redis-backed tool dispatcher. +//! NATS-backed tool dispatcher. +//! +//! Each tool call becomes a NATS request to `ares.tools.exec.{role}` with +//! an auto-generated reply inbox; the worker subscribes to that subject as +//! a queue group and replies on the inbox. This replaces the old Redis +//! BRPOP pattern, eliminating the dedicated-connection-per-waiter +//! requirement (a single multiplexed NATS connection handles arbitrary +//! concurrent request/reply pairs). + +use std::time::Duration; use anyhow::{Context, Result}; -use redis::AsyncCommands; +use bytes::Bytes; use tracing::{debug, warn, Instrument}; +use ares_core::nats; use ares_core::telemetry::propagation::inject_traceparent; use ares_core::telemetry::spans::{producer_span, Team}; use ares_llm::{ToolCall, ToolExecResult}; @@ -12,10 +22,10 @@ use crate::orchestrator::task_queue::TaskQueue; use super::{ extract_credential_key, push_realtime_discoveries, AuthThrottle, ToolExecRequest, - ToolExecResponse, RESULT_TTL_SECS, TOOL_EXEC_PREFIX, TOOL_RESULT_PREFIX, + ToolExecResponse, }; -/// Dispatches tool calls to workers via Redis queues. +/// Dispatches tool calls to workers via NATS request/reply. /// /// When tool results contain structured discoveries (hosts, credentials, etc.), /// they are pushed to the `ares:discoveries:{op_id}` list for real-time @@ -23,7 +33,7 @@ use super::{ /// immediately rather than waiting for the task result consumer. pub struct RedisToolDispatcher { pub(super) queue: TaskQueue, - pub(super) tool_timeout: std::time::Duration, + pub(super) tool_timeout: Duration, pub(super) operation_id: String, pub(super) auth_throttle: AuthThrottle, } @@ -32,7 +42,7 @@ impl RedisToolDispatcher { pub fn new(queue: TaskQueue, operation_id: String, auth_throttle: AuthThrottle) -> Self { Self { queue, - tool_timeout: std::time::Duration::from_secs(super::DEFAULT_TOOL_TIMEOUT_SECS), + tool_timeout: Duration::from_secs(super::DEFAULT_TOOL_TIMEOUT_SECS), operation_id, auth_throttle, } @@ -75,105 +85,93 @@ impl ares_llm::ToolDispatcher for RedisToolDispatcher { operation_id: Some(self.operation_id.clone()), }; - let queue_key = format!("{TOOL_EXEC_PREFIX}:{effective_role}"); - let result_key = format!("{TOOL_RESULT_PREFIX}:{call_id}"); + let subject = nats::tool_exec_subject(effective_role); let payload = - serde_json::to_string(&request).context("Failed to serialize tool exec request")?; + serde_json::to_vec(&request).context("Failed to serialize tool exec request")?; debug!( tool = %call.name, call_id = %call_id, - queue = %queue_key, + subject = %subject, effective_role = %effective_role, "Dispatching tool call to worker" ); - // Push request to worker queue (shared multiplexed connection is fine for LPUSH) - let mut conn = self.queue.connection(); - conn.lpush::<_, _, ()>(&queue_key, &payload) - .await - .context("Failed to push tool exec request to Redis")?; - - // BRPOP needs a dedicated connection: it blocks its TCP connection - // until a result arrives, so a shared multiplexed connection would - // serialize all concurrent agent loops behind one waiter. - let timeout_secs = self.tool_timeout.as_secs().max(1) as f64; - let brpop_result: Option<(String, String)> = match self.queue.dedicated_connection().await { - Ok(mut dedicated) => { - redis::cmd("BRPOP") - .arg(&result_key) - .arg(timeout_secs) - .query_async(&mut dedicated) - .await - .context("BRPOP failed for tool result")? - } - Err(e) => { - // Fall back to shared connection if dedicated fails - warn!(err = %e, "Failed to open dedicated BRPOP connection, falling back to shared"); - redis::cmd("BRPOP") - .arg(&result_key) - .arg(timeout_secs) - .query_async(&mut conn) - .await - .context("BRPOP failed for tool result")? - } - }; - - match brpop_result { - Some((_key, value)) => { - let response: ToolExecResponse = serde_json::from_str(&value) - .context("Failed to deserialize tool exec response")?; - - debug!( + let nats = self + .queue + .nats_broker() + .context("ToolDispatcher requires NATS broker")?; + let client = nats.client().clone(); + + let timeout = self.tool_timeout; + let response_msg = match tokio::time::timeout( + timeout, + client.request(subject.clone(), Bytes::from(payload)), + ) + .await + { + Ok(Ok(msg)) => msg, + Ok(Err(e)) => { + warn!( tool = %call.name, call_id = %call_id, - has_error = response.error.is_some(), - "Tool result received" + err = %e, + "NATS request failed" ); - - // Push discoveries to the real-time discovery list so - // the discovery poller publishes them to state immediately, - // independent of the task result consumer. - if let Some(ref disc) = response.discoveries { - push_realtime_discoveries( - &self.queue, - &self.operation_id, - disc, - &call.name, - &call.arguments, - ) - .await; - } - - Ok(ToolExecResult { - output: response.output, - error: response.error, - discoveries: response.discoveries, - }) + return Ok(ToolExecResult { + output: String::new(), + error: Some(format!("Tool '{}' dispatch error: {e}", call.name)), + discoveries: None, + }); } - None => { + Err(_) => { warn!( tool = %call.name, call_id = %call_id, - timeout_secs = timeout_secs, + timeout_secs = timeout.as_secs(), "Tool execution timed out" ); - - // Clean up any late result - let _: Result<(), _> = conn - .expire::<_, ()>(&result_key, RESULT_TTL_SECS as i64) - .await; - - Ok(ToolExecResult { + return Ok(ToolExecResult { output: String::new(), error: Some(format!( - "Tool '{}' timed out after {timeout_secs}s", - call.name + "Tool '{}' timed out after {}s", + call.name, + timeout.as_secs() )), discoveries: None, - }) + }); } + }; + + let response: ToolExecResponse = serde_json::from_slice(&response_msg.payload) + .context("Failed to deserialize tool exec response")?; + + debug!( + tool = %call.name, + call_id = %call_id, + has_error = response.error.is_some(), + "Tool result received" + ); + + // Push discoveries to the real-time discovery list so the + // discovery poller publishes them to state immediately, + // independent of the task result consumer. + if let Some(ref disc) = response.discoveries { + push_realtime_discoveries( + &self.queue, + &self.operation_id, + disc, + &call.name, + &call.arguments, + ) + .await; } + + Ok(ToolExecResult { + output: response.output, + error: response.error, + discoveries: response.discoveries, + }) } .instrument(span) .await diff --git a/ares-cli/src/worker/blue_task_loop.rs b/ares-cli/src/worker/blue_task_loop.rs index 1ec0dd23..717091d9 100644 --- a/ares-cli/src/worker/blue_task_loop.rs +++ b/ares-cli/src/worker/blue_task_loop.rs @@ -17,6 +17,7 @@ use std::time::Duration; use anyhow::Result; use tracing::{debug, error, info, warn}; +use ares_core::nats::NatsBroker; use ares_core::state::blue_task_queue::{BlueTaskMessage, BlueTaskQueue, BlueTaskResult}; use ares_llm::tool_registry::blue::{self, BlueAgentRole}; use ares_llm::{run_agent_loop, AgentLoopConfig, LlmProvider, LoopEndReason, ToolDispatcher}; @@ -25,9 +26,11 @@ use crate::worker::config::WorkerConfig; use crate::worker::heartbeat::WorkerStatus; /// Run the blue team task consumption loop until shutdown. +#[allow(clippy::too_many_arguments)] pub async fn run_blue_task_loop( config: &WorkerConfig, conn: redis::aio::ConnectionManager, + nats: NatsBroker, provider: Box, dispatcher: Arc, model_name: String, @@ -43,7 +46,7 @@ pub async fn run_blue_task_loop( "Starting blue team task loop" ); - let mut task_queue = BlueTaskQueue::from_conn(conn); + let mut task_queue = BlueTaskQueue::from_parts(conn, nats); let mut retry_delay = Duration::from_secs(1); let max_retry_delay = Duration::from_secs(60); diff --git a/ares-cli/src/worker/config.rs b/ares-cli/src/worker/config.rs index e772131e..c90de086 100644 --- a/ares-cli/src/worker/config.rs +++ b/ares-cli/src/worker/config.rs @@ -33,6 +33,9 @@ pub struct WorkerConfig { /// Redis connection URL (ARES_REDIS_URL). pub redis_url: String, + /// NATS connection URL (ARES_NATS_URL). + pub nats_url: String, + /// Worker role matching `AgentRole` values: credential_access, cracker, lateral, acl, privesc, coercion. pub worker_role: String, @@ -97,6 +100,8 @@ impl WorkerConfig { anyhow::anyhow!("Redis URL required: set ARES_REDIS_URL, REDIS_URL, or REDIS_HOST") })?; + let nats_url = ares_core::nats::NatsBroker::url_from_env(); + let worker_role = env::var("ARES_WORKER_ROLE") .or_else(|_| env::var("ARES_ROLE")) .map_err(|_| anyhow::anyhow!("ARES_WORKER_ROLE (or ARES_ROLE) is required"))?; @@ -146,6 +151,7 @@ impl WorkerConfig { Ok(Self { redis_url, + nats_url, worker_role, pod_name, agent_name, diff --git a/ares-cli/src/worker/mod.rs b/ares-cli/src/worker/mod.rs index bf798649..15c13781 100644 --- a/ares-cli/src/worker/mod.rs +++ b/ares-cli/src/worker/mod.rs @@ -50,11 +50,8 @@ pub async fn run() -> anyhow::Result<()> { "Ares worker starting" ); - // Single shared Redis connection — cloned cheaply to all subsystems - // Default response_timeout is 500ms which is too short for BRPOP - // blocking calls (5s+). Without this, the client-side timeout cancels - // the future but the server-side BRPOP remains, consuming queue items - // that get silently dropped. + // Single shared Redis connection (state only — heartbeats, task status, + // token usage, hosts sync). Queue traffic moved to NATS JetStream. let redis_client = redis::Client::open(config.redis_url.as_str())?; let cm_config = redis::aio::ConnectionManagerConfig::new() .set_response_timeout(Some(std::time::Duration::from_secs(30))); @@ -62,6 +59,10 @@ pub async fn run() -> anyhow::Result<()> { .get_connection_manager_with_config(cm_config) .await?; + // Single shared NATS connection — multiplexes work-queue pulls, result + // publishes, and tool-exec request/reply over one TCP connection. + let nats = ares_core::nats::NatsBroker::connect(&config.nats_url).await?; + // Shared shutdown signal let shutdown = Arc::new(tokio::sync::Notify::new()); let shutdown_signal = Arc::clone(&shutdown); @@ -103,10 +104,17 @@ pub async fn run() -> anyhow::Result<()> { // Run the appropriate loop based on worker mode let result = match config.mode { config::WorkerMode::Task => { - task_loop::run_task_loop(&config, conn, status_tx, shutdown_signal).await + task_loop::run_task_loop(&config, conn, nats.clone(), status_tx, shutdown_signal).await } config::WorkerMode::ToolExec => { - tool_executor::run_tool_exec_loop(&config, conn, status_tx, shutdown_signal).await + tool_executor::run_tool_exec_loop( + &config, + conn, + nats.clone(), + status_tx, + shutdown_signal, + ) + .await } #[cfg(feature = "blue")] config::WorkerMode::BlueTask => { @@ -125,6 +133,7 @@ pub async fn run() -> anyhow::Result<()> { blue_task_loop::run_blue_task_loop( &config, conn, + nats.clone(), provider, dispatcher, model_name, diff --git a/ares-cli/src/worker/task_loop/mod.rs b/ares-cli/src/worker/task_loop/mod.rs index 129db781..163f1202 100644 --- a/ares-cli/src/worker/task_loop/mod.rs +++ b/ares-cli/src/worker/task_loop/mod.rs @@ -1,16 +1,17 @@ -//! Core task consumption loop. +//! Core task consumption loop (NATS JetStream). //! //! ```text //! loop { -//! 1. BRPOP from ares:tasks:{role} +//! 1. Pull batch from urgent + normal subjects on the role queue //! 2. Deserialize TaskMessage -//! 3. Update task status to "running" +//! 3. Update task status to "running" (Redis) //! 4. Execute agent task (native Rust) //! 5. Parse result //! 6. Serialize TaskResult -//! 7. LPUSH to ares:results:{task_id} -//! 8. Update task status to "completed" or "failed" -//! 9. Refresh heartbeat status +//! 7. JetStream publish to ares.tasks.results.{task_id} +//! 8. Update task status to "completed" or "failed" (Redis) +//! 9. Refresh heartbeat status (Redis) +//! 10. Ack JetStream message //! } //! ``` @@ -23,49 +24,46 @@ use types::TaskMessage; use std::sync::Arc; use std::time::Duration; +use anyhow::Context; +use async_nats::jetstream::consumer::pull::Config as PullConfig; +use async_nats::jetstream::consumer::{AckPolicy, Consumer}; +use futures::StreamExt; use tracing::{debug, error, info, warn}; +use ares_core::nats::{self, NatsBroker}; + use crate::worker::config::WorkerConfig; use crate::worker::heartbeat::WorkerStatus; -// ─── Redis key prefixes (must match Python's RedisTaskQueue) ───────────────── - -const TASK_QUEUE_PREFIX: &str = "ares:tasks"; -const RESULT_QUEUE_PREFIX: &str = "ares:results"; -const TASK_STATUS_PREFIX: &str = "ares:task_status"; - /// TTL for task status keys — 24 hours, matches Python. const TASK_STATUS_TTL: i64 = 60 * 60 * 24; -/// TTL for result keys — 24 hours, matches Python's `RESULT_TTL`. -const RESULT_TTL: i64 = 60 * 60 * 24; - // ─── Task loop ─────────────────────────────────────────────────────────────── /// Run the main task consumption loop until shutdown is signalled. pub async fn run_task_loop( config: &WorkerConfig, - conn: redis::aio::ConnectionManager, + redis_conn: redis::aio::ConnectionManager, + nats: NatsBroker, status_tx: tokio::sync::watch::Sender, shutdown: Arc, ) -> anyhow::Result<()> { - let queue_key = format!("{TASK_QUEUE_PREFIX}:{}", config.worker_role); + let urgent_consumer = ensure_role_consumer(&nats, &config.worker_role, true).await?; + let normal_consumer = ensure_role_consumer(&nats, &config.worker_role, false).await?; + info!( - queue = %queue_key, + role = %config.worker_role, agent = %config.agent_name, - "Starting task loop" + "Starting task loop (NATS JetStream)" ); - let mut conn = conn; - - // Exponential backoff state for connection errors + let mut redis_conn = redis_conn; let mut retry_delay = Duration::from_secs(1); let max_retry_delay = Duration::from_secs(60); loop { - // Race BRPOP against shutdown signal let poll_result = tokio::select! { - result = poll_task(&mut conn, &queue_key, config.poll_timeout) => result, + r = poll_one_task(&urgent_consumer, &normal_consumer, config.poll_timeout) => r, _ = shutdown.notified() => { info!("Task loop: shutdown signalled, finishing"); break; @@ -73,27 +71,26 @@ pub async fn run_task_loop( }; match poll_result { - Ok(Some(task)) => { - // Reset backoff on successful poll + Ok(Some((task, msg))) => { retry_delay = Duration::from_secs(1); - // Update heartbeat status to busy let _ = status_tx.send(WorkerStatus { status: "busy".to_string(), current_task: Some(task.task_id.clone()), }); - // Execute the task — runs to completion even if shutdown arrives mid-task - result_handler::process_task(&mut conn, config, &task).await; + result_handler::process_task(&mut redis_conn, &nats, config, &task).await; + + if let Err(e) = msg.ack().await { + warn!(task_id = %task.task_id, "Failed to ack JetStream message: {e}"); + } - // Update heartbeat status back to idle let _ = status_tx.send(WorkerStatus { status: "idle".to_string(), current_task: None, }); } Ok(None) => { - // No task available (BRPOP timeout), just loop retry_delay = Duration::from_secs(1); } Err(e) => { @@ -105,15 +102,15 @@ pub async fn run_task_loop( "timeout", "broken pipe", "reset", + "no responders", ] .iter() .any(|kw| error_str.contains(kw)); if is_conn_error { - // ConnectionManager auto-reconnects; just back off before retrying warn!( delay_secs = retry_delay.as_secs(), - "Task loop: connection error, retrying: {e}" + "Task loop: transient broker error, retrying: {e}" ); tokio::select! { _ = tokio::time::sleep(retry_delay) => {} @@ -121,7 +118,7 @@ pub async fn run_task_loop( } retry_delay = (retry_delay * 2).min(max_retry_delay); } else { - error!("Task loop: non-connection error: {e}"); + error!("Task loop: error: {e}"); tokio::select! { _ = tokio::time::sleep(Duration::from_secs(5)) => {} _ = shutdown.notified() => break, @@ -135,26 +132,79 @@ pub async fn run_task_loop( Ok(()) } -/// BRPOP from the task queue with timeout. -/// Returns `Ok(None)` on timeout (no task available). -async fn poll_task( - conn: &mut redis::aio::ConnectionManager, - queue_key: &str, +/// Re-export TTL constant for the result handler. +pub(crate) const fn task_status_ttl() -> i64 { + TASK_STATUS_TTL +} + +/// Ensure a durable pull consumer exists for the given (role, urgency). +async fn ensure_role_consumer( + nats: &NatsBroker, + role: &str, + urgent: bool, +) -> anyhow::Result> { + let (filter_subject, suffix) = if urgent { + (nats::urgent_task_subject(role), "urgent") + } else { + (nats::task_subject(role), "normal") + }; + + let durable_name = format!("ares-worker-{role}-{suffix}"); + let stream = nats + .jetstream() + .get_stream(nats::TASKS_STREAM) + .await + .with_context(|| format!("get_stream({})", nats::TASKS_STREAM))?; + + let cfg = PullConfig { + durable_name: Some(durable_name.clone()), + filter_subject, + ack_policy: AckPolicy::Explicit, + ack_wait: Duration::from_secs(60 * 30), + max_deliver: 5, + ..Default::default() + }; + + let consumer = stream + .get_or_create_consumer(&durable_name, cfg) + .await + .with_context(|| format!("ensure consumer {durable_name}"))?; + Ok(consumer) +} + +/// Pull one message, preferring urgent. Returns Ok(None) on idle timeout. +async fn poll_one_task( + urgent: &Consumer, + normal: &Consumer, timeout: Duration, -) -> anyhow::Result> { - // BRPOP returns Option<(key, value)> - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(queue_key) - .arg(timeout.as_secs() as i64) - .query_async(conn) - .await?; - - match result { - Some((_key, data)) => { - let task: TaskMessage = serde_json::from_str(&data)?; +) -> anyhow::Result> { + // Try urgent with a tiny expiry first; fall back to normal if empty. + if let Some(item) = fetch_one(urgent, Duration::from_millis(50)).await? { + return Ok(Some(item)); + } + fetch_one(normal, timeout).await +} + +async fn fetch_one( + consumer: &Consumer, + expires: Duration, +) -> anyhow::Result> { + let mut batch = consumer + .fetch() + .max_messages(1) + .expires(expires.max(Duration::from_millis(50))) + .messages() + .await + .context("start fetch")?; + + match batch.next().await { + Some(Ok(msg)) => { + let task: TaskMessage = + serde_json::from_slice(&msg.payload).context("deserialize TaskMessage")?; debug!(task_id = %task.task_id, task_type = %task.task_type, "Received task"); - Ok(Some(task)) + Ok(Some((task, msg))) } + Some(Err(e)) => Err(anyhow::anyhow!("JetStream fetch error: {e}")), None => Ok(None), } } @@ -223,14 +273,6 @@ mod tests { fn task_result_skip_serializing_none() { let r = TaskResult::success("t1", serde_json::json!("ok"), "pod", "agent"); let json = serde_json::to_string(&r).unwrap(); - // error field should be absent (skip_serializing_if = "Option::is_none") assert!(!json.contains("\"error\"")); } - - #[test] - fn redis_key_prefixes() { - assert_eq!(TASK_QUEUE_PREFIX, "ares:tasks"); - assert_eq!(RESULT_QUEUE_PREFIX, "ares:results"); - assert_eq!(TASK_STATUS_PREFIX, "ares:task_status"); - } } diff --git a/ares-cli/src/worker/task_loop/result_handler.rs b/ares-cli/src/worker/task_loop/result_handler.rs index a185d89d..d1643fcf 100644 --- a/ares-cli/src/worker/task_loop/result_handler.rs +++ b/ares-cli/src/worker/task_loop/result_handler.rs @@ -1,20 +1,25 @@ -//! Result processing — build TaskResult, push to Redis, track token usage. +//! Result processing — build TaskResult, publish to NATS, track token usage. +use bytes::Bytes; use chrono::Utc; use redis::AsyncCommands; use tracing::{debug, error, info, warn}; +use ares_core::nats::{self, NatsBroker}; use ares_core::token_usage; use crate::worker::config::WorkerConfig; use super::executor::run_agent_task; +use super::task_status_ttl; use super::types::{TaskMessage, TaskResult}; -use super::{RESULT_QUEUE_PREFIX, RESULT_TTL, TASK_STATUS_PREFIX, TASK_STATUS_TTL}; -/// Process a single task: set status, run agent, push result. +const TASK_STATUS_PREFIX: &str = "ares:task_status"; + +/// Process a single task: set status, run agent, publish result. pub async fn process_task( conn: &mut redis::aio::ConnectionManager, + nats: &NatsBroker, config: &WorkerConfig, task: &TaskMessage, ) { @@ -27,7 +32,6 @@ pub async fn process_task( "Processing task" ); - // 1. Set task status to "running" if let Err(e) = set_task_status( conn, &task.task_id, @@ -47,17 +51,13 @@ pub async fn process_task( warn!(task_id = %task.task_id, "Failed to set task status to running: {e}"); } - // 2. Run the agent task let agent_result = run_agent_task(&task.task_type, &task.payload, config.task_timeout).await; - // 3. Extract token usage before consuming agent_result (for Redis tracking) let usage_for_tracking = agent_result.as_ref().ok().and_then(|ar| ar.usage.clone()); - // 4. Build the result let (task_result, final_status) = match agent_result { Ok(ar) => { if let Some(ref err) = ar.error { - // Agent returned an error (e.g., unsupported task, max steps, model refusal) let result_payload = serde_json::json!({ "output": ar.output, "task_type": task.task_type, @@ -77,11 +77,9 @@ pub async fn process_task( "output": ar.output, "task_type": task.task_type, }); - // Include usage metrics if available if let Some(ref usage) = ar.usage { result_payload["usage"] = serde_json::to_value(usage).unwrap_or_default(); } - // Include structured discoveries parsed from tool output if let Some(ref disc) = ar.discoveries { if let Some(obj) = disc.as_object() { for (k, v) in obj { @@ -119,7 +117,6 @@ pub async fn process_task( } }; - // 5. Accumulate token usage to Redis (best-effort, never fails the task) if let Some(ref usage) = usage_for_tracking { if usage.total_tokens > 0 { if let Some(ref op_id) = config.operation_id { @@ -139,12 +136,23 @@ pub async fn process_task( } } - // 6. LPUSH result to ares:results:{task_id} - let result_key = format!("{RESULT_QUEUE_PREFIX}:{}", task.task_id); - match serde_json::to_string(&task_result) { - Ok(result_json) => { - if let Err(e) = push_result(conn, &result_key, &result_json).await { - error!(task_id = %task.task_id, "Failed to push result: {e}"); + // Publish result to JetStream result subject + match serde_json::to_vec(&task_result) { + Ok(bytes) => { + let subject = nats::task_result_subject(&task.task_id); + match nats + .jetstream() + .publish(subject.clone(), Bytes::from(bytes)) + .await + { + Ok(ack) => { + if let Err(e) = ack.await { + error!(task_id = %task.task_id, subject = %subject, "JetStream ack failed: {e}"); + } + } + Err(e) => { + error!(task_id = %task.task_id, subject = %subject, "Failed to publish result: {e}"); + } } } Err(e) => { @@ -152,7 +160,6 @@ pub async fn process_task( } } - // 7. Update task status to final state if let Err(e) = set_task_status( conn, &task.task_id, @@ -177,19 +184,7 @@ pub async fn process_task( } } -/// Push a result to the result queue and set TTL. -async fn push_result( - conn: &mut redis::aio::ConnectionManager, - result_key: &str, - result_json: &str, -) -> anyhow::Result<()> { - conn.lpush::<_, _, ()>(result_key, result_json).await?; - conn.expire::<_, ()>(result_key, RESULT_TTL).await?; - Ok(()) -} - /// Set task status in Redis with TTL. -/// Matches Python's `set_task_status` — writes JSON to `ares:task_status:{task_id}`. async fn set_task_status( conn: &mut redis::aio::ConnectionManager, task_id: &str, @@ -209,7 +204,7 @@ async fn set_task_status( ); } let json_str = serde_json::to_string(&data)?; - conn.set_ex::<_, _, ()>(&key, &json_str, TASK_STATUS_TTL as u64) + conn.set_ex::<_, _, ()>(&key, &json_str, task_status_ttl() as u64) .await?; Ok(()) } diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs index 2dcbdf69..09111d85 100644 --- a/ares-cli/src/worker/tool_executor.rs +++ b/ares-cli/src/worker/tool_executor.rs @@ -1,29 +1,29 @@ //! Thin tool executor loop for LLM-driven orchestration. //! //! When the Rust orchestrator drives agent loops via `ARES_LLM_MODEL`, it -//! dispatches individual tool calls to `ares:tool_exec:{role}` and waits -//! for results on `ares:tool_results:{call_id}`. -//! -//! This module implements the worker-side consumer: +//! issues a NATS request to `ares.tools.exec.{role}`. Workers subscribe as +//! a queue group so each request goes to exactly one worker, and reply on +//! the auto-generated reply inbox. //! //! ```text //! loop { -//! 1. BRPOP from ares:tool_exec:{role} +//! 1. Receive NATS request on ares.tools.exec.{role} (queue group) //! 2. Deserialize ToolExecRequest //! 3. Execute tool via ares_tools::dispatch() //! 4. Serialize ToolExecResponse -//! 5. LPUSH to ares:tool_results:{call_id} +//! 5. Reply on msg.reply inbox //! } //! ``` //! use std::sync::Arc; -use std::time::Duration; -use redis::AsyncCommands; +use bytes::Bytes; +use futures::StreamExt; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info, warn, Instrument}; +use ares_core::nats::{self, NatsBroker}; use ares_core::telemetry::propagation::set_span_parent; use ares_core::telemetry::spans::{trace_discovery, AgentSpanBuilder, SpanKind, Team}; use ares_core::telemetry::target::{extract_target_info, infer_target_type_from_info}; @@ -31,14 +31,6 @@ use ares_core::telemetry::target::{extract_target_info, infer_target_type_from_i use crate::worker::config::WorkerConfig; use crate::worker::heartbeat::WorkerStatus; -// ─── Redis key prefixes (must match orchestrator's tool_dispatcher.rs) ─────── - -const TOOL_EXEC_PREFIX: &str = "ares:tool_exec"; -const TOOL_RESULT_PREFIX: &str = "ares:tool_results"; - -/// TTL for result keys (1 hour) — matches orchestrator's RESULT_TTL_SECS. -const RESULT_TTL: i64 = 3600; - // ─── Wire types (match orchestrator's tool_dispatcher.rs exactly) ──────────── /// Request from the orchestrator's RedisToolDispatcher. @@ -71,165 +63,107 @@ struct ToolExecResponse { /// Run the tool execution loop until shutdown is signalled. /// -/// Consumes individual tool call requests from `ares:tool_exec:{role}` and -/// dispatches them directly to `ares_tools::dispatch()`. Results are pushed -/// back to the per-call mailbox `ares:tool_results:{call_id}`. +/// Subscribes to `ares.tools.exec.{role}` as a queue group so each request +/// goes to exactly one worker. Replies on the request's reply inbox. pub async fn run_tool_exec_loop( config: &WorkerConfig, - conn: redis::aio::ConnectionManager, + _conn: redis::aio::ConnectionManager, + nats: NatsBroker, status_tx: tokio::sync::watch::Sender, shutdown: Arc, ) -> anyhow::Result<()> { - let queue_key = format!("{TOOL_EXEC_PREFIX}:{}", config.worker_role); + let subject = nats::tool_exec_subject(&config.worker_role); + let queue_group = format!("ares-tools-{}", config.worker_role); + + let client = nats.client().clone(); + let mut sub = client + .queue_subscribe(subject.clone(), queue_group.clone()) + .await?; info!( - queue = %queue_key, + subject = %subject, + queue_group = %queue_group, agent = %config.agent_name, - "Starting tool executor loop" + "Starting tool executor loop (NATS queue subscribe)" ); - let mut conn = conn; - - // Track tools that failed with "not installed" so we can short-circuit - // future calls immediately without attempting to spawn the binary. let mut unavailable_tools: std::collections::HashSet = std::collections::HashSet::new(); - // Exponential backoff state for connection errors - let mut retry_delay = Duration::from_secs(1); - let max_retry_delay = Duration::from_secs(60); - loop { - // Check for shutdown via select with zero-timeout - let poll_result = tokio::select! { - result = poll_tool_request(&mut conn, &queue_key, config.poll_timeout) => result, + let next = tokio::select! { + m = sub.next() => m, _ = shutdown.notified() => { info!("Tool executor: shutdown signalled, finishing"); return Ok(()); } }; - match poll_result { - Ok(Some(request)) => { - retry_delay = Duration::from_secs(1); - - // Update heartbeat to busy - let _ = status_tx.send(WorkerStatus { - status: "busy".to_string(), - current_task: Some(format!("{}:{}", request.tool_name, request.call_id)), - }); - - let ti = extract_target_info(&request.arguments); - let tt = infer_target_type_from_info(&ti); - let mut span_builder = - AgentSpanBuilder::new("tool_exec", &config.worker_role, Team::Red) - .tool(&request.tool_name) - .kind(SpanKind::Consumer); - if let Some(ref ip) = ti.target_ip { - span_builder = span_builder.target_ip(ip); - } - if let Some(ref fqdn) = ti.target_fqdn { - span_builder = span_builder.target_fqdn(fqdn); - } - if let Some(ref user) = ti.target_user { - span_builder = span_builder.target_user(user); - } - if let Some(target_type) = tt { - span_builder = span_builder.target_type(target_type); - } - if let Some(ref op) = request.operation_id { - span_builder = span_builder.operation_id(op); - } - let exec_span = span_builder.build(); - if let Some(ref tp) = request.traceparent { - set_span_parent(&exec_span, tp); - } - execute_and_respond(&mut conn, &request, &mut unavailable_tools) - .instrument(exec_span) - .await; - - // Back to idle - let _ = status_tx.send(WorkerStatus { - status: "idle".to_string(), - current_task: None, - }); - } - Ok(None) => { - // BRPOP timeout, no request — just loop - retry_delay = Duration::from_secs(1); + let msg = match next { + Some(m) => m, + None => { + warn!("Tool executor: subscription closed, exiting"); + return Ok(()); } + }; + + let request: ToolExecRequest = match serde_json::from_slice(&msg.payload) { + Ok(r) => r, Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - ] - .iter() - .any(|kw| error_str.contains(kw)); - - if is_conn_error { - // ConnectionManager auto-reconnects; just back off before retrying - warn!( - delay_secs = retry_delay.as_secs(), - "Tool executor: connection error, retrying: {e}" - ); - tokio::select! { - _ = tokio::time::sleep(retry_delay) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = (retry_delay * 2).min(max_retry_delay); - } else { - error!("Tool executor: non-connection error: {e}"); - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} - _ = shutdown.notified() => return Ok(()), - } - retry_delay = Duration::from_secs(1); - } + warn!(err = %e, "Bad ToolExecRequest payload, skipping"); + continue; } + }; + + let _ = status_tx.send(WorkerStatus { + status: "busy".to_string(), + current_task: Some(format!("{}:{}", request.tool_name, request.call_id)), + }); + + let ti = extract_target_info(&request.arguments); + let tt = infer_target_type_from_info(&ti); + let mut span_builder = AgentSpanBuilder::new("tool_exec", &config.worker_role, Team::Red) + .tool(&request.tool_name) + .kind(SpanKind::Consumer); + if let Some(ref ip) = ti.target_ip { + span_builder = span_builder.target_ip(ip); + } + if let Some(ref fqdn) = ti.target_fqdn { + span_builder = span_builder.target_fqdn(fqdn); + } + if let Some(ref user) = ti.target_user { + span_builder = span_builder.target_user(user); + } + if let Some(target_type) = tt { + span_builder = span_builder.target_type(target_type); + } + if let Some(ref op) = request.operation_id { + span_builder = span_builder.operation_id(op); + } + let exec_span = span_builder.build(); + if let Some(ref tp) = request.traceparent { + set_span_parent(&exec_span, tp); } - } -} -/// BRPOP a single tool execution request from the queue. -async fn poll_tool_request( - conn: &mut redis::aio::ConnectionManager, - queue_key: &str, - timeout: Duration, -) -> anyhow::Result> { - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(queue_key) - .arg(timeout.as_secs() as i64) - .query_async(conn) - .await?; + let reply_to = msg.reply.clone(); + let client_for_reply = client.clone(); - match result { - Some((_key, data)) => { - let request: ToolExecRequest = serde_json::from_str(&data)?; - debug!( - tool = %request.tool_name, - call_id = %request.call_id, - task_id = %request.task_id, - "Received tool exec request" - ); - Ok(Some(request)) - } - None => Ok(None), + execute_and_respond(client_for_reply, reply_to, &request, &mut unavailable_tools) + .instrument(exec_span) + .await; + + let _ = status_tx.send(WorkerStatus { + status: "idle".to_string(), + current_task: None, + }); } } -/// Execute a tool call and push the result to Redis. -/// -/// If the tool has previously failed with "not installed", short-circuits -/// immediately without attempting to spawn the binary. +/// Execute a tool call and reply on the NATS inbox. async fn execute_and_respond( - conn: &mut redis::aio::ConnectionManager, + client: async_nats::Client, + reply_to: Option, request: &ToolExecRequest, unavailable_tools: &mut std::collections::HashSet, ) { - // Short-circuit if this tool is known to be unavailable if unavailable_tools.contains(&request.tool_name) { debug!( tool = %request.tool_name, @@ -246,10 +180,7 @@ async fn execute_and_respond( )), discoveries: None, }; - let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); - if let Ok(json) = serde_json::to_string(&response) { - let _ = push_result(conn, &result_key, &json).await; - } + send_reply(&client, reply_to.as_ref(), &response).await; return; } @@ -265,9 +196,7 @@ async fn execute_and_respond( let response = match ares_tools::dispatch(&request.tool_name, &request.arguments).await { Ok(output) => { - // Raw output for structured parsers (need unfiltered data) let raw = output.combined_raw(); - // Filtered output for LLM (strips MOTD, noise, etc.) let combined = output.combined(); let error = if output.success { None @@ -275,7 +204,6 @@ async fn execute_and_respond( Some(format!("tool exited with code {:?}", output.exit_code)) }; - // Parse structured discoveries from raw (unfiltered) tool output let discoveries = ares_tools::parsers::parse_tool_output( &request.tool_name, &raw, @@ -287,7 +215,6 @@ async fn execute_and_respond( Some(discoveries) }; - // Emit discovery spans for observability if let Some(ref disc) = discoveries { if let Some(obj) = disc.as_object() { for (disc_type, items) in obj { @@ -318,7 +245,6 @@ async fn execute_and_respond( } Err(e) => { let err_str = e.to_string(); - // Track tools that fail because the binary is missing if err_str.contains("failed to spawn") || err_str.contains("not installed") { warn!( tool = %request.tool_name, @@ -341,45 +267,36 @@ async fn execute_and_respond( } }; - let has_error = response.error.is_some(); - let result_key = format!("{TOOL_RESULT_PREFIX}:{}", request.call_id); + debug!( + tool = %request.tool_name, + call_id = %request.call_id, + has_error = response.error.is_some(), + "Tool result ready" + ); + send_reply(&client, reply_to.as_ref(), &response).await; +} - match serde_json::to_string(&response) { - Ok(json) => { - if let Err(e) = push_result(conn, &result_key, &json).await { - error!( - call_id = %request.call_id, - "Failed to push tool result: {e}" - ); - } else { - debug!( - tool = %request.tool_name, - call_id = %request.call_id, - has_error = has_error, - "Tool result pushed" - ); +async fn send_reply( + client: &async_nats::Client, + reply_to: Option<&async_nats::Subject>, + response: &ToolExecResponse, +) { + let Some(reply) = reply_to else { + warn!(call_id = %response.call_id, "No reply subject — orchestrator will time out"); + return; + }; + match serde_json::to_vec(response) { + Ok(bytes) => { + if let Err(e) = client.publish(reply.clone(), Bytes::from(bytes)).await { + error!(call_id = %response.call_id, "Failed to publish reply: {e}"); } } Err(e) => { - error!( - call_id = %request.call_id, - "Failed to serialize tool result: {e}" - ); + error!(call_id = %response.call_id, "Failed to serialize reply: {e}"); } } } -/// LPUSH result and set TTL. -async fn push_result( - conn: &mut redis::aio::ConnectionManager, - result_key: &str, - result_json: &str, -) -> anyhow::Result<()> { - conn.lpush::<_, _, ()>(result_key, result_json).await?; - conn.expire::<_, ()>(result_key, RESULT_TTL).await?; - Ok(()) -} - // ─── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -443,18 +360,6 @@ mod tests { assert!(json.contains("192.168.58.10")); } - #[test] - fn redis_key_prefixes_match_orchestrator() { - // These must match crate::orchestrator::tool_dispatcher - assert_eq!(TOOL_EXEC_PREFIX, "ares:tool_exec"); - assert_eq!(TOOL_RESULT_PREFIX, "ares:tool_results"); - } - - #[test] - fn result_ttl_is_one_hour() { - assert_eq!(RESULT_TTL, 3600); - } - #[test] fn tool_exec_request_deserialize_with_traceparent() { let json = r#"{ @@ -587,54 +492,10 @@ mod tests { } #[test] - fn queue_key_format() { + fn nats_subject_format() { let role = "recon"; - let key = format!("{TOOL_EXEC_PREFIX}:{role}"); - assert_eq!(key, "ares:tool_exec:recon"); - } - - #[test] - fn result_key_format() { - let call_id = "nmap_scan_abc123"; - let key = format!("{TOOL_RESULT_PREFIX}:{call_id}"); - assert_eq!(key, "ares:tool_results:nmap_scan_abc123"); - } - - #[test] - fn connection_error_detection_keywords() { - // Verify the connection error detection logic from the main loop - let conn_keywords = [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - ]; - - let test_errors = [ - ("connection refused", true), - ("failed to connect", true), - ("connection closed", true), - ("operation timeout", true), - ("broken pipe", true), - ("connection reset by peer", true), - ("invalid argument", false), - ("permission denied", false), - ("key not found", false), - ]; - - for (error_str, expected_is_conn) in test_errors { - let error_lower = error_str.to_lowercase(); - let is_conn = conn_keywords.iter().any(|kw| error_lower.contains(kw)); - assert_eq!( - is_conn, - expected_is_conn, - "Error '{}' should {}be a connection error", - error_str, - if expected_is_conn { "" } else { "NOT " } - ); - } + let subj = nats::tool_exec_subject(role); + assert_eq!(subj, "ares.tools.exec.recon"); } #[test] diff --git a/ares-core/Cargo.toml b/ares-core/Cargo.toml index bb9cb90c..ae244d81 100644 --- a/ares-core/Cargo.toml +++ b/ares-core/Cargo.toml @@ -9,6 +9,9 @@ serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } redis = { workspace = true } +async-nats = { workspace = true } +futures = { workspace = true } +bytes = { workspace = true } chrono = { workspace = true } tracing = { workspace = true } uuid = { workspace = true } diff --git a/ares-core/src/lib.rs b/ares-core/src/lib.rs index 989cf7f3..da8fdee4 100644 --- a/ares-core/src/lib.rs +++ b/ares-core/src/lib.rs @@ -16,6 +16,7 @@ pub mod detection; #[cfg(feature = "blue")] pub mod eval; pub mod models; +pub mod nats; pub mod parsing; pub mod persistent_store; pub mod reports; diff --git a/ares-core/src/nats.rs b/ares-core/src/nats.rs new file mode 100644 index 00000000..c3107339 --- /dev/null +++ b/ares-core/src/nats.rs @@ -0,0 +1,369 @@ +//! NATS / JetStream broker integration for the Ares queue protocol. +//! +//! Replaces the Redis List + BRPOP work queue and per-call mailbox patterns: +//! +//! | Old Redis key | New NATS subject | Persistence | +//! |----------------------------------------|----------------------------------|--------------| +//! | `ares:tasks:{role}` | `ares.tasks.{role}` | JetStream | +//! | `ares:results:{task_id}` | NATS reply inbox (per request) | core (sync) | +//! | `ares:tool_exec:{role}` | `ares.tools.exec.{role}` | core | +//! | `ares:tool_results:{call_id}` | NATS reply inbox | core (sync) | +//! | `ares:blue:tasks:global:{role}` | `ares.blue.tasks.{role}` | JetStream | +//! | `ares:blue:results:{task_id}` | NATS reply inbox | core (sync) | +//! | `ares:deferred:{op}:{type}` (ZSET) | `ares.deferred.{op}.{type}` | JetStream KV | +//! | `ares:state:updates:{op}` (PUBLISH) | `ares.state.updates.{op}` | core | +//! +//! Tool calls and result mailboxes use NATS request/reply, which removes the +//! "BRPOP needs a dedicated TCP connection" workaround in the Redis path +//! (a single multiplexed NATS connection handles arbitrary concurrent +//! request/reply pairs because each reply uses its own auto-generated inbox +//! subject). +//! +//! Work queues use JetStream with a pull consumer per worker role. Acks are +//! explicit and `max_deliver` triggers redelivery on worker crash, replacing +//! the silent message loss of `BRPOP`. + +use std::time::Duration; + +use anyhow::{Context, Result}; +use async_nats::jetstream::consumer::pull::Config as PullConfig; +use async_nats::jetstream::consumer::AckPolicy; +use async_nats::jetstream::stream::{Config as StreamConfig, RetentionPolicy, StorageType}; +use async_nats::jetstream::{self, Context as JetStreamContext}; +use async_nats::Client; +use tracing::{info, warn}; + +/// Default NATS URL used when neither `ARES_NATS_URL` nor an explicit URL is provided. +pub const DEFAULT_NATS_URL: &str = "nats://127.0.0.1:4222"; + +// === Subject taxonomy ===================================================== + +/// Red team task work queue. `ares.tasks.{role}` (e.g. `ares.tasks.recon`). +pub const TASK_SUBJECT_PREFIX: &str = "ares.tasks"; +/// Tool dispatch RPC. `ares.tools.exec.{role}`. +pub const TOOL_EXEC_SUBJECT_PREFIX: &str = "ares.tools.exec"; +/// Blue team task work queue. `ares.blue.tasks.{role}`. +pub const BLUE_TASK_SUBJECT_PREFIX: &str = "ares.blue.tasks"; +/// Blue investigation request queue. `ares.blue.investigations`. +pub const BLUE_INVESTIGATION_SUBJECT: &str = "ares.blue.investigations"; +/// Deferred (delayed re-dispatch) tasks. `ares.deferred.{op}.{type}`. +pub const DEFERRED_SUBJECT_PREFIX: &str = "ares.deferred"; +/// State change notifications. `ares.state.updates.{op}` (core, fire-and-forget). +pub const STATE_UPDATE_SUBJECT_PREFIX: &str = "ares.state.updates"; +/// Real-time discovery forwarding. `ares.discoveries.{op}`. +pub const DISCOVERY_SUBJECT_PREFIX: &str = "ares.discoveries"; +/// Per-task result subject. `ares.tasks.results.{task_id}`. +/// Lives on the `ARES_TASKS` stream so results survive orchestrator restart. +pub const TASK_RESULT_SUBJECT_PREFIX: &str = "ares.tasks.results"; +/// Urgent task subject (priority ≤ 2). `ares.tasks.urgent.{role}`. +pub const URGENT_TASK_SUBJECT_PREFIX: &str = "ares.tasks.urgent"; +/// Blue task result subject. `ares.blue.tasks.results.{task_id}`. +pub const BLUE_TASK_RESULT_SUBJECT_PREFIX: &str = "ares.blue.tasks.results"; + +// === Stream names ========================================================= + +/// JetStream stream containing all red-team task subjects. +pub const TASKS_STREAM: &str = "ARES_TASKS"; +/// JetStream stream containing all blue-team task subjects. +pub const BLUE_TASKS_STREAM: &str = "ARES_BLUE_TASKS"; +/// JetStream stream containing deferred-task subjects. +pub const DEFERRED_STREAM: &str = "ARES_DEFERRED"; +/// JetStream stream containing real-time discoveries. +pub const DISCOVERIES_STREAM: &str = "ARES_DISCOVERIES"; + +// === Subject builders ===================================================== + +#[inline] +pub fn task_subject(role: &str) -> String { + format!("{TASK_SUBJECT_PREFIX}.{role}") +} + +#[inline] +pub fn urgent_task_subject(role: &str) -> String { + format!("{URGENT_TASK_SUBJECT_PREFIX}.{role}") +} + +#[inline] +pub fn task_result_subject(task_id: &str) -> String { + format!("{TASK_RESULT_SUBJECT_PREFIX}.{task_id}") +} + +#[inline] +pub fn blue_task_result_subject(task_id: &str) -> String { + format!("{BLUE_TASK_RESULT_SUBJECT_PREFIX}.{task_id}") +} + +#[inline] +pub fn tool_exec_subject(role: &str) -> String { + format!("{TOOL_EXEC_SUBJECT_PREFIX}.{role}") +} + +#[inline] +pub fn blue_task_subject(role: &str) -> String { + format!("{BLUE_TASK_SUBJECT_PREFIX}.{role}") +} + +#[inline] +pub fn deferred_subject(operation_id: &str, task_type: &str) -> String { + format!("{DEFERRED_SUBJECT_PREFIX}.{operation_id}.{task_type}") +} + +#[inline] +pub fn state_update_subject(operation_id: &str) -> String { + format!("{STATE_UPDATE_SUBJECT_PREFIX}.{operation_id}") +} + +#[inline] +pub fn discovery_subject(operation_id: &str) -> String { + format!("{DISCOVERY_SUBJECT_PREFIX}.{operation_id}") +} + +// === Connection =========================================================== + +/// Shared NATS broker handle. +/// +/// `async_nats::Client` is already cheaply cloneable and multiplexes all +/// subscriptions and requests over a single TCP connection — we just keep +/// the JetStream context alongside it for convenience. +#[derive(Clone)] +pub struct NatsBroker { + client: Client, + jetstream: JetStreamContext, +} + +impl NatsBroker { + /// Connect to NATS at the given URL (e.g. `nats://nats.attack-simulation.svc:4222`). + pub async fn connect(url: &str) -> Result { + let client = async_nats::connect(url) + .await + .with_context(|| format!("Failed to connect to NATS at {url}"))?; + let jetstream = jetstream::new(client.clone()); + info!(url, "Connected to NATS"); + Ok(Self { client, jetstream }) + } + + /// Resolve URL from `ARES_NATS_URL` then `NATS_URL`, falling back to localhost. + pub fn url_from_env() -> String { + std::env::var("ARES_NATS_URL") + .or_else(|_| std::env::var("NATS_URL")) + .unwrap_or_else(|_| DEFAULT_NATS_URL.to_string()) + } + + /// Connect using `ARES_NATS_URL` / `NATS_URL` / default. + pub async fn connect_from_env() -> Result { + Self::connect(&Self::url_from_env()).await + } + + pub fn client(&self) -> &Client { + &self.client + } + + pub fn jetstream(&self) -> &JetStreamContext { + &self.jetstream + } + + /// Ensure the standard Ares streams exist with sensible defaults. + /// + /// Idempotent — safe to call from every process on startup. The + /// orchestrator typically calls this; workers can rely on the stream + /// already existing but calling again is harmless. + pub async fn ensure_streams(&self) -> Result<()> { + self.ensure_stream(StreamSpec::tasks()).await?; + self.ensure_stream(StreamSpec::blue_tasks()).await?; + self.ensure_stream(StreamSpec::deferred()).await?; + self.ensure_stream(StreamSpec::discoveries()).await?; + Ok(()) + } + + /// Create or update a single stream. + pub async fn ensure_stream(&self, spec: StreamSpec) -> Result<()> { + match self.jetstream.get_or_create_stream(spec.to_config()).await { + Ok(_) => { + info!(stream = spec.name, "JetStream ready"); + Ok(()) + } + Err(e) => { + warn!(stream = spec.name, err = %e, "Failed to create/get stream"); + Err(anyhow::anyhow!( + "JetStream stream {} unavailable: {e}", + spec.name + )) + } + } + } + + /// Ensure a durable pull consumer exists on the given stream + filter. + /// + /// Returns the consumer name. Idempotent on repeated calls with the same + /// configuration. + pub async fn ensure_pull_consumer( + &self, + stream: &str, + durable_name: &str, + filter_subject: &str, + ) -> Result { + let stream_handle = self + .jetstream + .get_stream(stream) + .await + .with_context(|| format!("get_stream({stream})"))?; + + let cfg = PullConfig { + durable_name: Some(durable_name.to_string()), + filter_subject: filter_subject.to_string(), + ack_policy: AckPolicy::Explicit, + ack_wait: Duration::from_secs(60 * 30), // tools can take minutes + max_deliver: 5, // bounded redelivery on worker crash + ..Default::default() + }; + + stream_handle + .get_or_create_consumer(durable_name, cfg) + .await + .with_context(|| format!("ensure consumer {durable_name} on {stream}"))?; + Ok(durable_name.to_string()) + } +} + +/// Stream definition. One per logical broker workload. +pub struct StreamSpec { + pub name: &'static str, + pub subjects: Vec, + pub max_age: Duration, + pub storage: StorageType, +} + +impl StreamSpec { + /// Red team task queue stream. + pub fn tasks() -> Self { + Self { + name: TASKS_STREAM, + subjects: vec![format!("{TASK_SUBJECT_PREFIX}.>")], + max_age: Duration::from_secs(60 * 60 * 24), // 24h + storage: StorageType::File, + } + } + + /// Blue team task queue stream. + pub fn blue_tasks() -> Self { + Self { + name: BLUE_TASKS_STREAM, + subjects: vec![ + format!("{BLUE_TASK_SUBJECT_PREFIX}.>"), + BLUE_INVESTIGATION_SUBJECT.to_string(), + ], + max_age: Duration::from_secs(60 * 60 * 24), + storage: StorageType::File, + } + } + + /// Deferred-task stream. Messages here carry a `Nats-Expected-Stream`- + /// independent delay; consumers fetch and re-publish to the live + /// `ares.tasks.{role}` subject when their deadline arrives. + pub fn deferred() -> Self { + Self { + name: DEFERRED_STREAM, + subjects: vec![format!("{DEFERRED_SUBJECT_PREFIX}.>")], + max_age: Duration::from_secs(60 * 60 * 6), // shorter — deferred tasks are short-lived + storage: StorageType::File, + } + } + + /// Real-time discovery forwarding stream. + pub fn discoveries() -> Self { + Self { + name: DISCOVERIES_STREAM, + subjects: vec![format!("{DISCOVERY_SUBJECT_PREFIX}.>")], + max_age: Duration::from_secs(60 * 60 * 12), + storage: StorageType::File, + } + } + + fn to_config(&self) -> StreamConfig { + StreamConfig { + name: self.name.to_string(), + subjects: self.subjects.clone(), + retention: RetentionPolicy::WorkQueue, + max_age: self.max_age, + storage: self.storage, + ..Default::default() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn task_subject_format() { + assert_eq!(task_subject("recon"), "ares.tasks.recon"); + assert_eq!(task_subject("lateral"), "ares.tasks.lateral"); + } + + #[test] + fn tool_exec_subject_format() { + assert_eq!(tool_exec_subject("recon"), "ares.tools.exec.recon"); + } + + #[test] + fn blue_task_subject_format() { + assert_eq!(blue_task_subject("triage"), "ares.blue.tasks.triage"); + } + + #[test] + fn deferred_subject_format() { + assert_eq!( + deferred_subject("op-1", "recon"), + "ares.deferred.op-1.recon" + ); + } + + #[test] + fn state_update_subject_format() { + assert_eq!(state_update_subject("op-1"), "ares.state.updates.op-1"); + } + + #[test] + fn discovery_subject_format() { + assert_eq!(discovery_subject("op-1"), "ares.discoveries.op-1"); + } + + #[test] + fn url_from_env_default() { + // Must not panic when neither var is set; we don't assert exact value + // because the test environment may have one set. + let url = NatsBroker::url_from_env(); + assert!(!url.is_empty()); + } + + #[test] + fn tasks_stream_spec_covers_all_roles() { + let spec = StreamSpec::tasks(); + assert_eq!(spec.name, "ARES_TASKS"); + assert_eq!(spec.subjects, vec!["ares.tasks.>"]); + } + + #[test] + fn blue_tasks_stream_includes_investigation_subject() { + let spec = StreamSpec::blue_tasks(); + assert_eq!(spec.name, "ARES_BLUE_TASKS"); + assert!(spec + .subjects + .iter() + .any(|s| s == "ares.blue.investigations")); + assert!(spec.subjects.iter().any(|s| s == "ares.blue.tasks.>")); + } + + #[test] + fn deferred_stream_subject_pattern() { + let spec = StreamSpec::deferred(); + assert_eq!(spec.subjects, vec!["ares.deferred.>"]); + } + + #[test] + fn discoveries_stream_subject_pattern() { + let spec = StreamSpec::discoveries(); + assert_eq!(spec.subjects, vec!["ares.discoveries.>"]); + } +} diff --git a/ares-core/src/state/blue_task_queue.rs b/ares-core/src/state/blue_task_queue.rs index 122b1125..b888eb1a 100644 --- a/ares-core/src/state/blue_task_queue.rs +++ b/ares-core/src/state/blue_task_queue.rs @@ -1,15 +1,30 @@ -//! Blue team task queue for distributed investigation workers. +//! Blue team task queue. //! -//! Matches the Python `BlueTaskQueue` key patterns for task submission, -//! result polling, heartbeat, and investigation registration. +//! Hybrid Redis + NATS JetStream. Queues live on JetStream; heartbeats and +//! investigation registration stay on Redis. +//! +//! NATS subjects: +//! - `ares.blue.tasks.{role}` global per-role work queue +//! - `ares.blue.tasks.results.{task_id}` durable per-task result subject +//! - `ares.blue.investigations` investigation request queue +//! +//! Redis keys (state only): +//! - `ares:blue:heartbeat:{agent}` agent heartbeat (TTL 60s) +//! - `ares:blue:active_investigations` SET of active investigation IDs +//! - `ares:blue:inv:{id}:queue_meta` investigation metadata (HASH) use std::time::Duration; +use anyhow::{Context, Result}; +use bytes::Bytes; use chrono::Utc; +use futures::StreamExt; use redis::aio::ConnectionManagerConfig; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; + +use crate::nats::{self, NatsBroker}; use super::keys::*; @@ -70,41 +85,59 @@ impl BlueTaskResult { } } -/// Blue team task queue backed by Redis. -/// -/// Queue naming: -/// ares:blue:tasks:global:{role} Global queue per role -/// ares:blue:results:{task_id} Result queue (TTL) -/// ares:blue:heartbeat:{agent} Agent heartbeat (TTL 60s) -/// ares:blue:active_investigations Active investigation IDs (SET, TTL 24h) -/// ares:blue:inv:{id}:queue_meta Investigation queue metadata (HASH, TTL 24h) +/// Blue team task queue — NATS for queues, Redis for state. pub struct BlueTaskQueue { conn: redis::aio::ConnectionManager, + nats: Option, } impl BlueTaskQueue { + /// Connect to Redis only (state methods work, queue methods will error). + /// Used by callers that only need heartbeat/investigation registration. pub async fn connect(redis_url: &str) -> anyhow::Result { let client = redis::Client::open(redis_url)?; - // Default response_timeout is 500ms which is too short for BRPOP - // blocking calls. Set to 30s to accommodate blocking operations. let config = ConnectionManagerConfig::new().set_response_timeout(Some(Duration::from_secs(30))); let conn = client.get_connection_manager_with_config(config).await?; - Ok(Self { conn }) + Ok(Self { conn, nats: None }) + } + + /// Connect to both Redis (state) and NATS (queues). + pub async fn connect_with_nats(redis_url: &str, nats_url: &str) -> anyhow::Result { + let mut q = Self::connect(redis_url).await?; + let nats = NatsBroker::connect(nats_url).await?; + nats.ensure_streams().await?; + q.nats = Some(nats); + Ok(q) } pub fn from_conn(conn: redis::aio::ConnectionManager) -> Self { - Self { conn } + Self { conn, nats: None } + } + + pub fn from_parts(conn: redis::aio::ConnectionManager, nats: NatsBroker) -> Self { + Self { + conn, + nats: Some(nats), + } } pub fn conn_mut(&mut self) -> &mut redis::aio::ConnectionManager { &mut self.conn } + fn nats(&self) -> Result<&NatsBroker> { + self.nats + .as_ref() + .context("BlueTaskQueue has no NATS broker configured") + } + + // === Queue methods (NATS JetStream) ===================================== + /// Submit a task to the global role queue. pub async fn submit_task(&mut self, task: &BlueTaskMessage) -> anyhow::Result<()> { - let queue_key = format!("{BLUE_TASK_QUEUE_PREFIX}:global:{}", task.role); - let data = serde_json::to_string(task)?; + let subject = nats::blue_task_subject(&task.role); + let bytes = Bytes::from(serde_json::to_vec(task).context("serialize BlueTaskMessage")?); debug!( task_id = %task.task_id, @@ -113,40 +146,69 @@ impl BlueTaskQueue { "submitting blue team task" ); - let _: () = self.conn.lpush(&queue_key, &data).await?; - let _: () = self.conn.expire(&queue_key, 86400).await?; + let ack = self + .nats()? + .jetstream() + .publish(subject.clone(), bytes) + .await + .with_context(|| format!("JetStream publish to {subject}"))?; + ack.await + .with_context(|| format!("Awaiting JetStream ack for {subject}"))?; Ok(()) } - /// Poll for a task from the global role queue (blocking). + /// Poll for a task from the global role queue (blocking up to `timeout_secs`). pub async fn poll_global_task( &mut self, role: &str, timeout_secs: f64, ) -> anyhow::Result> { - let queue_key = format!("{BLUE_TASK_QUEUE_PREFIX}:global:{role}"); - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(&queue_key) - .arg(timeout_secs) - .query_async(&mut self.conn) + let nats = self.nats()?; + let subject = nats::blue_task_subject(role); + let consumer_name = format!("blue-tasks-{role}"); + nats.ensure_pull_consumer(nats::BLUE_TASKS_STREAM, &consumer_name, &subject) .await?; - match result { - Some((_key, data)) => { - let task: BlueTaskMessage = serde_json::from_str(&data)?; + let stream = nats.jetstream().get_stream(nats::BLUE_TASKS_STREAM).await?; + let consumer = stream + .get_consumer::(&consumer_name) + .await + .map_err(|e| anyhow::anyhow!("get_consumer({consumer_name}): {e}"))?; + + let timeout = Duration::from_secs_f64(timeout_secs.max(0.05)); + let mut fetch = consumer + .fetch() + .max_messages(1) + .expires(timeout) + .messages() + .await + .context("start fetch")?; + + match fetch.next().await { + Some(Ok(m)) => { + let task: BlueTaskMessage = serde_json::from_slice(&m.payload) + .with_context(|| format!("Bad BlueTaskMessage JSON on {subject}"))?; + m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); Ok(Some(task)) } + Some(Err(e)) => Err(anyhow::anyhow!("JetStream fetch error: {e}")), None => Ok(None), } } - /// Send a task result. + /// Send a task result to its dedicated result subject. pub async fn send_result(&mut self, result: &BlueTaskResult) -> anyhow::Result<()> { - let result_key = format!("{BLUE_RESULT_QUEUE_PREFIX}:{}", result.task_id); - let data = serde_json::to_string(result)?; - - let _: () = self.conn.lpush(&result_key, &data).await?; - let _: () = self.conn.expire(&result_key, 3600).await?; // 1h TTL + let subject = nats::blue_task_result_subject(&result.task_id); + let bytes = Bytes::from(serde_json::to_vec(result).context("serialize BlueTaskResult")?); + + let ack = self + .nats()? + .jetstream() + .publish(subject.clone(), bytes) + .await + .with_context(|| format!("JetStream publish to {subject}"))?; + ack.await + .with_context(|| format!("Awaiting ack for {subject}"))?; Ok(()) } @@ -156,36 +218,139 @@ impl BlueTaskQueue { task_id: &str, timeout_secs: f64, ) -> anyhow::Result> { - let result_key = format!("{BLUE_RESULT_QUEUE_PREFIX}:{task_id}"); - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(&result_key) - .arg(timeout_secs) - .query_async(&mut self.conn) - .await?; + self.fetch_result(task_id, Duration::from_secs_f64(timeout_secs.max(0.0))) + .await + } - match result { - Some((_key, data)) => { - let task_result: BlueTaskResult = serde_json::from_str(&data)?; - Ok(Some(task_result)) + /// Check for a result without blocking. + pub async fn check_result(&mut self, task_id: &str) -> anyhow::Result> { + self.fetch_result(task_id, Duration::from_millis(100)).await + } + + async fn fetch_result( + &mut self, + task_id: &str, + timeout: Duration, + ) -> anyhow::Result> { + use async_nats::jetstream::consumer::pull::Config as PullConfig; + use async_nats::jetstream::consumer::AckPolicy; + + let nats = self.nats()?; + let stream = nats.jetstream().get_stream(nats::BLUE_TASKS_STREAM).await?; + + let cfg = PullConfig { + filter_subject: nats::blue_task_result_subject(task_id), + ack_policy: AckPolicy::Explicit, + inactive_threshold: Duration::from_secs(60), + ..Default::default() + }; + + let consumer = stream + .create_consumer(cfg) + .await + .context("create ephemeral blue result consumer")?; + + let mut fetch = consumer + .fetch() + .max_messages(1) + .expires(timeout.max(Duration::from_millis(50))) + .messages() + .await + .context("start fetch")?; + + match fetch.next().await { + Some(Ok(m)) => { + let parsed: BlueTaskResult = serde_json::from_slice(&m.payload) + .with_context(|| format!("Bad BlueTaskResult JSON for {task_id}"))?; + m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); + Ok(Some(parsed)) } + Some(Err(e)) => Err(anyhow::anyhow!("JetStream fetch error: {e}")), None => Ok(None), } } - /// Check for a result without blocking. - pub async fn check_result(&mut self, task_id: &str) -> anyhow::Result> { - let result_key = format!("{BLUE_RESULT_QUEUE_PREFIX}:{task_id}"); - let result: Option = self.conn.rpop(&result_key, None).await?; - - match result { - Some(data) => { - let task_result: BlueTaskResult = serde_json::from_str(&data)?; - Ok(Some(task_result)) + /// Pop an investigation request from the queue. + pub async fn pop_investigation_request( + &mut self, + timeout_secs: f64, + ) -> anyhow::Result> { + let nats = self.nats()?; + let consumer_name = "blue-investigations"; + nats.ensure_pull_consumer( + nats::BLUE_TASKS_STREAM, + consumer_name, + nats::BLUE_INVESTIGATION_SUBJECT, + ) + .await?; + + let stream = nats.jetstream().get_stream(nats::BLUE_TASKS_STREAM).await?; + let consumer = stream + .get_consumer::(consumer_name) + .await + .map_err(|e| anyhow::anyhow!("get_consumer({consumer_name}): {e}"))?; + + let timeout = Duration::from_secs_f64(timeout_secs.max(0.05)); + let mut fetch = consumer + .fetch() + .max_messages(1) + .expires(timeout) + .messages() + .await + .context("start fetch")?; + + match fetch.next().await { + Some(Ok(m)) => { + match serde_json::from_slice::(&m.payload) { + Ok(val) => { + m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); + Ok(Some(val)) + } + Err(e) => { + warn!("Failed to parse investigation request: {e}"); + // Ack and skip the malformed message + m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); + Ok(None) + } + } } + Some(Err(e)) => Err(anyhow::anyhow!("JetStream fetch error: {e}")), None => Ok(None), } } + /// Submit an investigation request via the supplied NATS broker. No Redis + /// connection required — used by CLI submission paths (`ares blue submit`, + /// `ares blue from-operation`, auto-submit) which only need to publish. + pub async fn submit_investigation_request( + broker: &NatsBroker, + request: &serde_json::Value, + ) -> anyhow::Result<()> { + let bytes = Bytes::from(serde_json::to_vec(request).context("serialize request")?); + let ack = broker + .jetstream() + .publish(nats::BLUE_INVESTIGATION_SUBJECT, bytes) + .await + .with_context(|| { + format!("JetStream publish to {}", nats::BLUE_INVESTIGATION_SUBJECT) + })?; + ack.await.context("ack investigation request")?; + Ok(()) + } + + /// Get the global role queue length (best-effort; returns total stream depth). + pub async fn queue_length(&mut self, _role: &str) -> anyhow::Result { + let stream = self + .nats()? + .jetstream() + .get_stream(nats::BLUE_TASKS_STREAM) + .await?; + let info = stream.cached_info(); + Ok(info.state.messages as usize) + } + + // === Redis-backed state methods ======================================== + /// Send a heartbeat for a blue team agent. pub async fn send_heartbeat( &mut self, @@ -234,14 +399,12 @@ impl BlueTaskQueue { alert: &serde_json::Value, model: &str, ) -> anyhow::Result<()> { - // Add to active set let _: () = self .conn .sadd(BLUE_ACTIVE_INVESTIGATIONS, investigation_id) .await?; let _: () = self.conn.expire(BLUE_ACTIVE_INVESTIGATIONS, 86400).await?; - // Store investigation metadata let meta_key = format!("{BLUE_KEY_PREFIX}:{investigation_id}:queue_meta"); let _: () = self .conn @@ -257,6 +420,8 @@ impl BlueTaskQueue { .hset(&meta_key, "registered_at", Utc::now().to_rfc3339()) .await?; let _: () = self.conn.expire(&meta_key, 86400).await?; + + info!(investigation_id, "Investigation registered as active"); Ok(()) } @@ -288,36 +453,6 @@ impl BlueTaskQueue { let model: Option = self.conn.hget(&meta_key, "model").await?; Ok(model) } - - /// Pop an investigation request from the queue. - pub async fn pop_investigation_request( - &mut self, - timeout_secs: f64, - ) -> anyhow::Result> { - let result: Option<(String, String)> = redis::cmd("BRPOP") - .arg(BLUE_INVESTIGATION_QUEUE) - .arg(timeout_secs) - .query_async(&mut self.conn) - .await?; - - match result { - Some((_key, data)) => match serde_json::from_str(&data) { - Ok(val) => Ok(Some(val)), - Err(e) => { - warn!("Failed to parse investigation request: {e}"); - Ok(None) - } - }, - None => Ok(None), - } - } - - /// Get the global role queue length. - pub async fn queue_length(&mut self, role: &str) -> anyhow::Result { - let queue_key = format!("{BLUE_TASK_QUEUE_PREFIX}:global:{role}"); - let len: usize = self.conn.llen(&queue_key).await?; - Ok(len) - } } #[cfg(test)] @@ -357,7 +492,6 @@ mod tests { let success = BlueTaskResult::success("t", "i", serde_json::Value::Null, "a"); let failure = BlueTaskResult::failure("t", "i", "err".to_string(), "a"); - // Both should have a non-empty RFC 3339 timestamp. assert!(!success.completed_at.is_empty()); assert!(!failure.completed_at.is_empty()); assert!(chrono::DateTime::parse_from_rfc3339(&success.completed_at).is_ok()); diff --git a/ares-core/src/state/operations.rs b/ares-core/src/state/operations.rs index 06ae4452..8b4c9730 100644 --- a/ares-core/src/state/operations.rs +++ b/ares-core/src/state/operations.rs @@ -8,25 +8,42 @@ use redis::AsyncCommands; use super::keys::*; use super::{build_key, build_lock_key}; -/// Publish a state update notification via Redis PUBLISH. +/// Publish a state update notification via NATS. /// -/// Channel: `ares:state:updates:{operation_id}` +/// Subject: `ares.state.updates.{operation_id}` (core publish, fire-and-forget). /// Message: `{"type":"state_update","operation_id":"...","ts":"..."}` /// -/// Returns the number of subscribers that received the message. +/// Returns 0 on success (no per-subscriber count; NATS core publish is async). +/// Connects to NATS using `ARES_NATS_URL` / `NATS_URL` if `nats` is None. pub async fn publish_state_update( - conn: &mut impl AsyncCommands, + _conn: &mut impl AsyncCommands, operation_id: &str, ) -> Result { - let channel = format!("{STATE_UPDATE_CHANNEL_PREFIX}:{operation_id}"); + use bytes::Bytes; let message = serde_json::json!({ "type": "state_update", "operation_id": operation_id, "ts": chrono::Utc::now().to_rfc3339(), }); - let msg_str = serde_json::to_string(&message).unwrap_or_default(); - let count: i64 = conn.publish(&channel, &msg_str).await?; - Ok(count) + let msg_bytes = Bytes::from(serde_json::to_vec(&message).unwrap_or_default()); + let subject = format!( + "{}.{operation_id}", + crate::nats::STATE_UPDATE_SUBJECT_PREFIX + ); + + // Best-effort one-shot publish. Errors here are not fatal — state writes + // already succeeded and the subscriber-count signal isn't load-bearing. + match crate::nats::NatsBroker::connect_from_env().await { + Ok(broker) => { + if let Err(e) = broker.client().publish(subject, msg_bytes).await { + tracing::debug!(operation_id, "NATS publish_state_update failed: {e}"); + } + } + Err(e) => { + tracing::debug!(operation_id, "NATS unavailable for state update: {e}"); + } + } + Ok(0) } /// Set the operation status JSON string. diff --git a/docs/blue.md b/docs/blue.md index 18e10fb3..b4322a16 100644 --- a/docs/blue.md +++ b/docs/blue.md @@ -34,7 +34,7 @@ The investigation orchestrator manages the full investigation lifecycle: - Chains follow-up investigations based on discovered evidence types - Enforces hard timeout watchdog (1 min/step + 2 min buffer) - Generates partial reports on timeout -- Handles investigation state persistence via Redis +- Handles investigation state persistence via Redis (task queues run on NATS JetStream) #### Blue Worker Task Loop diff --git a/docs/infrastructure.md b/docs/infrastructure.md index c491619f..689f9252 100644 --- a/docs/infrastructure.md +++ b/docs/infrastructure.md @@ -57,7 +57,7 @@ ansible/ Ansible collection (dreadnode.nimbus_range v warpgate-templates/ Container image build templates ares-base/ Base: Kali + Ansible base role + security tools - ares-orchestrator/ Orchestrator: unified Ares binary + Redis client + ares-orchestrator/ Orchestrator: unified Ares binary + Redis & NATS clients ares-worker/ Generic worker (inherits ares-base) ares-{recon,credential-access,cracker,acl,privesc,lateral-movement,coercion}-agent/ ares-cracker-{agent-gpu,base-gpu}/ @@ -231,6 +231,7 @@ kubectl run ares-orchestrator \ --image=ghcr.io/dreadnode/ares-python-orchestrator:latest \ -it --rm \ --env="REDIS_URL=redis://redis:6379" \ + --env="NATS_URL=nats://nats:4222" \ --env="ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY" \ -- ares orchestrator @@ -247,19 +248,26 @@ services: image: redis:7-alpine ports: ["6379:6379"] + nats: + image: nats:2.10-alpine + command: ["-js"] # enable JetStream + ports: ["4222:4222"] + orchestrator: image: ghcr.io/dreadnode/ares-orchestrator:latest command: ["ares", "orchestrator"] environment: REDIS_URL: redis://redis:6379 + NATS_URL: nats://nats:4222 ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY} - depends_on: [redis] + depends_on: [redis, nats] recon-worker: image: ghcr.io/dreadnode/ares-recon-agent:latest command: ["ares", "worker"] environment: REDIS_URL: redis://redis:6379 + NATS_URL: nats://nats:4222 ARES_WORKER_ROLE: recon - depends_on: [redis] + depends_on: [redis, nats] ``` diff --git a/docs/red.md b/docs/red.md index da3a3c3f..aad4882d 100644 --- a/docs/red.md +++ b/docs/red.md @@ -25,7 +25,7 @@ installed. │ - Operation completion decision │ │ - Does NOT execute exploitation tools directly │ └──────────────────────────────┬─────────────────────────────────────────┘ - │ Redis pub/sub + task queues + │ NATS JetStream tasks + Redis state ┌───────────────────────┼─────────────┬─────────────┬─────────────┬─────────────┐ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ @@ -60,11 +60,15 @@ Each worker agent has: - No knowledge of other workers' activities (except via shared state) - Responsibility to report results back to the orchestrator -### 3. Shared State via Redis +### 3. Shared State via Redis, Tasks via NATS -All agents share state through Redis: +Ares splits transport from state: -- Discovered credentials are automatically broadcast +- **NATS JetStream** carries task dispatch and tool RPC between orchestrator + and workers (durable work queues, pull consumers, explicit acks) +- **Redis** holds durable shared state: credentials, hashes, hosts, + vulnerabilities, locks, heartbeats, and operation metadata +- Discovered credentials are automatically broadcast via Redis state updates - Hashes are tracked for cracking status - Hosts and vulnerabilities are cataloged - Task status is visible to all agents @@ -446,7 +450,7 @@ coverage. Regardless of mode, conditions are checked in this order: -1. External stop signal (CLI `stop` command or Redis flag) +1. External stop signal (CLI `stop` command or Redis stop flag) 2. Max runtime exceeded (`timeouts.operation_timeout`) 3. Mode-specific DA/GT/forest check (described above) @@ -532,6 +536,22 @@ INFO | Operation phase transition: enumeration → privilege_escalation ## State Management +### Broker vs. State Split + +Ares uses two backends with distinct roles: + +- **NATS JetStream** — broker/transport for queues and RPC. Carries task + dispatch (`ares.red.tasks.{role}`, `ares.blue.tasks.{role}`), tool result + streams (`ares.{red,blue}.tasks.results.{task_id}`), and investigation + requests. Work-queue retention auto-deletes acked messages. +- **Redis** — durable, queryable state. Holds operation state, credentials, + hosts, hashes, vulnerabilities, heartbeats, locks, task status, and the + per-orchestrator deferred priority queue. + +Workers connect to both. The orchestrator owns one shared `NatsBroker` and +threads it through dispatcher, completion checks, and the embedded blue +auto-submit task. + ### Pattern: Write-Through Cache Redis is the **durable store**. In-memory dicts are **write-through caches**. @@ -573,8 +593,8 @@ SharedRedTeamState: When any agent discovers a credential: -1. Credential is added to shared state -2. Redis pub/sub broadcasts to all agents +1. Credential is added to shared state (Redis) +2. Other agents observe it on their next state read 3. All agents can use the credential immediately ## Task Flow Example @@ -703,7 +723,8 @@ kubectl -n attack-simulation exec -it ares-recon-agent-0 -- \ - `ares-cli/src/orchestrator/state/` - Operation state management - `ares-cli/src/orchestrator/config.rs` - Orchestrator configuration - `ares-cli/src/worker/` - Worker agent task loop, tool execution -- `ares-core/src/` - Shared models, state, Redis schema, telemetry +- `ares-core/src/` - Shared models, state, Redis/NATS schemas, telemetry +- `ares-core/src/nats/` - NATS JetStream broker, stream/subject taxonomy **CLI**: diff --git a/warpgate-templates/templates/ares-blue-agent/README.md b/warpgate-templates/templates/ares-blue-agent/README.md index d26af5fe..a0f19aab 100644 --- a/warpgate-templates/templates/ares-blue-agent/README.md +++ b/warpgate-templates/templates/ares-blue-agent/README.md @@ -58,6 +58,7 @@ After building the image, you can test it locally: # Run the agent container interactively docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-blue-agent:latest diff --git a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md index e94c64e4..9aa54698 100644 --- a/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md +++ b/warpgate-templates/templates/ares-blue-lateral-analyst-agent/README.md @@ -59,6 +59,7 @@ After building the image, you can test it locally: # Run the agent container interactively docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-blue-lateral-analyst-agent:latest diff --git a/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md b/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md index ee2557d3..fc3a144b 100644 --- a/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md +++ b/warpgate-templates/templates/ares-blue-threat-hunter-agent/README.md @@ -59,6 +59,7 @@ After building the image, you can test it locally: # Run the agent container interactively docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-blue-threat-hunter-agent:latest diff --git a/warpgate-templates/templates/ares-blue-triage-agent/README.md b/warpgate-templates/templates/ares-blue-triage-agent/README.md index 5b499ab6..608c558e 100644 --- a/warpgate-templates/templates/ares-blue-triage-agent/README.md +++ b/warpgate-templates/templates/ares-blue-triage-agent/README.md @@ -59,6 +59,7 @@ After building the image, you can test it locally: # Run the agent container interactively docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-blue-triage-agent:latest diff --git a/warpgate-templates/templates/ares-orchestrator/README.md b/warpgate-templates/templates/ares-orchestrator/README.md index e740d032..28640dfd 100644 --- a/warpgate-templates/templates/ares-orchestrator/README.md +++ b/warpgate-templates/templates/ares-orchestrator/README.md @@ -2,8 +2,9 @@ This template builds **Ares Orchestrator** images using Warp Gate. The orchestrator coordinates multi-agent red team operations, dispatching tasks to -specialized worker agents via Redis, using a compiled Rust binary with embedded -Python for LLM agent steps. +specialized worker agents via NATS JetStream (broker) with Redis as the durable +state store, using a compiled Rust binary with embedded Python for LLM agent +steps. --- @@ -82,9 +83,10 @@ After building the image, you can test it locally: **Run the orchestrator container interactively:** ```bash -# Run with Redis and API key for testing +# Run with Redis, NATS, and API key for testing docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ --entrypoint /bin/bash \ ares-orchestrator:latest @@ -99,16 +101,20 @@ docker run --rm --entrypoint ares ares-orchestrator:latest orchestrator --versio docker run --rm --entrypoint bash ares-orchestrator:latest -c "curl --version && jq --version" ``` -**Test with local Redis:** +**Test with local Redis and NATS:** ```bash # Start Redis in Docker docker run -d --name redis -p 6379:6379 redis:7-alpine -# Run the orchestrator connected to local Redis +# Start NATS with JetStream enabled +docker run -d --name nats -p 4222:4222 nats:2.10-alpine -js + +# Run the orchestrator connected to local Redis and NATS docker run -it --rm \ --network host \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ -e ARES_NAMESPACE="default" \ ares-orchestrator:latest @@ -146,7 +152,8 @@ ares orchestrator multi-agent contoso.local "192.168.58.10,192.168.58.11"``` The pod has the following environment variables pre-configured: -- `REDIS_URL`: Redis connection string with authentication +- `REDIS_URL`: Redis connection string with authentication (durable state store) +- `NATS_URL`: NATS server URL (task + RPC broker, e.g. `nats://nats:4222`) - `ANTHROPIC_API_KEY`: API key for Claude models - `ARES_NAMESPACE`: Kubernetes namespace for agent discovery @@ -171,7 +178,8 @@ The pod has the following environment variables pre-configured: - **Directory Structure:** - `/root/` - Default working directory - `/usr/local/bin/ares` - Compiled Ares binary - Python packages installed system-wide -- The orchestrator requires Redis, an Anthropic API key, and access to worker agents to function. +- The orchestrator requires Redis (state), NATS JetStream (broker), an + Anthropic API key, and access to worker agents to function. --- diff --git a/warpgate-templates/templates/ares-worker/README.md b/warpgate-templates/templates/ares-worker/README.md index 6a5e17d6..60997d83 100644 --- a/warpgate-templates/templates/ares-worker/README.md +++ b/warpgate-templates/templates/ares-worker/README.md @@ -1,9 +1,10 @@ # Ares Worker Warp Gate Template This template builds **Ares Worker** images using Warp Gate. It supports -building **Docker images** (for `amd64` and `arm64`). The worker agent polls -Redis for tasks and orchestrates tool execution across the Ares framework, -using a compiled Rust binary with embedded Python for LLM agent steps. +building **Docker images** (for `amd64` and `arm64`). The worker agent pulls +tasks from NATS JetStream, reads/writes shared state in Redis, and orchestrates +tool execution across the Ares framework, using a compiled Rust binary with +embedded Python for LLM agent steps. --- @@ -81,9 +82,10 @@ After building the image, you can test it locally: **Run the worker container interactively:** ```bash -# Run with Redis connection for testing +# Run with Redis + NATS connections for testing docker run -it --rm \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-worker:latest ``` @@ -94,16 +96,20 @@ docker run -it --rm \ # Check the Rust binary is available docker run --rm ares-worker:latest ares worker --version``` -**Test with local Redis:** +**Test with local Redis and NATS:** ```bash # Start Redis in Docker docker run -d --name redis -p 6379:6379 redis:7-alpine -# Run the worker connected to local Redis +# Start NATS with JetStream enabled +docker run -d --name nats -p 4222:4222 nats:2.10-alpine -js + +# Run the worker connected to local services docker run -it --rm \ --network host \ -e REDIS_URL="redis://localhost:6379" \ + -e NATS_URL="nats://localhost:4222" \ -e ANTHROPIC_API_KEY="your-api-key" \ ares-worker:latest ``` @@ -145,7 +151,8 @@ warpgate validate ares-worker - **Directory Structure:** - `/root/` - Default working directory - `/usr/local/bin/ares` - Compiled Ares binary - Python packages installed system-wide -- The worker requires Redis and an Anthropic API key to function. +- The worker requires Redis (state), NATS JetStream (broker), and an + Anthropic API key to function. --- From 929254977bca28e75c9cff9f20587fc5ce29a87e Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Wed, 29 Apr 2026 16:32:44 -0600 Subject: [PATCH 2/5] test: add and improve unit test coverage for dispatcher, result handler, and blue queue **Added:** - Unit tests for credential extraction logic in tool dispatcher, including cases for various tools, username/domain presence, and field aliases - Tests for traceparent and operation_id serialization in ToolExecRequest - Tests for ToolExecResponse discovery field handling and default behaviors - Extensive integration tests for push_realtime_discoveries, covering host, credential, hash, vulnerability, share, user, trust, and various error cases - Unit tests for AuthThrottle covering limits, credential separation, and window expiry logic - Tests for set_task_status in result_handler, including overwriting, merging, and handling non-object extras - Unit tests for BlueTaskQueueCore covering serialization, heartbeat, active investigation, alert/model retrieval, and error handling for missing NATS **Changed:** - push_realtime_discoveries and set_task_status made generic over Redis connection type to support mock connections in tests - BlueTaskQueue refactored to use a generic BlueTaskQueueCore with production and test implementations, enabling better unit testability - Imports updated to include ConnectionLike and support new generic types --- .../src/orchestrator/tool_dispatcher/mod.rs | 11 +- .../src/orchestrator/tool_dispatcher/tests.rs | 403 ++++++++++++++++++ .../src/worker/task_loop/result_handler.rs | 70 ++- ares-core/src/state/blue_task_queue.rs | 228 +++++++++- 4 files changed, 699 insertions(+), 13 deletions(-) diff --git a/ares-cli/src/orchestrator/tool_dispatcher/mod.rs b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs index 5986df3f..13b71738 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/mod.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/mod.rs @@ -10,12 +10,13 @@ //! Also provides [`LocalToolDispatcher`] for in-process execution without //! going through NATS, useful for testing or single-binary deployments. +use redis::aio::ConnectionLike; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use tracing::debug; use crate::orchestrator::state::DISCOVERY_KEY_PREFIX; -use crate::orchestrator::task_queue::TaskQueue; +use crate::orchestrator::task_queue::TaskQueueCore; mod auth_throttle; mod local; @@ -144,13 +145,15 @@ pub(super) fn resolve_queue_role<'a>(role: &'a str, tool_name: &str) -> &'a str /// /// `tool_args` carries the tool call's input arguments — used to extract /// the authenticating credential (username/domain) for lineage tracking. -pub(super) async fn push_realtime_discoveries( - queue: &TaskQueue, +pub(super) async fn push_realtime_discoveries( + queue: &TaskQueueCore, operation_id: &str, discoveries: &serde_json::Value, tool_name: &str, tool_args: &serde_json::Value, -) { +) where + C: ConnectionLike + Clone + Send + Sync + 'static, +{ let discovery_key = format!("{DISCOVERY_KEY_PREFIX}:{operation_id}"); let mut conn = queue.connection(); diff --git a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs index eeabb95a..7dcbda0f 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs @@ -1,4 +1,7 @@ use super::*; +use crate::orchestrator::task_queue::TaskQueueCore; +use ares_core::state::mock_redis::MockRedisConnection; +use redis::AsyncCommands; #[test] fn tool_exec_request_serialization() { @@ -96,3 +99,403 @@ fn cross_role_routing_recon_stays_recon() { "recon" ); } + +#[test] +fn extract_credential_key_returns_none_for_non_auth_tool() { + let call = ares_llm::ToolCall { + id: "1".into(), + name: "nmap_scan".into(), + arguments: serde_json::json!({"target": "192.168.58.0/24"}), + }; + assert!(extract_credential_key(&call).is_none()); +} + +#[test] +fn extract_credential_key_returns_none_when_username_missing() { + let call = ares_llm::ToolCall { + id: "1".into(), + name: "secretsdump".into(), + arguments: serde_json::json!({"target": "192.168.58.10"}), + }; + assert!(extract_credential_key(&call).is_none()); +} + +#[test] +fn extract_credential_key_lowercases_username_and_domain() { + let call = ares_llm::ToolCall { + id: "1".into(), + name: "password_spray".into(), + arguments: serde_json::json!({ + "username": "Administrator", + "domain": "CONTOSO.LOCAL", + "passwords": ["P@ss"] + }), + }; + let key = extract_credential_key(&call).expect("key extracted"); + assert_eq!(key, "administrator@contoso.local"); +} + +#[test] +fn extract_credential_key_uses_unknown_when_domain_missing() { + let call = ares_llm::ToolCall { + id: "1".into(), + name: "secretsdump".into(), + arguments: serde_json::json!({"username": "admin", "target": "10.0.0.1"}), + }; + let key = extract_credential_key(&call).expect("key extracted"); + assert_eq!(key, "admin@unknown"); +} + +#[test] +fn extract_credential_key_uses_unknown_when_domain_empty() { + let call = ares_llm::ToolCall { + id: "1".into(), + name: "kerberoast".into(), + arguments: serde_json::json!({"username": "user1", "domain": ""}), + }; + let key = extract_credential_key(&call).expect("key extracted"); + assert_eq!(key, "user1@unknown"); +} + +#[test] +fn extract_credential_key_recognizes_lateral_tools() { + for tool in ["smbexec", "psexec", "wmiexec", "dcomexec", "atexec"] { + let call = ares_llm::ToolCall { + id: "1".into(), + name: tool.into(), + arguments: serde_json::json!({"username": "u", "domain": "d", "target": "x"}), + }; + assert_eq!( + extract_credential_key(&call).as_deref(), + Some("u@d"), + "tool {tool} should be auth-bearing" + ); + } +} + +#[test] +fn extract_credential_key_recognizes_netexec_tools() { + for tool in [ + "ldap_search_descriptions", + "username_as_password", + "gpp_password_finder", + "sysvol_script_search", + "password_policy", + "laps_dump", + "smbclient_spider", + "check_credman_entries", + "check_autologon_registry", + "domain_admin_checker", + "gmsa_dump_passwords", + ] { + let call = ares_llm::ToolCall { + id: "1".into(), + name: tool.into(), + arguments: serde_json::json!({"username": "u", "domain": "d"}), + }; + assert_eq!( + extract_credential_key(&call).as_deref(), + Some("u@d"), + "tool {tool} should be auth-bearing" + ); + } +} + +#[test] +fn extract_credential_key_recognizes_impacket_tools() { + for tool in [ + "secretsdump", + "secretsdump_kerberos", + "kerberoast", + "asrep_roast", + "lsassy", + "ntds_dit_extract", + ] { + let call = ares_llm::ToolCall { + id: "1".into(), + name: tool.into(), + arguments: serde_json::json!({"username": "u", "domain": "d"}), + }; + assert_eq!( + extract_credential_key(&call).as_deref(), + Some("u@d"), + "tool {tool} should be auth-bearing" + ); + } +} + +#[test] +fn cross_role_routing_lateral_movement_stays() { + // Lateral tools should stay on the calling role + assert_eq!(resolve_queue_role("lateral", "smbexec"), "lateral"); + assert_eq!(resolve_queue_role("lateral", "psexec"), "lateral"); + assert_eq!(resolve_queue_role("lateral", "wmiexec"), "lateral"); +} + +#[test] +fn cross_role_routing_native_recon_tools() { + // Pure recon tools (not in RECON_ROUTED_TOOLS) stay on whatever role calls them. + assert_eq!(resolve_queue_role("custom", "nmap_scan"), "custom"); + assert_eq!(resolve_queue_role("custom", "secretsdump"), "custom"); +} + +#[test] +fn tool_exec_request_omits_traceparent_when_none() { + let req = ToolExecRequest { + call_id: "c".into(), + task_id: "t".into(), + tool_name: "nmap_scan".into(), + arguments: serde_json::json!({}), + traceparent: None, + operation_id: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("traceparent")); + assert!(!json.contains("operation_id")); +} + +#[test] +fn tool_exec_request_includes_traceparent_when_some() { + let req = ToolExecRequest { + call_id: "c".into(), + task_id: "t".into(), + tool_name: "nmap_scan".into(), + arguments: serde_json::json!({}), + traceparent: Some("00-trace-span-01".into()), + operation_id: Some("op-123".into()), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("traceparent")); + assert!(json.contains("00-trace-span-01")); + assert!(json.contains("op-123")); +} + +#[test] +fn tool_exec_response_with_discoveries_field() { + let json = r#"{ + "call_id":"c", + "output":"out", + "error":null, + "discoveries":{"hosts":[{"ip":"10.0.0.1"}]} + }"#; + let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); + assert!(resp.discoveries.is_some()); + let disc = resp.discoveries.unwrap(); + assert_eq!(disc["hosts"][0]["ip"], "10.0.0.1"); +} + +#[test] +fn tool_exec_response_default_discoveries_none() { + // discoveries field is optional with #[serde(default)] + let json = r#"{"call_id":"c","output":"out","error":null}"#; + let resp: ToolExecResponse = serde_json::from_str(json).unwrap(); + assert!(resp.discoveries.is_none()); +} + +#[tokio::test] +async fn auth_throttle_allows_under_limit() { + let throttle = AuthThrottle::new(3, std::time::Duration::from_secs(60)); + let start = std::time::Instant::now(); + throttle.acquire("admin@contoso").await; + throttle.acquire("admin@contoso").await; + throttle.acquire("admin@contoso").await; + // All three within window — should not have slept + assert!(start.elapsed() < std::time::Duration::from_millis(500)); +} + +#[tokio::test] +async fn auth_throttle_separate_credentials_dont_interfere() { + let throttle = AuthThrottle::new(2, std::time::Duration::from_secs(60)); + let start = std::time::Instant::now(); + throttle.acquire("user1@d").await; + throttle.acquire("user1@d").await; + // user1 is at limit, but user2 should be free + throttle.acquire("user2@d").await; + throttle.acquire("user2@d").await; + assert!(start.elapsed() < std::time::Duration::from_millis(500)); +} + +#[tokio::test] +async fn auth_throttle_window_pruning_allows_more_after_expiry() { + // 2 attempts in a 100ms window + let throttle = AuthThrottle::new(2, std::time::Duration::from_millis(100)); + throttle.acquire("u@d").await; + throttle.acquire("u@d").await; + // Sleep past the window + tokio::time::sleep(std::time::Duration::from_millis(150)).await; + let start = std::time::Instant::now(); + // Old attempts pruned — this should not block + throttle.acquire("u@d").await; + assert!(start.elapsed() < std::time::Duration::from_millis(50)); +} + +fn mock_queue() -> TaskQueueCore { + TaskQueueCore::from_connection(MockRedisConnection::new()) +} + +#[tokio::test] +async fn push_realtime_discoveries_pushes_hosts() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "hosts": [ + {"ip": "192.168.58.10", "hostname": "dc01"}, + {"ip": "192.168.58.11", "hostname": "ws01"} + ] + }); + let args = serde_json::json!({}); + + push_realtime_discoveries(&q, "op-1", &discoveries, "nmap_scan", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-1"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + assert_eq!(entries.len(), 2); + let parsed0: serde_json::Value = serde_json::from_str(&entries[0]).unwrap(); + assert_eq!(parsed0["type"], "host"); + assert_eq!(parsed0["source_tool"], "nmap_scan"); + assert!(parsed0["data"]["ip"].is_string()); +} + +#[tokio::test] +async fn push_realtime_discoveries_pushes_credentials_with_input_context() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "credentials": [ + {"username": "svc_admin", "password": "P@ss"} + ] + }); + let args = serde_json::json!({ + "username": "Administrator", + "domain": "contoso.local" + }); + + push_realtime_discoveries(&q, "op-2", &discoveries, "secretsdump", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-2"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + assert_eq!(entries.len(), 1); + let parsed: serde_json::Value = serde_json::from_str(&entries[0]).unwrap(); + assert_eq!(parsed["type"], "credential"); + assert_eq!(parsed["input_username"], "Administrator"); + assert_eq!(parsed["input_domain"], "contoso.local"); +} + +#[tokio::test] +async fn push_realtime_discoveries_handles_multiple_types() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "hosts": [{"ip": "10.0.0.1"}], + "credentials": [{"username": "u"}], + "hashes": [{"hash": "aad3..."}], + "vulnerabilities": [{"id": "CVE-1"}], + "shares": [{"name": "C$"}], + "discovered_users": [{"username": "u2"}], + "trusted_domains": [{"name": "child.contoso"}] + }); + let args = serde_json::json!({}); + + push_realtime_discoveries(&q, "op-3", &discoveries, "tool", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-3"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + assert_eq!(entries.len(), 7); + let types: Vec = entries + .iter() + .map(|e| { + serde_json::from_str::(e).unwrap()["type"] + .as_str() + .unwrap() + .to_string() + }) + .collect(); + for expected in [ + "host", + "credential", + "hash", + "vulnerability", + "share", + "user", + "trust", + ] { + assert!( + types.contains(&expected.to_string()), + "missing type: {expected}" + ); + } +} + +#[tokio::test] +async fn push_realtime_discoveries_skips_non_array_fields() { + let q = mock_queue(); + // hosts is a string instead of an array — should be skipped + let discoveries = serde_json::json!({ + "hosts": "not-an-array", + "credentials": [{"username": "u"}] + }); + let args = serde_json::json!({}); + + push_realtime_discoveries(&q, "op-4", &discoveries, "tool", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-4"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + assert_eq!(entries.len(), 1); + let parsed: serde_json::Value = serde_json::from_str(&entries[0]).unwrap(); + assert_eq!(parsed["type"], "credential"); +} + +#[tokio::test] +async fn push_realtime_discoveries_no_input_context_when_args_lack_username() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "hosts": [{"ip": "10.0.0.1"}] + }); + let args = serde_json::json!({"target": "10.0.0.0/24"}); + + push_realtime_discoveries(&q, "op-5", &discoveries, "nmap_scan", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-5"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + assert_eq!(entries.len(), 1); + let parsed: serde_json::Value = serde_json::from_str(&entries[0]).unwrap(); + assert!(parsed.get("input_username").is_none()); + assert!(parsed.get("input_domain").is_none()); +} + +#[tokio::test] +async fn push_realtime_discoveries_uses_user_alias_when_username_missing() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "hosts": [{"ip": "10.0.0.1"}] + }); + // Some tools call it "user" instead of "username" + let args = serde_json::json!({"user": "fallback_user", "domain": "d"}); + + push_realtime_discoveries(&q, "op-6", &discoveries, "tool", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-6"); + let entries: Vec = conn.lrange(&key, 0, -1).await.unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&entries[0]).unwrap(); + assert_eq!(parsed["input_username"], "fallback_user"); + assert_eq!(parsed["input_domain"], "d"); +} + +#[tokio::test] +async fn push_realtime_discoveries_no_op_when_no_known_keys() { + let q = mock_queue(); + let discoveries = serde_json::json!({ + "unknown_field": [{"x": 1}] + }); + let args = serde_json::json!({}); + + push_realtime_discoveries(&q, "op-7", &discoveries, "tool", &args).await; + + let mut conn = q.connection(); + let key = format!("{DISCOVERY_KEY_PREFIX}:op-7"); + let exists: bool = conn.exists(&key).await.unwrap(); + assert!(!exists, "should not have created discovery list"); +} diff --git a/ares-cli/src/worker/task_loop/result_handler.rs b/ares-cli/src/worker/task_loop/result_handler.rs index d1643fcf..80e7f824 100644 --- a/ares-cli/src/worker/task_loop/result_handler.rs +++ b/ares-cli/src/worker/task_loop/result_handler.rs @@ -2,6 +2,7 @@ use bytes::Bytes; use chrono::Utc; +use redis::aio::ConnectionLike; use redis::AsyncCommands; use tracing::{debug, error, info, warn}; @@ -185,12 +186,15 @@ pub async fn process_task( } /// Set task status in Redis with TTL. -async fn set_task_status( - conn: &mut redis::aio::ConnectionManager, +async fn set_task_status( + conn: &mut C, task_id: &str, status: &str, extra_fields: &serde_json::Value, -) -> anyhow::Result<()> { +) -> anyhow::Result<()> +where + C: ConnectionLike + Send + Sync, +{ let key = format!("{TASK_STATUS_PREFIX}:{task_id}"); let mut data = extra_fields.clone(); if let Some(obj) = data.as_object_mut() { @@ -208,3 +212,63 @@ async fn set_task_status( .await?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use ares_core::state::mock_redis::MockRedisConnection; + + #[tokio::test] + async fn set_task_status_writes_status_and_timestamps() { + let mut conn = MockRedisConnection::new(); + let extra = serde_json::json!({ + "operation_id": "op-1", + "role": "recon", + "agent_name": "agent-0", + }); + set_task_status(&mut conn, "task-123", "running", &extra) + .await + .unwrap(); + + let raw: Option = conn.get("ares:task_status:task-123").await.unwrap(); + let raw = raw.expect("status written"); + let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); + assert_eq!(v["status"], "running"); + assert_eq!(v["operation_id"], "op-1"); + assert_eq!(v["role"], "recon"); + assert!(v["updated_at"].is_string()); + } + + #[tokio::test] + async fn set_task_status_overwrites_status_field_in_extra() { + let mut conn = MockRedisConnection::new(); + // If extra has a "status" key, set_task_status overrides it + let extra = serde_json::json!({ + "status": "pending", + "task_type": "recon", + }); + set_task_status(&mut conn, "t-1", "completed", &extra) + .await + .unwrap(); + + let raw: Option = conn.get("ares:task_status:t-1").await.unwrap(); + let v: serde_json::Value = serde_json::from_str(&raw.unwrap()).unwrap(); + assert_eq!(v["status"], "completed"); + assert_eq!(v["task_type"], "recon"); + } + + #[tokio::test] + async fn set_task_status_handles_non_object_extra() { + let mut conn = MockRedisConnection::new(); + // If extra isn't an object, status/updated_at can't be merged but + // we should not panic — the value is serialized as-is. + let extra = serde_json::json!("not-an-object"); + set_task_status(&mut conn, "t-2", "running", &extra) + .await + .unwrap(); + + let raw: Option = conn.get("ares:task_status:t-2").await.unwrap(); + // Stored as the raw string, no merge happened + assert_eq!(raw.as_deref(), Some("\"not-an-object\"")); + } +} diff --git a/ares-core/src/state/blue_task_queue.rs b/ares-core/src/state/blue_task_queue.rs index b888eb1a..9b58f9b9 100644 --- a/ares-core/src/state/blue_task_queue.rs +++ b/ares-core/src/state/blue_task_queue.rs @@ -19,7 +19,7 @@ use anyhow::{Context, Result}; use bytes::Bytes; use chrono::Utc; use futures::StreamExt; -use redis::aio::ConnectionManagerConfig; +use redis::aio::{ConnectionLike, ConnectionManager, ConnectionManagerConfig}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use tracing::{debug, info, warn}; @@ -86,11 +86,17 @@ impl BlueTaskResult { } /// Blue team task queue — NATS for queues, Redis for state. -pub struct BlueTaskQueue { - conn: redis::aio::ConnectionManager, +/// +/// Generic over the Redis backend so unit tests can use a mock; `nats` is +/// `None` in tests that don't exercise queue methods. +pub struct BlueTaskQueueCore { + conn: C, nats: Option, } +/// Production blue team task queue. +pub type BlueTaskQueue = BlueTaskQueueCore; + impl BlueTaskQueue { /// Connect to Redis only (state methods work, queue methods will error). /// Used by callers that only need heartbeat/investigation registration. @@ -111,20 +117,28 @@ impl BlueTaskQueue { Ok(q) } - pub fn from_conn(conn: redis::aio::ConnectionManager) -> Self { + pub fn from_conn(conn: ConnectionManager) -> Self { Self { conn, nats: None } } - pub fn from_parts(conn: redis::aio::ConnectionManager, nats: NatsBroker) -> Self { + pub fn from_parts(conn: ConnectionManager, nats: NatsBroker) -> Self { Self { conn, nats: Some(nats), } } - pub fn conn_mut(&mut self) -> &mut redis::aio::ConnectionManager { + pub fn conn_mut(&mut self) -> &mut ConnectionManager { &mut self.conn } +} + +impl BlueTaskQueueCore { + /// Construct from a Redis backend only — used by unit tests that don't + /// exercise queue methods. Queue methods will return an error. + pub fn from_connection(conn: C) -> Self { + Self { conn, nats: None } + } fn nats(&self) -> Result<&NatsBroker> { self.nats @@ -458,6 +472,11 @@ impl BlueTaskQueue { #[cfg(test)] mod tests { use super::*; + use crate::state::mock_redis::MockRedisConnection; + + fn mock_queue() -> BlueTaskQueueCore { + BlueTaskQueueCore::from_connection(MockRedisConnection::new()) + } #[test] fn success_sets_success_true_and_stores_result() { @@ -497,4 +516,201 @@ mod tests { assert!(chrono::DateTime::parse_from_rfc3339(&success.completed_at).is_ok()); assert!(chrono::DateTime::parse_from_rfc3339(&failure.completed_at).is_ok()); } + + #[test] + fn blue_task_message_serialization_roundtrip() { + let msg = BlueTaskMessage { + task_id: "btask-1".into(), + investigation_id: "inv-1".into(), + task_type: "log_search".into(), + role: "triage".into(), + params: serde_json::json!({"query": "alertname=Foo"}), + created_at: "2026-04-29T20:00:00Z".into(), + }; + let json = serde_json::to_string(&msg).unwrap(); + let parsed: BlueTaskMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.task_id, "btask-1"); + assert_eq!(parsed.investigation_id, "inv-1"); + assert_eq!(parsed.role, "triage"); + assert_eq!(parsed.params["query"], "alertname=Foo"); + } + + #[test] + fn blue_task_result_skips_none_fields_in_serialization() { + let r = BlueTaskResult::success("t", "i", serde_json::json!({"ok": true}), "a"); + let json = serde_json::to_string(&r).unwrap(); + // error is None and has skip_serializing_if + assert!(!json.contains("\"error\"")); + } + + #[test] + fn blue_task_result_failure_omits_result_field() { + let r = BlueTaskResult::failure("t", "i", "boom".into(), "a"); + let json = serde_json::to_string(&r).unwrap(); + assert!(!json.contains("\"result\"")); + assert!(json.contains("\"error\"")); + } + + #[tokio::test] + async fn send_heartbeat_roundtrip() { + let mut q = mock_queue(); + q.send_heartbeat("blue-agent-1", "idle", None, "triage", None) + .await + .unwrap(); + + let hb = q + .get_heartbeat("blue-agent-1") + .await + .unwrap() + .expect("heartbeat present"); + assert_eq!(hb["status"], "idle"); + assert_eq!(hb["role"], "triage"); + assert!(hb["current_task"].is_null()); + assert!(hb["timestamp"].is_string()); + } + + #[tokio::test] + async fn send_heartbeat_with_current_task_and_investigation() { + let mut q = mock_queue(); + q.send_heartbeat( + "blue-agent-2", + "busy", + Some("btask-9"), + "log_analyst", + Some("inv-42"), + ) + .await + .unwrap(); + + let hb = q.get_heartbeat("blue-agent-2").await.unwrap().unwrap(); + assert_eq!(hb["status"], "busy"); + assert_eq!(hb["current_task"], "btask-9"); + assert_eq!(hb["investigation_id"], "inv-42"); + } + + #[tokio::test] + async fn get_heartbeat_returns_none_when_missing() { + let mut q = mock_queue(); + assert!(q.get_heartbeat("ghost").await.unwrap().is_none()); + } + + #[tokio::test] + async fn register_investigation_then_discover() { + let mut q = mock_queue(); + let alert = serde_json::json!({"alertname": "SuspiciousLogon"}); + q.register_investigation("inv-100", &alert, "openai/gpt-4.1-mini") + .await + .unwrap(); + + let active = q.discover_active_investigation().await.unwrap(); + assert_eq!(active.as_deref(), Some("inv-100")); + } + + #[tokio::test] + async fn discover_active_investigation_returns_none_when_empty() { + let mut q = mock_queue(); + assert!(q.discover_active_investigation().await.unwrap().is_none()); + } + + #[tokio::test] + async fn get_investigation_alert_returns_registered_alert() { + let mut q = mock_queue(); + let alert = serde_json::json!({ + "alertname": "FailedLogons", + "severity": "high", + }); + q.register_investigation("inv-200", &alert, "openai/gpt-4.1-mini") + .await + .unwrap(); + + let stored = q.get_investigation_alert("inv-200").await.unwrap().unwrap(); + assert_eq!(stored["alertname"], "FailedLogons"); + assert_eq!(stored["severity"], "high"); + } + + #[tokio::test] + async fn get_investigation_alert_returns_none_for_unknown_id() { + let mut q = mock_queue(); + assert!(q + .get_investigation_alert("nonexistent") + .await + .unwrap() + .is_none()); + } + + #[tokio::test] + async fn get_investigation_model_returns_registered_model() { + let mut q = mock_queue(); + q.register_investigation( + "inv-300", + &serde_json::json!({}), + "anthropic/claude-sonnet-4-5", + ) + .await + .unwrap(); + + let model = q.get_investigation_model("inv-300").await.unwrap(); + assert_eq!(model.as_deref(), Some("anthropic/claude-sonnet-4-5")); + } + + #[tokio::test] + async fn get_investigation_model_returns_none_for_unknown_id() { + let mut q = mock_queue(); + assert!(q + .get_investigation_model("nonexistent") + .await + .unwrap() + .is_none()); + } + + #[tokio::test] + async fn submit_task_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let task = BlueTaskMessage { + task_id: "t".into(), + investigation_id: "i".into(), + task_type: "log_search".into(), + role: "triage".into(), + params: serde_json::json!({}), + created_at: "2026-04-29T00:00:00Z".into(), + }; + let err = q.submit_task(&task).await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn poll_global_task_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let err = q.poll_global_task("triage", 1.0).await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn send_result_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let r = BlueTaskResult::success("t", "i", serde_json::Value::Null, "a"); + let err = q.send_result(&r).await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn check_result_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let err = q.check_result("t").await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn pop_investigation_request_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let err = q.pop_investigation_request(1.0).await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn queue_length_errors_when_no_nats_configured() { + let mut q = mock_queue(); + let err = q.queue_length("triage").await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } } From 0398259bc1e6ae1b7df948256b21d5dc785e1bf1 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Wed, 29 Apr 2026 16:45:42 -0600 Subject: [PATCH 3/5] test: add comprehensive unit tests for orchestrator and worker modules **Added:** - Added extensive unit tests for `task_queue` covering result checks, batch error handling, lock extension, status management, and serialization behaviors - Added tests for `redis_dispatcher.rs` helper functions, including dispatch error and timeout result formatting, and subject/stream configuration - Added tests for `is_transient_broker_error` logic, task status TTL, message priority overrides, and task result serialization in `task_loop` - Added tests for `tool_executor` helpers: unavailable tool responses, error classification, and discoveries serialization logic - Added tests for NATS subject and stream formatting, retention, and uniqueness in the core NATS module, including environment variable fallback handling **Changed:** - Refactored tool dispatcher to use helper functions for error and timeout result construction, ensuring consistent formatting and easier testability - Replaced inline connection error detection in worker task loop with `is_transient_broker_error` helper for improved maintainability - Refactored tool executor to use helper functions for unavailable tool responses and discoveries serialization, improving clarity and test coverage **Removed:** - Removed inline duplicate logic for error result and discoveries handling in tool dispatcher and tool executor, consolidating into reusable functions --- ares-cli/src/orchestrator/task_queue.rs | 189 ++++++++++++++++++ .../tool_dispatcher/redis_dispatcher.rs | 43 ++-- .../src/orchestrator/tool_dispatcher/tests.rs | 57 ++++++ ares-cli/src/worker/task_loop/mod.rs | 118 +++++++++-- ares-cli/src/worker/tool_executor.rs | 124 ++++++++++-- ares-core/src/nats.rs | 128 ++++++++++++ 6 files changed, 615 insertions(+), 44 deletions(-) diff --git a/ares-cli/src/orchestrator/task_queue.rs b/ares-cli/src/orchestrator/task_queue.rs index 2982072c..9d61539a 100644 --- a/ares-cli/src/orchestrator/task_queue.rs +++ b/ares-cli/src/orchestrator/task_queue.rs @@ -718,4 +718,193 @@ mod tests { .unwrap_err(); assert!(err.to_string().contains("NATS")); } + + #[tokio::test] + async fn has_pending_result_always_false() { + // Documented "always returns false" semantic kept for API compat with + // the old Redis implementation. + let q = mock_queue(); + for tid in ["t1", "t2", "anything"] { + assert!(!q.has_pending_result(tid).await.unwrap()); + } + } + + #[tokio::test] + async fn check_result_errors_without_nats() { + let q = mock_queue(); + let err = q.check_result("t1").await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn check_results_batch_empty_returns_empty_map() { + let q = mock_queue(); + let map = q.check_results_batch(&[]).await.unwrap(); + assert!(map.is_empty()); + } + + #[tokio::test] + async fn check_results_batch_swallows_per_task_errors() { + // Without NATS, each per-task fetch errors. The batch method logs + // and treats those as None rather than propagating. + let q = mock_queue(); + let ids = vec!["t1".to_string(), "t2".to_string(), "t3".to_string()]; + let map = q.check_results_batch(&ids).await.unwrap(); + assert_eq!(map.len(), 3); + for id in &ids { + assert!(map.contains_key(id)); + assert!(map.get(id).unwrap().is_none()); + } + } + + #[tokio::test] + async fn send_result_errors_without_nats() { + let q = mock_queue(); + let r = TaskResult { + task_id: "t1".into(), + success: true, + result: None, + error: None, + completed_at: Some(Utc::now()), + worker_pod: None, + agent_name: None, + }; + let err = q.send_result("t1", &r).await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn publish_state_update_errors_without_nats() { + let q = mock_queue(); + let err = q.publish_state_update("op-1").await.unwrap_err(); + assert!(err.to_string().contains("NATS")); + } + + #[tokio::test] + async fn nats_broker_is_none_for_mock_queue() { + let q = mock_queue(); + assert!(q.nats_broker().is_none()); + } + + #[tokio::test] + async fn connection_returns_independent_clone() { + // The connection() accessor should hand back a clone the caller can + // hold without invalidating the queue's own conn. + let q = mock_queue(); + let mut c = q.connection(); + let _: () = c.set_ex::<_, _, ()>("x", "y", 30).await.unwrap(); + // queue still works after caller used the cloned conn + q.set_task_status("after", "pending").await.unwrap(); + let raw = q.get_task_status("after").await.unwrap().unwrap(); + let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); + assert_eq!(v["status"], "pending"); + } + + #[tokio::test] + async fn set_task_status_pending_does_not_set_started_or_ended() { + let q = mock_queue(); + q.set_task_status("t1", "pending").await.unwrap(); + let raw = q.get_task_status("t1").await.unwrap().unwrap(); + let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); + assert_eq!(v["status"], "pending"); + assert!(v.get("started_at").is_none()); + assert!(v.get("ended_at").is_none()); + } + + #[tokio::test] + async fn set_task_status_in_progress_does_not_overwrite_started_at() { + let q = mock_queue(); + // First in_progress sets started_at + q.set_task_status("t1", "in_progress").await.unwrap(); + let raw1 = q.get_task_status("t1").await.unwrap().unwrap(); + let v1: serde_json::Value = serde_json::from_str(&raw1).unwrap(); + let started_first = v1["started_at"].as_str().unwrap().to_string(); + + // sleep briefly so timestamps would differ + tokio::time::sleep(Duration::from_millis(20)).await; + + // Second in_progress preserves the original started_at + q.set_task_status("t1", "in_progress").await.unwrap(); + let raw2 = q.get_task_status("t1").await.unwrap().unwrap(); + let v2: serde_json::Value = serde_json::from_str(&raw2).unwrap(); + assert_eq!(v2["started_at"].as_str().unwrap(), started_first); + assert_ne!(v1["updated_at"], v2["updated_at"]); + } + + #[tokio::test] + async fn set_task_status_full_without_payload_omits_payload_field() { + let q = mock_queue(); + q.set_task_status_full("t1", "pending", "op-1", "scanner", "recon", None) + .await + .unwrap(); + let raw = q.get_task_status("t1").await.unwrap().unwrap(); + let v: serde_json::Value = serde_json::from_str(&raw).unwrap(); + assert!(v.get("payload").is_none()); + assert!(v.get("started_at").is_none()); // pending != in_progress + } + + #[tokio::test] + async fn extend_lock_against_mock_redis_succeeds() { + // Mock EXPIRE always reports success; this test pins the call shape + // (i64 TTL conversion, Result return type). + let q = mock_queue(); + let ok = q + .extend_lock("op-1", Duration::from_secs(60)) + .await + .unwrap(); + assert!(ok); + } + + #[tokio::test] + async fn try_acquire_lock_uses_separate_keys_per_operation() { + let q = mock_queue(); + assert!(q + .try_acquire_lock("op-a", Duration::from_secs(30)) + .await + .unwrap()); + // Different op id is independent of op-a + assert!(q + .try_acquire_lock("op-b", Duration::from_secs(30)) + .await + .unwrap()); + } + + #[test] + fn task_message_default_priority_in_constants() { + assert_eq!(default_priority(), 5); + } + + #[test] + fn task_message_serialize_includes_callback_queue() { + let msg = TaskMessage { + task_id: "t".into(), + task_type: "recon".into(), + source_agent: "orch".into(), + target_agent: "scanner".into(), + payload: serde_json::json!({}), + priority: 5, + created_at: None, + callback_queue: Some("ares.tasks.results.t".to_string()), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("ares.tasks.results.t")); + } + + #[test] + fn task_result_serializes_none_fields_as_null() { + let r = TaskResult { + task_id: "t".into(), + success: true, + result: None, + error: None, + completed_at: None, + worker_pod: None, + agent_name: None, + }; + let v: serde_json::Value = serde_json::to_value(&r).unwrap(); + assert!(v["result"].is_null()); + assert!(v["error"].is_null()); + assert!(v["worker_pod"].is_null()); + assert!(v["agent_name"].is_null()); + } } diff --git a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs index 96db2961..377123d9 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs @@ -49,6 +49,33 @@ impl RedisToolDispatcher { } } +/// Synthetic ToolExecResult returned when the NATS request itself fails +/// (broker disconnect, no responders, etc.). Free function so the wording +/// is testable and stays in lock-step with the agent-facing error message. +pub(super) fn dispatch_error_result( + tool_name: &str, + err: impl std::fmt::Display, +) -> ToolExecResult { + ToolExecResult { + output: String::new(), + error: Some(format!("Tool '{tool_name}' dispatch error: {err}")), + discoveries: None, + } +} + +/// Synthetic ToolExecResult returned when the request times out waiting for +/// a worker to reply. +pub(super) fn dispatch_timeout_result(tool_name: &str, timeout: Duration) -> ToolExecResult { + ToolExecResult { + output: String::new(), + error: Some(format!( + "Tool '{tool_name}' timed out after {}s", + timeout.as_secs() + )), + discoveries: None, + } +} + #[async_trait::async_trait] impl ares_llm::ToolDispatcher for RedisToolDispatcher { async fn dispatch_tool( @@ -118,11 +145,7 @@ impl ares_llm::ToolDispatcher for RedisToolDispatcher { err = %e, "NATS request failed" ); - return Ok(ToolExecResult { - output: String::new(), - error: Some(format!("Tool '{}' dispatch error: {e}", call.name)), - discoveries: None, - }); + return Ok(dispatch_error_result(&call.name, e)); } Err(_) => { warn!( @@ -131,15 +154,7 @@ impl ares_llm::ToolDispatcher for RedisToolDispatcher { timeout_secs = timeout.as_secs(), "Tool execution timed out" ); - return Ok(ToolExecResult { - output: String::new(), - error: Some(format!( - "Tool '{}' timed out after {}s", - call.name, - timeout.as_secs() - )), - discoveries: None, - }); + return Ok(dispatch_timeout_result(&call.name, timeout)); } }; diff --git a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs index 7dcbda0f..a81f65b8 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs @@ -499,3 +499,60 @@ async fn push_realtime_discoveries_no_op_when_no_known_keys() { let exists: bool = conn.exists(&key).await.unwrap(); assert!(!exists, "should not have created discovery list"); } + +#[test] +fn dispatch_error_result_includes_tool_name_and_underlying_error() { + use redis_dispatcher::dispatch_error_result; + let r = dispatch_error_result("nmap_scan", "no responders available"); + assert_eq!(r.output, ""); + assert!(r.discoveries.is_none()); + let err = r.error.as_deref().unwrap(); + assert!(err.contains("nmap_scan"), "missing tool name in {err}"); + assert!(err.contains("dispatch error")); + assert!(err.contains("no responders available")); +} + +#[test] +fn dispatch_error_result_handles_anyhow_errors() { + use redis_dispatcher::dispatch_error_result; + let upstream = anyhow::anyhow!("upstream broken pipe"); + let r = dispatch_error_result("certipy", upstream); + assert!(r.error.unwrap().contains("upstream broken pipe")); +} + +#[test] +fn dispatch_timeout_result_renders_seconds() { + use redis_dispatcher::dispatch_timeout_result; + let r = dispatch_timeout_result("hashcat", std::time::Duration::from_secs(1500)); + assert_eq!(r.output, ""); + assert!(r.discoveries.is_none()); + let err = r.error.as_deref().unwrap(); + assert!(err.contains("hashcat")); + assert!(err.contains("1500s")); + assert!(err.contains("timed out")); +} + +#[test] +fn dispatch_timeout_result_zero_seconds_still_well_formed() { + use redis_dispatcher::dispatch_timeout_result; + let r = dispatch_timeout_result("nmap", std::time::Duration::from_secs(0)); + assert!(r.error.unwrap().contains("0s")); +} + +#[test] +fn default_tool_timeout_is_25_minutes() { + // 1500s = 25min — must exceed worst-case hashcat queue + run time. + assert_eq!(DEFAULT_TOOL_TIMEOUT_SECS, 25 * 60); +} + +#[test] +fn dispatch_error_and_timeout_results_share_shape() { + // Both helpers must produce the same shape so the agent loop can treat + // them uniformly: empty output, no discoveries, non-empty error. + use redis_dispatcher::{dispatch_error_result, dispatch_timeout_result}; + let e = dispatch_error_result("t", "oops"); + let t = dispatch_timeout_result("t", std::time::Duration::from_secs(60)); + assert_eq!(e.output, t.output); + assert!(e.discoveries.is_none() && t.discoveries.is_none()); + assert!(e.error.is_some() && t.error.is_some()); +} diff --git a/ares-cli/src/worker/task_loop/mod.rs b/ares-cli/src/worker/task_loop/mod.rs index 163f1202..d9637f25 100644 --- a/ares-cli/src/worker/task_loop/mod.rs +++ b/ares-cli/src/worker/task_loop/mod.rs @@ -94,18 +94,7 @@ pub async fn run_task_loop( retry_delay = Duration::from_secs(1); } Err(e) => { - let error_str = e.to_string().to_lowercase(); - let is_conn_error = [ - "connection", - "connect", - "closed", - "timeout", - "broken pipe", - "reset", - "no responders", - ] - .iter() - .any(|kw| error_str.contains(kw)); + let is_conn_error = is_transient_broker_error(&e.to_string()); if is_conn_error { warn!( @@ -137,6 +126,23 @@ pub(crate) const fn task_status_ttl() -> i64 { TASK_STATUS_TTL } +/// Classify a broker-side error as a transient connectivity/timeout failure +/// (worth retrying with backoff) versus a logic error (worth surfacing fast). +fn is_transient_broker_error(err: &str) -> bool { + let lower = err.to_lowercase(); + [ + "connection", + "connect", + "closed", + "timeout", + "broken pipe", + "reset", + "no responders", + ] + .iter() + .any(|kw| lower.contains(kw)) +} + /// Ensure a durable pull consumer exists for the given (role, urgency). async fn ensure_role_consumer( nats: &NatsBroker, @@ -275,4 +281,92 @@ mod tests { let json = serde_json::to_string(&r).unwrap(); assert!(!json.contains("\"error\"")); } + + #[test] + fn is_transient_broker_error_recognizes_connection_terms() { + for kw in [ + "connection refused", + "Connection reset by peer", + "broken pipe", + "request timeout", + "stream closed unexpectedly", + "no responders available", + "Failed to connect to NATS", + ] { + assert!( + is_transient_broker_error(kw), + "expected {kw:?} to be classified as transient" + ); + } + } + + #[test] + fn is_transient_broker_error_rejects_logic_errors() { + for kw in [ + "deserialize TaskMessage: missing field", + "JetStream consumer not found", + "stream ARES_TASKS does not exist", + "permission denied: not authorized", + "ack returned NACK", + ] { + // None of these contain the transient keywords. + assert!( + !is_transient_broker_error(kw), + "expected {kw:?} to be classified as non-transient" + ); + } + } + + #[test] + fn is_transient_broker_error_is_case_insensitive() { + assert!(is_transient_broker_error("BROKEN PIPE")); + assert!(is_transient_broker_error("Timeout while waiting")); + assert!(is_transient_broker_error("No Responders")); + } + + #[test] + fn task_status_ttl_is_24_hours() { + assert_eq!(task_status_ttl(), 60 * 60 * 24); + } + + #[test] + fn task_message_with_explicit_priority_overrides_default() { + let json = r#"{ + "task_id": "t1", + "task_type": "recon", + "source_agent": "orch", + "target_agent": "recon-0", + "payload": {}, + "priority": 1 + }"#; + let msg: TaskMessage = serde_json::from_str(json).unwrap(); + assert_eq!(msg.priority, 1); + } + + #[test] + fn task_result_success_carries_completed_at() { + let r = TaskResult::success( + "t1", + serde_json::json!({"output": "done"}), + "pod-0", + "ares-recon", + ); + let parsed_at = chrono::DateTime::parse_from_rfc3339(r.completed_at.as_deref().unwrap()); + assert!(parsed_at.is_ok()); + } + + #[test] + fn task_result_failure_with_partial_output() { + let partial = serde_json::json!({"partial_output": "ran 3/5 steps"}); + let r = TaskResult::failure( + "t1", + "agent crashed".into(), + Some(partial.clone()), + "pod-0", + "ares-recon", + ); + assert!(!r.success); + assert_eq!(r.error.as_deref(), Some("agent crashed")); + assert_eq!(r.result, Some(partial)); + } } diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs index 09111d85..8e4f6e31 100644 --- a/ares-cli/src/worker/tool_executor.rs +++ b/ares-cli/src/worker/tool_executor.rs @@ -157,6 +157,37 @@ pub async fn run_tool_exec_loop( } } +/// Build the error response sent when a tool was previously found to be +/// unavailable on this worker (binary missing). Surfaced as a free function +/// so the wording stays in lock-step with tests. +fn unavailable_tool_response(tool_name: &str, call_id: &str) -> ToolExecResponse { + ToolExecResponse { + call_id: call_id.to_string(), + output: String::new(), + error: Some(format!( + "Tool '{tool_name}' is not installed on this worker. \ + Do not call this tool again — it failed to spawn previously." + )), + discoveries: None, + } +} + +/// Tool execution failures that indicate the binary is not present should +/// be marked unavailable so we don't keep retrying it. +fn is_tool_unavailable_error(err_str: &str) -> bool { + err_str.contains("failed to spawn") || err_str.contains("not installed") +} + +/// Convert a parsed-discoveries value into `Some(_)` only when it carries +/// at least one entry — avoids serialising an empty `discoveries: {}` blob. +fn discoveries_or_none(parsed: serde_json::Value) -> Option { + if parsed.as_object().is_none_or(|o| o.is_empty()) { + None + } else { + Some(parsed) + } +} + /// Execute a tool call and reply on the NATS inbox. async fn execute_and_respond( client: async_nats::Client, @@ -170,16 +201,7 @@ async fn execute_and_respond( call_id = %request.call_id, "Skipping unavailable tool (previously failed to spawn)" ); - let response = ToolExecResponse { - call_id: request.call_id.clone(), - output: String::new(), - error: Some(format!( - "Tool '{}' is not installed on this worker. \ - Do not call this tool again — it failed to spawn previously.", - request.tool_name - )), - discoveries: None, - }; + let response = unavailable_tool_response(&request.tool_name, &request.call_id); send_reply(&client, reply_to.as_ref(), &response).await; return; } @@ -204,16 +226,11 @@ async fn execute_and_respond( Some(format!("tool exited with code {:?}", output.exit_code)) }; - let discoveries = ares_tools::parsers::parse_tool_output( + let discoveries = discoveries_or_none(ares_tools::parsers::parse_tool_output( &request.tool_name, &raw, &request.arguments, - ); - let discoveries = if discoveries.as_object().is_none_or(|o| o.is_empty()) { - None - } else { - Some(discoveries) - }; + )); if let Some(ref disc) = discoveries { if let Some(obj) = disc.as_object() { @@ -245,7 +262,7 @@ async fn execute_and_respond( } Err(e) => { let err_str = e.to_string(); - if err_str.contains("failed to spawn") || err_str.contains("not installed") { + if is_tool_unavailable_error(&err_str) { warn!( tool = %request.tool_name, "Tool binary not found — marking as unavailable for this session" @@ -543,4 +560,75 @@ mod tests { let result: Result = serde_json::from_str(json); assert!(result.is_err()); } + + #[test] + fn unavailable_tool_response_contains_tool_name() { + let resp = unavailable_tool_response("certipy", "call_42"); + assert_eq!(resp.call_id, "call_42"); + assert_eq!(resp.output, ""); + assert!(resp.discoveries.is_none()); + let err = resp.error.as_deref().unwrap(); + assert!(err.contains("certipy")); + assert!(err.contains("not installed")); + assert!(err.contains("Do not call this tool again")); + } + + #[test] + fn unavailable_tool_response_round_trips_via_json() { + let resp = unavailable_tool_response("hashcat", "abc"); + let json = serde_json::to_string(&resp).unwrap(); + // discoveries omitted when None + assert!(!json.contains("discoveries")); + assert!(json.contains("hashcat")); + } + + #[test] + fn is_tool_unavailable_error_classifies_spawn_failures() { + assert!(is_tool_unavailable_error( + "failed to spawn 'nmap' — is it installed?" + )); + assert!(is_tool_unavailable_error("tool not installed: certipy")); + assert!(is_tool_unavailable_error( + "failed to spawn process: No such file" + )); + } + + #[test] + fn is_tool_unavailable_error_rejects_unrelated_errors() { + assert!(!is_tool_unavailable_error("connection refused")); + assert!(!is_tool_unavailable_error("permission denied")); + assert!(!is_tool_unavailable_error("invalid arguments")); + assert!(!is_tool_unavailable_error("command not found")); // different wording + } + + #[test] + fn discoveries_or_none_drops_empty_object() { + let v = serde_json::json!({}); + assert!(discoveries_or_none(v).is_none()); + } + + #[test] + fn discoveries_or_none_drops_non_object() { + // Arrays / strings / numbers should all be treated as "no discoveries" + assert!(discoveries_or_none(serde_json::json!(null)).is_none()); + assert!(discoveries_or_none(serde_json::json!([])).is_none()); + assert!(discoveries_or_none(serde_json::json!("hi")).is_none()); + assert!(discoveries_or_none(serde_json::json!(42)).is_none()); + } + + #[test] + fn discoveries_or_none_keeps_non_empty_object() { + let v = serde_json::json!({"hosts": [{"ip": "10.0.0.1"}]}); + let kept = discoveries_or_none(v.clone()); + assert!(kept.is_some()); + assert_eq!(kept.unwrap(), v); + } + + #[test] + fn discoveries_or_none_keeps_empty_array_inside_object() { + // Object with even an empty array is still non-empty at the top level + let v = serde_json::json!({"credentials": []}); + let kept = discoveries_or_none(v.clone()); + assert_eq!(kept, Some(v)); + } } diff --git a/ares-core/src/nats.rs b/ares-core/src/nats.rs index c3107339..559b2aaa 100644 --- a/ares-core/src/nats.rs +++ b/ares-core/src/nats.rs @@ -366,4 +366,132 @@ mod tests { let spec = StreamSpec::discoveries(); assert_eq!(spec.subjects, vec!["ares.discoveries.>"]); } + + #[test] + fn urgent_task_subject_format() { + assert_eq!(urgent_task_subject("recon"), "ares.tasks.urgent.recon"); + assert_eq!(urgent_task_subject("lateral"), "ares.tasks.urgent.lateral"); + } + + #[test] + fn task_result_subject_format() { + assert_eq!( + task_result_subject("recon_abc123"), + "ares.tasks.results.recon_abc123" + ); + } + + #[test] + fn blue_task_result_subject_format() { + assert_eq!( + blue_task_result_subject("btask_abc"), + "ares.blue.tasks.results.btask_abc" + ); + } + + #[test] + fn subject_prefixes_are_unique() { + // Sanity check that the subject namespaces don't overlap, which would + // cause cross-stream collisions. + let prefixes = [ + TASK_SUBJECT_PREFIX, + TOOL_EXEC_SUBJECT_PREFIX, + BLUE_TASK_SUBJECT_PREFIX, + DEFERRED_SUBJECT_PREFIX, + STATE_UPDATE_SUBJECT_PREFIX, + DISCOVERY_SUBJECT_PREFIX, + ]; + for (i, p1) in prefixes.iter().enumerate() { + for p2 in &prefixes[i + 1..] { + assert!( + !p1.starts_with(p2) && !p2.starts_with(p1), + "subject prefixes {p1} and {p2} overlap" + ); + } + } + } + + #[test] + fn tasks_stream_uses_work_queue_retention_and_file_storage() { + let spec = StreamSpec::tasks(); + let cfg = spec.to_config(); + assert_eq!(cfg.name, "ARES_TASKS"); + assert!(matches!(cfg.retention, RetentionPolicy::WorkQueue)); + assert!(matches!(cfg.storage, StorageType::File)); + // 24h retention + assert_eq!(cfg.max_age, Duration::from_secs(60 * 60 * 24)); + } + + #[test] + fn blue_tasks_stream_to_config_carries_subjects() { + let cfg = StreamSpec::blue_tasks().to_config(); + assert_eq!(cfg.name, "ARES_BLUE_TASKS"); + assert!(cfg.subjects.iter().any(|s| s == BLUE_INVESTIGATION_SUBJECT)); + assert!(cfg.subjects.iter().any(|s| s == "ares.blue.tasks.>")); + assert!(matches!(cfg.retention, RetentionPolicy::WorkQueue)); + } + + #[test] + fn deferred_stream_max_age_is_six_hours() { + let spec = StreamSpec::deferred(); + assert_eq!(spec.max_age, Duration::from_secs(60 * 60 * 6)); + assert!(matches!(spec.storage, StorageType::File)); + } + + #[test] + fn discoveries_stream_max_age_is_twelve_hours() { + let spec = StreamSpec::discoveries(); + assert_eq!(spec.max_age, Duration::from_secs(60 * 60 * 12)); + assert!(matches!(spec.storage, StorageType::File)); + } + + #[test] + fn url_from_env_falls_back_to_default_when_unset() { + // We can't safely toggle process-wide env vars in parallel tests, so + // this only asserts that the function returns a non-empty URL string. + let url = NatsBroker::url_from_env(); + assert!(!url.is_empty()); + // Default contains nats:// scheme + if std::env::var("ARES_NATS_URL").is_err() && std::env::var("NATS_URL").is_err() { + assert_eq!(url, DEFAULT_NATS_URL); + assert!(url.starts_with("nats://")); + } + } + + #[test] + fn task_subject_distinguishes_urgent_from_normal() { + let normal = task_subject("recon"); + let urgent = urgent_task_subject("recon"); + assert_ne!(normal, urgent); + // Both must start with the task prefix + assert!(normal.starts_with(TASK_SUBJECT_PREFIX)); + assert!(urgent.starts_with(TASK_SUBJECT_PREFIX)); + } + + #[test] + fn deferred_subject_includes_both_op_and_type() { + let s = deferred_subject("op-20260429-abc", "lateral"); + assert!(s.contains("op-20260429-abc")); + assert!(s.contains("lateral")); + assert!(s.starts_with(DEFERRED_SUBJECT_PREFIX)); + } + + #[test] + fn stream_names_are_uppercase_and_distinct() { + let names = [ + TASKS_STREAM, + BLUE_TASKS_STREAM, + DEFERRED_STREAM, + DISCOVERIES_STREAM, + ]; + for n in &names { + assert_eq!(*n, n.to_uppercase(), "stream name {n} must be uppercase"); + } + // All distinct + for (i, a) in names.iter().enumerate() { + for b in &names[i + 1..] { + assert_ne!(a, b); + } + } + } } From fd770f7a2e344b69ce45f755f01f6165d54d7df9 Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Wed, 29 Apr 2026 17:00:24 -0600 Subject: [PATCH 4/5] refactor: extract serialization, message building, and status logic for testability **Added:** - Introduced helper functions in `task_queue.rs` to build task messages, select task subjects based on priority, and determine final status strings, allowing unit testing of wire message shapes and subject routing logic - Added public functions in `blue_task_queue.rs` to serialize/deserialize task and result messages, enabling easier unit testing without a broker - Implemented builder functions in `redis_dispatcher.rs` for call IDs, tool exec requests, and tool result conversions for improved testability - Added free functions in `result_handler.rs` for building task results from agent outcomes, supporting test coverage of branching logic - Provided construction helpers in `tool_executor.rs` for tool exit errors and response objects, allowing isolated unit tests of response shape logic - Added comprehensive unit tests for all new helper functions and message builders in affected modules **Changed:** - Refactored `TaskQueueCore::submit_task` and related logic to use extracted helper functions for message building and subject selection, improving clarity and testability - Updated `RedisToolDispatcher` to use new builder functions for call IDs and tool exec requests, reducing duplication and improving unit test coverage - Changed `process_task` in `result_handler.rs` to delegate result building and status computation to an extracted function, simplifying main logic - Refactored tool execution response construction in `tool_executor.rs` to use dedicated builder functions, clarifying error and success handling - Modified `BlueTaskQueueCore` methods to use new serialization/deserialization helpers, increasing code clarity and maintainability **Removed:** - Eliminated inlined message construction, subject routing, and status logic from main queue, dispatcher, and worker flows in favor of extracted functions - Removed duplicate code for serializing/deserializing messages within queue implementations, consolidating in free functions for testability --- ares-cli/src/orchestrator/task_queue.rs | 160 +++++++++-- .../tool_dispatcher/redis_dispatcher.rs | 57 +++- .../src/orchestrator/tool_dispatcher/tests.rs | 79 ++++++ .../src/worker/task_loop/result_handler.rs | 254 +++++++++++++----- ares-cli/src/worker/tool_executor.rs | 141 ++++++++-- ares-core/src/state/blue_task_queue.rs | 131 ++++++++- 6 files changed, 705 insertions(+), 117 deletions(-) diff --git a/ares-cli/src/orchestrator/task_queue.rs b/ares-cli/src/orchestrator/task_queue.rs index 9d61539a..968f1033 100644 --- a/ares-cli/src/orchestrator/task_queue.rs +++ b/ares-cli/src/orchestrator/task_queue.rs @@ -129,6 +129,55 @@ impl TaskQueue { } } +/// Build the [`TaskMessage`] that `submit_task` publishes to JetStream. +/// +/// Pulled out so the wire shape (priority → subject mapping, callback queue +/// generation, default field values) can be unit-tested without a broker. +#[allow(dead_code)] +pub(crate) fn build_task_message( + task_id: &str, + task_type: &str, + target_role: &str, + payload: serde_json::Value, + source_agent: &str, + priority: i32, +) -> TaskMessage { + TaskMessage { + task_id: task_id.to_string(), + task_type: task_type.to_string(), + source_agent: source_agent.to_string(), + target_agent: target_role.to_string(), + payload, + priority, + created_at: Some(Utc::now()), + callback_queue: Some(nats::task_result_subject(task_id)), + } +} + +/// Choose the work subject for a task based on its priority. +/// +/// Priority ≤ 2 publishes to the urgent subject so workers that bind two +/// consumers can prefer urgent work; everything else goes to the normal +/// subject. +#[allow(dead_code)] +pub(crate) fn task_subject_for_priority(target_role: &str, priority: i32) -> String { + if priority <= 2 { + nats::urgent_task_subject(target_role) + } else { + nats::task_subject(target_role) + } +} + +/// Lifecycle status string written to Redis after a result is published. +#[allow(dead_code)] +pub(crate) const fn final_status_for(success: bool) -> &'static str { + if success { + "completed" + } else { + "failed" + } +} + // The generic impl exposes both the production NATS path and a Redis-only // path used by unit tests with a mock connection. Some methods are only // exercised in the test build; allow that on the impl as a whole. @@ -174,22 +223,16 @@ impl TaskQueueCore { ) -> Result { let task_id = format!("{}_{}", task_type, &Uuid::new_v4().to_string()[..12]); - let msg = TaskMessage { - task_id: task_id.clone(), - task_type: task_type.to_string(), - source_agent: source_agent.to_string(), - target_agent: target_role.to_string(), + let msg = build_task_message( + &task_id, + task_type, + target_role, payload, + source_agent, priority, - created_at: Some(Utc::now()), - callback_queue: Some(nats::task_result_subject(&task_id)), - }; + ); - let subject = if priority <= 2 { - nats::urgent_task_subject(target_role) - } else { - nats::task_subject(target_role) - }; + let subject = task_subject_for_priority(target_role, priority); let bytes = Bytes::from(serde_json::to_vec(&msg).context("serialize TaskMessage")?); let ack = self @@ -310,11 +353,7 @@ impl TaskQueueCore { ack.await .with_context(|| format!("Awaiting ack for {subject}"))?; - let final_status = if result.success { - "completed" - } else { - "failed" - }; + let final_status = final_status_for(result.success); debug!( task_id, status = final_status, @@ -890,6 +929,91 @@ mod tests { assert!(json.contains("ares.tasks.results.t")); } + #[test] + fn task_subject_for_priority_routes_urgent_below_threshold() { + // Priority ≤ 2 ⇒ urgent subject, otherwise the normal subject + assert_eq!( + task_subject_for_priority("scanner", 1), + "ares.tasks.urgent.scanner" + ); + assert_eq!( + task_subject_for_priority("scanner", 2), + "ares.tasks.urgent.scanner" + ); + assert_eq!( + task_subject_for_priority("scanner", 3), + "ares.tasks.scanner" + ); + assert_eq!( + task_subject_for_priority("scanner", 5), + "ares.tasks.scanner" + ); + assert_eq!( + task_subject_for_priority("scanner", 10), + "ares.tasks.scanner" + ); + } + + #[test] + fn final_status_for_maps_success_flag() { + assert_eq!(final_status_for(true), "completed"); + assert_eq!(final_status_for(false), "failed"); + } + + #[test] + fn build_task_message_populates_callback_queue_with_result_subject() { + let msg = build_task_message( + "recon_abcdef123456", + "recon", + "scanner", + serde_json::json!({"target": "10.0.0.1"}), + "orchestrator", + 5, + ); + assert_eq!(msg.task_id, "recon_abcdef123456"); + assert_eq!(msg.task_type, "recon"); + assert_eq!(msg.source_agent, "orchestrator"); + assert_eq!(msg.target_agent, "scanner"); + assert_eq!(msg.priority, 5); + assert_eq!( + msg.callback_queue.as_deref(), + Some("ares.tasks.results.recon_abcdef123456"), + ); + assert!(msg.created_at.is_some()); + assert_eq!(msg.payload["target"], "10.0.0.1"); + } + + #[test] + fn build_task_message_preserves_priority_zero() { + // Priority 0 is allowed (super urgent); make sure we don't clamp. + let msg = build_task_message( + "t", + "exploit", + "exploiter", + serde_json::json!({}), + "orch", + 0, + ); + assert_eq!(msg.priority, 0); + } + + #[test] + fn build_task_message_serializes_round_trip_with_callback() { + let msg = build_task_message( + "lateral_xyz", + "lateral_movement", + "lateral", + serde_json::json!({"host": "dc01"}), + "orch", + 2, + ); + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("ares.tasks.results.lateral_xyz")); + let parsed: TaskMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.priority, 2); + assert_eq!(parsed.task_type, "lateral_movement"); + } + #[test] fn task_result_serializes_none_fields_as_null() { let r = TaskResult { diff --git a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs index 377123d9..e83300ad 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/redis_dispatcher.rs @@ -76,6 +76,41 @@ pub(super) fn dispatch_timeout_result(tool_name: &str, timeout: Duration) -> Too } } +/// Per-call request id for correlating worker replies with outstanding calls. +pub(super) fn build_call_id(tool_name: &str) -> String { + format!("{tool_name}_{}", uuid::Uuid::new_v4().simple()) +} + +/// Build the wire request that `dispatch_tool` sends to the worker. Pulled +/// out so the request shape can be unit-tested without a NATS broker. +pub(super) fn build_tool_exec_request( + call_id: String, + task_id: &str, + tool_name: &str, + arguments: serde_json::Value, + traceparent: Option, + operation_id: Option, +) -> ToolExecRequest { + ToolExecRequest { + call_id, + task_id: task_id.to_string(), + tool_name: tool_name.to_string(), + arguments, + traceparent, + operation_id, + } +} + +/// Convert a deserialized worker reply into the [`ToolExecResult`] returned +/// to the LLM agent loop. +pub(super) fn tool_exec_result_from_response(response: ToolExecResponse) -> ToolExecResult { + ToolExecResult { + output: response.output, + error: response.error, + discoveries: response.discoveries, + } +} + #[async_trait::async_trait] impl ares_llm::ToolDispatcher for RedisToolDispatcher { async fn dispatch_tool( @@ -98,19 +133,19 @@ impl ares_llm::ToolDispatcher for RedisToolDispatcher { self.auth_throttle.acquire(&cred_key).await; } - let call_id = format!("{}_{}", call.name, uuid::Uuid::new_v4().simple()); + let call_id = build_call_id(&call.name); // Inject trace context for cross-service span linking let traceparent = inject_traceparent(&tracing::Span::current()); - let request = ToolExecRequest { - call_id: call_id.clone(), - task_id: task_id.to_string(), - tool_name: call.name.clone(), - arguments: call.arguments.clone(), + let request = build_tool_exec_request( + call_id.clone(), + task_id, + &call.name, + call.arguments.clone(), traceparent, - operation_id: Some(self.operation_id.clone()), - }; + Some(self.operation_id.clone()), + ); let subject = nats::tool_exec_subject(effective_role); let payload = @@ -182,11 +217,7 @@ impl ares_llm::ToolDispatcher for RedisToolDispatcher { .await; } - Ok(ToolExecResult { - output: response.output, - error: response.error, - discoveries: response.discoveries, - }) + Ok(tool_exec_result_from_response(response)) } .instrument(span) .await diff --git a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs index a81f65b8..2a7e3b61 100644 --- a/ares-cli/src/orchestrator/tool_dispatcher/tests.rs +++ b/ares-cli/src/orchestrator/tool_dispatcher/tests.rs @@ -556,3 +556,82 @@ fn dispatch_error_and_timeout_results_share_shape() { assert!(e.discoveries.is_none() && t.discoveries.is_none()); assert!(e.error.is_some() && t.error.is_some()); } + +#[test] +fn build_call_id_includes_tool_name_prefix() { + use redis_dispatcher::build_call_id; + let id = build_call_id("nmap_scan"); + assert!(id.starts_with("nmap_scan_"), "got {id}"); + // simple uuid is 32 hex chars after the prefix + underscore + let suffix = id.strip_prefix("nmap_scan_").unwrap(); + assert_eq!(suffix.len(), 32); + assert!(suffix.chars().all(|c| c.is_ascii_hexdigit())); +} + +#[test] +fn build_call_id_is_unique_per_invocation() { + use redis_dispatcher::build_call_id; + let a = build_call_id("hashcat"); + let b = build_call_id("hashcat"); + assert_ne!(a, b); +} + +#[test] +fn build_tool_exec_request_carries_all_inputs() { + use redis_dispatcher::build_tool_exec_request; + let req = build_tool_exec_request( + "nmap_scan_abc".into(), + "task-1", + "nmap_scan", + serde_json::json!({"target": "10.0.0.1"}), + Some("00-trace-span-01".into()), + Some("op-2026".into()), + ); + assert_eq!(req.call_id, "nmap_scan_abc"); + assert_eq!(req.task_id, "task-1"); + assert_eq!(req.tool_name, "nmap_scan"); + assert_eq!(req.arguments["target"], "10.0.0.1"); + assert_eq!(req.traceparent.as_deref(), Some("00-trace-span-01")); + assert_eq!(req.operation_id.as_deref(), Some("op-2026")); +} + +#[test] +fn build_tool_exec_request_with_no_traceparent_or_operation() { + use redis_dispatcher::build_tool_exec_request; + let req = build_tool_exec_request("c".into(), "t", "whoami", serde_json::json!({}), None, None); + assert!(req.traceparent.is_none()); + assert!(req.operation_id.is_none()); + let json = serde_json::to_string(&req).unwrap(); + // Optional fields skip when None + assert!(!json.contains("traceparent")); + assert!(!json.contains("operation_id")); +} + +#[test] +fn tool_exec_result_from_response_passes_through_all_fields() { + use redis_dispatcher::tool_exec_result_from_response; + let resp = ToolExecResponse { + call_id: "c".into(), + output: "out".into(), + error: None, + discoveries: Some(serde_json::json!({"hosts": [{"ip": "10.0.0.1"}]})), + }; + let r = tool_exec_result_from_response(resp); + assert_eq!(r.output, "out"); + assert!(r.error.is_none()); + assert_eq!(r.discoveries.unwrap()["hosts"][0]["ip"], "10.0.0.1"); +} + +#[test] +fn tool_exec_result_from_response_preserves_error_string() { + use redis_dispatcher::tool_exec_result_from_response; + let resp = ToolExecResponse { + call_id: "c".into(), + output: String::new(), + error: Some("connection refused".into()), + discoveries: None, + }; + let r = tool_exec_result_from_response(resp); + assert_eq!(r.error.as_deref(), Some("connection refused")); + assert!(r.discoveries.is_none()); +} diff --git a/ares-cli/src/worker/task_loop/result_handler.rs b/ares-cli/src/worker/task_loop/result_handler.rs index 80e7f824..677c28ad 100644 --- a/ares-cli/src/worker/task_loop/result_handler.rs +++ b/ares-cli/src/worker/task_loop/result_handler.rs @@ -56,67 +56,13 @@ pub async fn process_task( let usage_for_tracking = agent_result.as_ref().ok().and_then(|ar| ar.usage.clone()); - let (task_result, final_status) = match agent_result { - Ok(ar) => { - if let Some(ref err) = ar.error { - let result_payload = serde_json::json!({ - "output": ar.output, - "task_type": task.task_type, - }); - ( - TaskResult::failure( - &task.task_id, - err.clone(), - Some(result_payload), - &config.pod_name, - &config.agent_name, - ), - "failed", - ) - } else { - let mut result_payload = serde_json::json!({ - "output": ar.output, - "task_type": task.task_type, - }); - if let Some(ref usage) = ar.usage { - result_payload["usage"] = serde_json::to_value(usage).unwrap_or_default(); - } - if let Some(ref disc) = ar.discoveries { - if let Some(obj) = disc.as_object() { - for (k, v) in obj { - result_payload[k] = v.clone(); - } - } - } - ( - TaskResult::success( - &task.task_id, - result_payload, - &config.pod_name, - &config.agent_name, - ), - "completed", - ) - } - } - Err(e) => { - let error_msg = format!("{e}"); - error!( - task_id = %task.task_id, - "Agent task failed: {error_msg}" - ); - ( - TaskResult::failure( - &task.task_id, - error_msg, - None, - &config.pod_name, - &config.agent_name, - ), - "failed", - ) - } - }; + let (task_result, final_status) = build_task_result_for_agent_outcome( + &task.task_id, + &config.pod_name, + &config.agent_name, + &task.task_type, + agent_result, + ); if let Some(ref usage) = usage_for_tracking { if usage.total_tokens > 0 { @@ -185,6 +131,60 @@ pub async fn process_task( } } +/// Build the final `TaskResult` and lifecycle status string from a single +/// agent execution. Pulled out as a free function so the branching logic +/// (success / agent-reported error / dispatch error) can be unit tested +/// without a NATS broker. +pub(super) fn build_task_result_for_agent_outcome( + task_id: &str, + pod_name: &str, + agent_name: &str, + task_type: &str, + agent_outcome: anyhow::Result, +) -> (TaskResult, &'static str) { + match agent_outcome { + Ok(ar) => { + if let Some(ref err) = ar.error { + let payload = serde_json::json!({ + "output": ar.output, + "task_type": task_type, + }); + ( + TaskResult::failure(task_id, err.clone(), Some(payload), pod_name, agent_name), + "failed", + ) + } else { + let mut payload = serde_json::json!({ + "output": ar.output, + "task_type": task_type, + }); + if let Some(ref usage) = ar.usage { + payload["usage"] = serde_json::to_value(usage).unwrap_or_default(); + } + if let Some(ref disc) = ar.discoveries { + if let Some(obj) = disc.as_object() { + for (k, v) in obj { + payload[k] = v.clone(); + } + } + } + ( + TaskResult::success(task_id, payload, pod_name, agent_name), + "completed", + ) + } + } + Err(e) => { + let msg = format!("{e}"); + error!(task_id = %task_id, "Agent task failed: {msg}"); + ( + TaskResult::failure(task_id, msg, None, pod_name, agent_name), + "failed", + ) + } + } +} + /// Set task status in Redis with TTL. async fn set_task_status( conn: &mut C, @@ -216,8 +216,140 @@ where #[cfg(test)] mod tests { use super::*; + use crate::worker::task_loop::types::{AgentResult, TokenUsage}; use ares_core::state::mock_redis::MockRedisConnection; + fn agent_ok(output: &str) -> AgentResult { + AgentResult { + output: output.to_string(), + error: None, + usage: None, + discoveries: None, + } + } + + #[test] + fn build_task_result_success_marks_completed_and_carries_payload() { + let ar = agent_ok("nmap output"); + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", Ok(ar)); + assert_eq!(status, "completed"); + assert!(tr.success); + assert!(tr.error.is_none()); + let payload = tr.result.expect("result payload present"); + assert_eq!(payload["output"], "nmap output"); + assert_eq!(payload["task_type"], "recon"); + assert!(payload.get("usage").is_none()); + } + + #[test] + fn build_task_result_success_includes_usage_when_present() { + let ar = AgentResult { + output: "out".into(), + error: None, + usage: Some(TokenUsage { + input_tokens: 12, + output_tokens: 34, + total_tokens: 46, + model: Some("openai/gpt-4.1-mini".into()), + }), + discoveries: None, + }; + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", Ok(ar)); + assert_eq!(status, "completed"); + let payload = tr.result.unwrap(); + assert_eq!(payload["usage"]["input_tokens"], 12); + assert_eq!(payload["usage"]["total_tokens"], 46); + assert_eq!(payload["usage"]["model"], "openai/gpt-4.1-mini"); + } + + #[test] + fn build_task_result_success_merges_discoveries_into_payload() { + let discoveries = serde_json::json!({ + "hosts": [{"ip": "10.0.0.1"}], + "credentials": [{"username": "alice"}], + }); + let ar = AgentResult { + output: "scan".into(), + error: None, + usage: None, + discoveries: Some(discoveries.clone()), + }; + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", Ok(ar)); + assert_eq!(status, "completed"); + let payload = tr.result.unwrap(); + assert_eq!(payload["hosts"], discoveries["hosts"]); + assert_eq!(payload["credentials"], discoveries["credentials"]); + assert_eq!(payload["task_type"], "recon"); + } + + #[test] + fn build_task_result_success_ignores_non_object_discoveries() { + let ar = AgentResult { + output: "scan".into(), + error: None, + usage: None, + discoveries: Some(serde_json::json!([1, 2, 3])), + }; + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", Ok(ar)); + assert_eq!(status, "completed"); + let payload = tr.result.unwrap(); + // Top-level keys remain just output + task_type + assert!(payload.get("0").is_none()); + assert_eq!(payload["task_type"], "recon"); + } + + #[test] + fn build_task_result_agent_reported_error_marks_failed_and_keeps_partial_output() { + let ar = AgentResult { + output: "ran 3 of 5 steps".into(), + error: Some("one or more tools had errors".into()), + usage: None, + discoveries: None, + }; + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", Ok(ar)); + assert_eq!(status, "failed"); + assert!(!tr.success); + assert_eq!(tr.error.as_deref(), Some("one or more tools had errors")); + let payload = tr.result.expect("partial output preserved"); + assert_eq!(payload["output"], "ran 3 of 5 steps"); + assert_eq!(payload["task_type"], "recon"); + } + + #[test] + fn build_task_result_dispatch_error_marks_failed_with_no_partial_payload() { + let err: anyhow::Result = Err(anyhow::anyhow!("tool spawn failed")); + let (tr, status) = + build_task_result_for_agent_outcome("t1", "pod-0", "ares-recon", "recon", err); + assert_eq!(status, "failed"); + assert!(!tr.success); + assert_eq!(tr.error.as_deref(), Some("tool spawn failed")); + // No partial output preserved on dispatch failure + assert!(tr.result.is_none()); + assert_eq!(tr.worker_pod.as_deref(), Some("pod-0")); + assert_eq!(tr.agent_name.as_deref(), Some("ares-recon")); + } + + #[test] + fn build_task_result_passes_through_pod_and_agent_metadata() { + let ar = agent_ok("hi"); + let (tr, _) = build_task_result_for_agent_outcome( + "task-42", + "pod-xyz", + "ares-credential-access", + "credential_access", + Ok(ar), + ); + assert_eq!(tr.task_id, "task-42"); + assert_eq!(tr.worker_pod.as_deref(), Some("pod-xyz")); + assert_eq!(tr.agent_name.as_deref(), Some("ares-credential-access")); + assert!(tr.completed_at.is_some()); + } + #[tokio::test] async fn set_task_status_writes_status_and_timestamps() { let mut conn = MockRedisConnection::new(); diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs index 8e4f6e31..0f8da9dc 100644 --- a/ares-cli/src/worker/tool_executor.rs +++ b/ares-cli/src/worker/tool_executor.rs @@ -188,6 +188,45 @@ fn discoveries_or_none(parsed: serde_json::Value) -> Option { } } +/// Render the error string for a tool that exited with a non-zero status. +fn tool_exit_error(exit_code: Option) -> String { + format!("tool exited with code {exit_code:?}") +} + +/// Build the success-path [`ToolExecResponse`] (output + discoveries + error +/// derived from the process exit status). Pulled out so the response shape +/// can be unit-tested without spawning a tool subprocess. +fn build_success_response( + call_id: &str, + success: bool, + exit_code: Option, + combined: String, + discoveries: Option, +) -> ToolExecResponse { + let error = if success { + None + } else { + Some(tool_exit_error(exit_code)) + }; + ToolExecResponse { + call_id: call_id.to_string(), + output: combined, + error, + discoveries, + } +} + +/// Build the error-path [`ToolExecResponse`] (dispatch failed before the +/// tool produced any output). +fn build_error_response(call_id: &str, err_str: String) -> ToolExecResponse { + ToolExecResponse { + call_id: call_id.to_string(), + output: String::new(), + error: Some(err_str), + discoveries: None, + } +} + /// Execute a tool call and reply on the NATS inbox. async fn execute_and_respond( client: async_nats::Client, @@ -220,11 +259,8 @@ async fn execute_and_respond( Ok(output) => { let raw = output.combined_raw(); let combined = output.combined(); - let error = if output.success { - None - } else { - Some(format!("tool exited with code {:?}", output.exit_code)) - }; + let success = output.success; + let exit_code = output.exit_code; let discoveries = discoveries_or_none(ares_tools::parsers::parse_tool_output( &request.tool_name, @@ -253,12 +289,7 @@ async fn execute_and_respond( } } - ToolExecResponse { - call_id: request.call_id.clone(), - output: combined, - error, - discoveries, - } + build_success_response(&request.call_id, success, exit_code, combined, discoveries) } Err(e) => { let err_str = e.to_string(); @@ -275,12 +306,7 @@ async fn execute_and_respond( err = %e, "Tool execution failed" ); - ToolExecResponse { - call_id: request.call_id.clone(), - output: String::new(), - error: Some(err_str), - discoveries: None, - } + build_error_response(&request.call_id, err_str) } }; @@ -631,4 +657,85 @@ mod tests { let kept = discoveries_or_none(v.clone()); assert_eq!(kept, Some(v)); } + + #[test] + fn tool_exit_error_renders_exit_code() { + assert_eq!(tool_exit_error(Some(0)), "tool exited with code Some(0)"); + assert_eq!(tool_exit_error(Some(1)), "tool exited with code Some(1)"); + assert_eq!(tool_exit_error(None), "tool exited with code None"); + } + + #[test] + fn build_success_response_success_omits_error() { + let resp = build_success_response("call-1", true, Some(0), "ok\n".into(), None); + assert_eq!(resp.call_id, "call-1"); + assert_eq!(resp.output, "ok\n"); + assert!(resp.error.is_none()); + assert!(resp.discoveries.is_none()); + } + + #[test] + fn build_success_response_failure_records_exit_code() { + let resp = build_success_response("call-2", false, Some(2), "err\n".into(), None); + assert!(!resp.error.as_deref().unwrap().is_empty()); + assert!(resp.error.as_deref().unwrap().contains("Some(2)")); + assert_eq!(resp.output, "err\n"); + } + + #[test] + fn build_success_response_failure_with_no_exit_code() { + // Tool was killed without an exit code (signal, etc.) + let resp = build_success_response("call-3", false, None, String::new(), None); + let err = resp.error.as_deref().unwrap(); + assert!(err.contains("None")); + } + + #[test] + fn build_success_response_carries_discoveries_when_present() { + let disc = serde_json::json!({"hosts": [{"ip": "10.0.0.1"}]}); + let resp = build_success_response( + "call-4", + true, + Some(0), + "scan output".into(), + Some(disc.clone()), + ); + assert_eq!(resp.discoveries.as_ref().unwrap()["hosts"], disc["hosts"]); + assert!(resp.error.is_none()); + } + + #[test] + fn build_success_response_serializes_with_omitted_discoveries_when_none() { + let resp = build_success_response("call-5", true, Some(0), "ok".into(), None); + let json = serde_json::to_string(&resp).unwrap(); + // discoveries field skipped when None + assert!(!json.contains("discoveries")); + } + + #[test] + fn build_error_response_zeroes_output_and_no_discoveries() { + let resp = build_error_response("call-6", "spawn failure".into()); + assert_eq!(resp.call_id, "call-6"); + assert!(resp.output.is_empty()); + assert!(resp.discoveries.is_none()); + assert_eq!(resp.error.as_deref(), Some("spawn failure")); + } + + #[test] + fn build_error_response_serializes_without_discoveries_field() { + let resp = build_error_response("call-7", "bad".into()); + let json = serde_json::to_string(&resp).unwrap(); + assert!(!json.contains("discoveries")); + assert!(json.contains("bad")); + } + + #[test] + fn build_success_and_error_responses_share_call_id_field() { + let s = build_success_response("xyz", true, Some(0), "ok".into(), None); + let e = build_error_response("xyz", "bad".into()); + let sj: serde_json::Value = serde_json::to_value(&s).unwrap(); + let ej: serde_json::Value = serde_json::to_value(&e).unwrap(); + assert_eq!(sj["call_id"], "xyz"); + assert_eq!(ej["call_id"], "xyz"); + } } diff --git a/ares-core/src/state/blue_task_queue.rs b/ares-core/src/state/blue_task_queue.rs index 9b58f9b9..8e4d69d9 100644 --- a/ares-core/src/state/blue_task_queue.rs +++ b/ares-core/src/state/blue_task_queue.rs @@ -133,6 +133,35 @@ impl BlueTaskQueue { } } +/// Serialize a [`BlueTaskMessage`] into the `(subject, payload)` pair that +/// [`BlueTaskQueueCore::submit_task`] hands to JetStream. Pulled out as a +/// free function so the wire shape can be unit-tested without a broker. +pub(crate) fn prepare_blue_task_publish(task: &BlueTaskMessage) -> Result<(String, Bytes)> { + let subject = nats::blue_task_subject(&task.role); + let bytes = Bytes::from(serde_json::to_vec(task).context("serialize BlueTaskMessage")?); + Ok((subject, bytes)) +} + +/// Serialize a [`BlueTaskResult`] into the `(subject, payload)` pair that +/// [`BlueTaskQueueCore::send_result`] hands to JetStream. +pub(crate) fn prepare_blue_result_publish(result: &BlueTaskResult) -> Result<(String, Bytes)> { + let subject = nats::blue_task_result_subject(&result.task_id); + let bytes = Bytes::from(serde_json::to_vec(result).context("serialize BlueTaskResult")?); + Ok((subject, bytes)) +} + +/// Parse a JetStream message payload into a [`BlueTaskMessage`]. +pub(crate) fn parse_blue_task_payload(payload: &[u8], subject: &str) -> Result { + serde_json::from_slice(payload) + .with_context(|| format!("Bad BlueTaskMessage JSON on {subject}")) +} + +/// Parse a JetStream message payload into a [`BlueTaskResult`]. +pub(crate) fn parse_blue_result_payload(payload: &[u8], task_id: &str) -> Result { + serde_json::from_slice(payload) + .with_context(|| format!("Bad BlueTaskResult JSON for {task_id}")) +} + impl BlueTaskQueueCore { /// Construct from a Redis backend only — used by unit tests that don't /// exercise queue methods. Queue methods will return an error. @@ -150,8 +179,7 @@ impl BlueTaskQueueCore { /// Submit a task to the global role queue. pub async fn submit_task(&mut self, task: &BlueTaskMessage) -> anyhow::Result<()> { - let subject = nats::blue_task_subject(&task.role); - let bytes = Bytes::from(serde_json::to_vec(task).context("serialize BlueTaskMessage")?); + let (subject, bytes) = prepare_blue_task_publish(task)?; debug!( task_id = %task.task_id, @@ -200,8 +228,7 @@ impl BlueTaskQueueCore { match fetch.next().await { Some(Ok(m)) => { - let task: BlueTaskMessage = serde_json::from_slice(&m.payload) - .with_context(|| format!("Bad BlueTaskMessage JSON on {subject}"))?; + let task = parse_blue_task_payload(&m.payload, &subject)?; m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); Ok(Some(task)) } @@ -212,8 +239,7 @@ impl BlueTaskQueueCore { /// Send a task result to its dedicated result subject. pub async fn send_result(&mut self, result: &BlueTaskResult) -> anyhow::Result<()> { - let subject = nats::blue_task_result_subject(&result.task_id); - let bytes = Bytes::from(serde_json::to_vec(result).context("serialize BlueTaskResult")?); + let (subject, bytes) = prepare_blue_result_publish(result)?; let ack = self .nats()? @@ -274,8 +300,7 @@ impl BlueTaskQueueCore { match fetch.next().await { Some(Ok(m)) => { - let parsed: BlueTaskResult = serde_json::from_slice(&m.payload) - .with_context(|| format!("Bad BlueTaskResult JSON for {task_id}"))?; + let parsed = parse_blue_result_payload(&m.payload, task_id)?; m.ack().await.map_err(|e| anyhow::anyhow!("ack: {e}")).ok(); Ok(Some(parsed)) } @@ -713,4 +738,94 @@ mod tests { let err = q.queue_length("triage").await.unwrap_err(); assert!(err.to_string().contains("NATS")); } + + fn sample_task() -> BlueTaskMessage { + BlueTaskMessage { + task_id: "btask-1".into(), + investigation_id: "inv-1".into(), + task_type: "log_search".into(), + role: "triage".into(), + params: serde_json::json!({"q": "alertname=Foo"}), + created_at: "2026-04-29T20:00:00Z".into(), + } + } + + #[test] + fn prepare_blue_task_publish_uses_role_subject_and_full_message() { + let task = sample_task(); + let (subject, bytes) = prepare_blue_task_publish(&task).unwrap(); + assert_eq!(subject, "ares.blue.tasks.triage"); + let parsed: BlueTaskMessage = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(parsed.task_id, "btask-1"); + assert_eq!(parsed.investigation_id, "inv-1"); + assert_eq!(parsed.role, "triage"); + assert_eq!(parsed.params["q"], "alertname=Foo"); + } + + #[test] + fn prepare_blue_task_publish_subject_changes_with_role() { + let mut t = sample_task(); + t.role = "log_analyst".into(); + let (subject, _) = prepare_blue_task_publish(&t).unwrap(); + assert_eq!(subject, "ares.blue.tasks.log_analyst"); + } + + #[test] + fn prepare_blue_result_publish_uses_task_result_subject() { + let r = BlueTaskResult::success( + "btask-9", + "inv-1", + serde_json::json!({"hits": 3}), + "agent-x", + ); + let (subject, bytes) = prepare_blue_result_publish(&r).unwrap(); + assert_eq!(subject, "ares.blue.tasks.results.btask-9"); + let parsed: BlueTaskResult = serde_json::from_slice(&bytes).unwrap(); + assert!(parsed.success); + assert_eq!(parsed.task_id, "btask-9"); + assert_eq!(parsed.result.unwrap()["hits"], 3); + } + + #[test] + fn prepare_blue_result_publish_uses_distinct_subject_per_task_id() { + let a = BlueTaskResult::failure("a", "inv-1", "err".into(), "agent"); + let b = BlueTaskResult::failure("b", "inv-1", "err".into(), "agent"); + let (sa, _) = prepare_blue_result_publish(&a).unwrap(); + let (sb, _) = prepare_blue_result_publish(&b).unwrap(); + assert_ne!(sa, sb); + assert!(sa.ends_with(".a")); + assert!(sb.ends_with(".b")); + } + + #[test] + fn parse_blue_task_payload_round_trips_a_published_message() { + let task = sample_task(); + let (subject, bytes) = prepare_blue_task_publish(&task).unwrap(); + let parsed = parse_blue_task_payload(&bytes, &subject).unwrap(); + assert_eq!(parsed.task_id, task.task_id); + assert_eq!(parsed.role, task.role); + } + + #[test] + fn parse_blue_task_payload_surfaces_subject_in_context_on_error() { + let err = parse_blue_task_payload(b"not json", "ares.blue.tasks.triage").unwrap_err(); + let msg = format!("{err:#}"); + assert!(msg.contains("ares.blue.tasks.triage"), "got {msg}"); + } + + #[test] + fn parse_blue_result_payload_round_trips_a_published_result() { + let r = BlueTaskResult::success("btask-1", "inv-1", serde_json::json!({"x": 1}), "agent"); + let (_, bytes) = prepare_blue_result_publish(&r).unwrap(); + let parsed = parse_blue_result_payload(&bytes, "btask-1").unwrap(); + assert!(parsed.success); + assert_eq!(parsed.task_id, "btask-1"); + } + + #[test] + fn parse_blue_result_payload_surfaces_task_id_in_context_on_error() { + let err = parse_blue_result_payload(b"garbage", "btask-7").unwrap_err(); + let msg = format!("{err:#}"); + assert!(msg.contains("btask-7"), "got {msg}"); + } } From 608cc880cc715d16fe2f7311f958d9a28871243b Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Wed, 29 Apr 2026 17:07:51 -0600 Subject: [PATCH 5/5] refactor: extract status and discovery helpers for clarity and testability **Added:** - Introduced `build_running_status_extra` and `build_final_status_extra` helper functions to encapsulate construction of status "extra_fields" payloads and ensure field consistency between producer and consumer - Added `busy_current_task` function to standardize formatting of `WorkerStatus.current_task` field - Added `count_discovery_entries` function to count non-empty discovery arrays per type, supporting clearer and unit-testable discovery reporting logic - Implemented comprehensive unit tests for new helper functions to verify payload structure, metadata consistency, and edge case handling **Changed:** - Refactored `process_task` in `result_handler.rs` to use new helper functions for status "extra_fields" payloads, improving maintainability and reducing field duplication - Updated `run_tool_exec_loop` in `tool_executor.rs` to use the new `busy_current_task` helper, enforcing consistent task status formatting - Modified discovery trace emission to use `count_discovery_entries`, replacing inline logic with reusable, tested function for clarity and correctness **Removed:** - Eliminated repeated manual construction of status payload objects in favor of the new helper functions, reducing code duplication and risk of inconsistency --- .../src/worker/task_loop/result_handler.rs | 136 +++++++++++++++--- ares-cli/src/worker/tool_executor.rs | 115 ++++++++++++--- 2 files changed, 216 insertions(+), 35 deletions(-) diff --git a/ares-cli/src/worker/task_loop/result_handler.rs b/ares-cli/src/worker/task_loop/result_handler.rs index 677c28ad..9ea651ac 100644 --- a/ares-cli/src/worker/task_loop/result_handler.rs +++ b/ares-cli/src/worker/task_loop/result_handler.rs @@ -37,15 +37,7 @@ pub async fn process_task( conn, &task.task_id, "running", - &serde_json::json!({ - "operation_id": config.operation_id, - "role": config.worker_role, - "agent_name": config.agent_name, - "pod_name": config.pod_name, - "task_type": task.task_type, - "payload": task.payload, - "started_at": started_at, - }), + &build_running_status_extra(config, &task.task_type, &task.payload, &started_at), ) .await { @@ -111,14 +103,7 @@ pub async fn process_task( conn, &task.task_id, final_status, - &serde_json::json!({ - "operation_id": config.operation_id, - "role": config.worker_role, - "agent_name": config.agent_name, - "pod_name": config.pod_name, - "task_type": task.task_type, - "ended_at": Utc::now().to_rfc3339(), - }), + &build_final_status_extra(config, &task.task_type, &Utc::now().to_rfc3339()), ) .await { @@ -131,6 +116,43 @@ pub async fn process_task( } } +/// Build the `extra_fields` payload written alongside the `running` status +/// when a task starts executing. Pulled out so callers don't have to keep +/// the field set in lock-step with the consumer side. +pub(super) fn build_running_status_extra( + config: &WorkerConfig, + task_type: &str, + payload: &serde_json::Value, + started_at: &str, +) -> serde_json::Value { + serde_json::json!({ + "operation_id": config.operation_id, + "role": config.worker_role, + "agent_name": config.agent_name, + "pod_name": config.pod_name, + "task_type": task_type, + "payload": payload, + "started_at": started_at, + }) +} + +/// Build the `extra_fields` payload written when a task transitions to its +/// final status (`completed` / `failed`). +pub(super) fn build_final_status_extra( + config: &WorkerConfig, + task_type: &str, + ended_at: &str, +) -> serde_json::Value { + serde_json::json!({ + "operation_id": config.operation_id, + "role": config.worker_role, + "agent_name": config.agent_name, + "pod_name": config.pod_name, + "task_type": task_type, + "ended_at": ended_at, + }) +} + /// Build the final `TaskResult` and lifecycle status string from a single /// agent execution. Pulled out as a free function so the branching logic /// (success / agent-reported error / dispatch error) can be unit tested @@ -389,6 +411,86 @@ mod tests { assert_eq!(v["task_type"], "recon"); } + fn worker_config_for_test() -> WorkerConfig { + WorkerConfig { + redis_url: "redis://localhost".into(), + nats_url: "nats://localhost".into(), + worker_role: "recon".into(), + agent_name: "ares-recon-0".into(), + pod_name: "pod-0".into(), + operation_id: Some("op-2026".into()), + mode: crate::worker::config::WorkerMode::Task, + poll_timeout: std::time::Duration::from_secs(1), + task_timeout: std::time::Duration::from_secs(60), + heartbeat_interval: std::time::Duration::from_secs(15), + heartbeat_ttl: std::time::Duration::from_secs(60), + } + } + + #[test] + fn build_running_status_extra_includes_all_metadata() { + let cfg = worker_config_for_test(); + let payload = serde_json::json!({"target": "10.0.0.1"}); + let extra = build_running_status_extra(&cfg, "recon", &payload, "2026-04-29T20:00:00Z"); + assert_eq!(extra["operation_id"], "op-2026"); + assert_eq!(extra["role"], "recon"); + assert_eq!(extra["agent_name"], "ares-recon-0"); + assert_eq!(extra["pod_name"], "pod-0"); + assert_eq!(extra["task_type"], "recon"); + assert_eq!(extra["payload"]["target"], "10.0.0.1"); + assert_eq!(extra["started_at"], "2026-04-29T20:00:00Z"); + assert!(extra.get("ended_at").is_none()); + } + + #[test] + fn build_running_status_extra_handles_missing_operation_id() { + let mut cfg = worker_config_for_test(); + cfg.operation_id = None; + let extra = build_running_status_extra( + &cfg, + "lateral", + &serde_json::json!({}), + "2026-04-29T20:00:00Z", + ); + assert!(extra["operation_id"].is_null()); + assert_eq!(extra["task_type"], "lateral"); + } + + #[test] + fn build_final_status_extra_omits_payload_and_started_at() { + let cfg = worker_config_for_test(); + let extra = build_final_status_extra(&cfg, "recon", "2026-04-29T20:05:00Z"); + assert_eq!(extra["operation_id"], "op-2026"); + assert_eq!(extra["role"], "recon"); + assert_eq!(extra["agent_name"], "ares-recon-0"); + assert_eq!(extra["pod_name"], "pod-0"); + assert_eq!(extra["task_type"], "recon"); + assert_eq!(extra["ended_at"], "2026-04-29T20:05:00Z"); + assert!(extra.get("payload").is_none()); + assert!(extra.get("started_at").is_none()); + } + + #[test] + fn running_and_final_extra_share_metadata_keys() { + let cfg = worker_config_for_test(); + let r = build_running_status_extra( + &cfg, + "recon", + &serde_json::json!({}), + "2026-04-29T20:00:00Z", + ); + let f = build_final_status_extra(&cfg, "recon", "2026-04-29T20:05:00Z"); + for k in [ + "operation_id", + "role", + "agent_name", + "pod_name", + "task_type", + ] { + assert_eq!(r[k], f[k], "key {k} should match between running and final"); + } + } + #[tokio::test] async fn set_task_status_handles_non_object_extra() { let mut conn = MockRedisConnection::new(); diff --git a/ares-cli/src/worker/tool_executor.rs b/ares-cli/src/worker/tool_executor.rs index 0f8da9dc..0c3a27bb 100644 --- a/ares-cli/src/worker/tool_executor.rs +++ b/ares-cli/src/worker/tool_executor.rs @@ -115,7 +115,7 @@ pub async fn run_tool_exec_loop( let _ = status_tx.send(WorkerStatus { status: "busy".to_string(), - current_task: Some(format!("{}:{}", request.tool_name, request.call_id)), + current_task: Some(busy_current_task(&request.tool_name, &request.call_id)), }); let ti = extract_target_info(&request.arguments); @@ -193,6 +193,29 @@ fn tool_exit_error(exit_code: Option) -> String { format!("tool exited with code {exit_code:?}") } +/// Build the `WorkerStatus.current_task` string used while a tool call is in +/// flight. Pulled out so the field shape stays in lock-step with consumers +/// that key off `tool_name:call_id`. +fn busy_current_task(tool_name: &str, call_id: &str) -> String { + format!("{tool_name}:{call_id}") +} + +/// Iterate a `discoveries` value and return `(disc_type, count)` for each +/// non-empty array. Used by the executor to emit one `trace_discovery` span +/// per non-empty discovery type. Pulled out as a free function so the +/// counting logic can be unit-tested without spinning up a tracer. +fn count_discovery_entries(discoveries: &serde_json::Value) -> Vec<(String, usize)> { + let Some(obj) = discoveries.as_object() else { + return Vec::new(); + }; + obj.iter() + .filter_map(|(disc_type, items)| { + let count = items.as_array().map(|a| a.len()).unwrap_or(0); + (count > 0).then(|| (disc_type.clone(), count)) + }) + .collect() +} + /// Build the success-path [`ToolExecResponse`] (output + discoveries + error /// derived from the process exit status). Pulled out so the response shape /// can be unit-tested without spawning a tool subprocess. @@ -269,23 +292,18 @@ async fn execute_and_respond( )); if let Some(ref disc) = discoveries { - if let Some(obj) = disc.as_object() { - for (disc_type, items) in obj { - let count = items.as_array().map(|a| a.len()).unwrap_or(0); - if count > 0 { - let span = trace_discovery( - disc_type, - &request.tool_name, - di.target_user.as_deref(), - None, - di.target_ip.as_deref(), - di.target_fqdn.as_deref(), - dt, - request.operation_id.as_deref(), - ); - let _guard = span.enter(); - } - } + for (disc_type, _count) in count_discovery_entries(disc) { + let span = trace_discovery( + &disc_type, + &request.tool_name, + di.target_user.as_deref(), + None, + di.target_ip.as_deref(), + di.target_fqdn.as_deref(), + dt, + request.operation_id.as_deref(), + ); + let _guard = span.enter(); } } @@ -729,6 +747,67 @@ mod tests { assert!(json.contains("bad")); } + #[test] + fn busy_current_task_uses_colon_delimiter() { + assert_eq!( + busy_current_task("nmap_scan", "nmap_scan_abc123"), + "nmap_scan:nmap_scan_abc123" + ); + } + + #[test] + fn busy_current_task_handles_empty_call_id() { + // We never expect an empty call_id, but the format should be defensive + assert_eq!(busy_current_task("whoami", ""), "whoami:"); + } + + #[test] + fn count_discovery_entries_returns_per_type_counts() { + let discoveries = serde_json::json!({ + "hosts": [{"ip": "10.0.0.1"}, {"ip": "10.0.0.2"}], + "credentials": [{"username": "alice"}], + }); + let mut entries = count_discovery_entries(&discoveries); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + assert_eq!( + entries, + vec![("credentials".to_string(), 1), ("hosts".to_string(), 2)], + ); + } + + #[test] + fn count_discovery_entries_skips_empty_arrays() { + let discoveries = serde_json::json!({ + "hosts": [], + "credentials": [{"username": "alice"}], + }); + let entries = count_discovery_entries(&discoveries); + assert_eq!(entries, vec![("credentials".to_string(), 1)]); + } + + #[test] + fn count_discovery_entries_skips_non_array_fields() { + let discoveries = serde_json::json!({ + "hosts": "not-an-array", + "credentials": [{"username": "alice"}], + }); + let entries = count_discovery_entries(&discoveries); + assert_eq!(entries, vec![("credentials".to_string(), 1)]); + } + + #[test] + fn count_discovery_entries_returns_empty_for_non_object() { + assert!(count_discovery_entries(&serde_json::json!([])).is_empty()); + assert!(count_discovery_entries(&serde_json::json!("hi")).is_empty()); + assert!(count_discovery_entries(&serde_json::json!(42)).is_empty()); + assert!(count_discovery_entries(&serde_json::json!(null)).is_empty()); + } + + #[test] + fn count_discovery_entries_returns_empty_for_empty_object() { + assert!(count_discovery_entries(&serde_json::json!({})).is_empty()); + } + #[test] fn build_success_and_error_responses_share_call_id_field() { let s = build_success_response("xyz", true, Some(0), "ok".into(), None);