Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ sql_app.db
uploads
pytorch_connectomics
server_api/chatbot/faiss_index/
.logs/
2 changes: 1 addition & 1 deletion pytorch_connectomics
Submodule pytorch_connectomics updated from 0a0dce to 20ccfd
132 changes: 132 additions & 0 deletions server_api/chatbot/agent_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Standalone Agent Test Script
Tests the multi-agent chatbot system without starting the full app.
Shows all RAG retrievals, tool calls, and final responses.

Usage:
python agent_cli.py # interactive mode
python agent_cli.py -b # batch: run 20-question PyTC eval
python agent_cli.py "your question" # single question
"""

import os
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from server_api.chatbot.chatbot import build_chain


# ── Failed questions from 40-question test ──────────────────────────────────

BATCH_QUESTIONS = [
# Test #10 - Fabricated CLI flags --batch-size, --checkpoint-interval
"Give me the command to train on CREMI with batch size 2 and save checkpoints every 5000 iterations",

# Test #14 - Didn't override scheduler explicitly
"Train on MitoEM with the WarmupCosineLR scheduler and a base learning rate of 0.002",

# Test #17 - Wrong override format --inference.AUG_NUM=8
"Generate an inference command for CREMI. Use configs/CREMI/CREMI-Base.yaml and checkpoint outputs/CREMI/checkpoint_100000.pth.tar with 8 TTA augmented views",

# Test #32 - Fabricated scripts/evaluate.py
"How do I evaluate synapse detection results for the CREMI challenge?",
]


def run_batch():
"""Run the 20-question batch test. Agent is built once and reused."""
print("Building agent (one-time)...")
agent, reset_search_counter = build_chain()
print(f"Running {len(BATCH_QUESTIONS)} tests...\n")

for i, q in enumerate(BATCH_QUESTIONS, 1):
reset_search_counter()

print(f"\n{'='*80}")
print(f"TEST {i}/{len(BATCH_QUESTIONS)}")
print(f"Q: {q}")
print(f"{'='*80}\n")

t0 = time.time()
try:
result = agent.invoke({"messages": [("user", q)]})
response = result["messages"][-1].content
except Exception as e:
response = f"[ERROR] {e}"
elapsed = time.time() - t0

print(f"\n{'─'*80}")
print("RESPONSE:")
print(f"{'─'*80}")
print(response)
print(f"\n({elapsed:.1f}s)")

print(f"\n{'#'*80}")
print(f"BATCH COMPLETE — {len(BATCH_QUESTIONS)} questions answered")
print(f"{'#'*80}")


def run_single(question: str):
"""Test the agent with a single question."""
print(f"\n{'='*80}")
print(f"QUESTION: {question}")
print(f"{'='*80}\n")
agent, reset_search_counter = build_chain()
reset_search_counter()
result = agent.invoke({"messages": [("user", question)]})
response = result["messages"][-1].content
print(f"\n{'─'*80}")
print("FINAL RESPONSE:")
print(f"{'─'*80}")
print(response)
print(f"\n{'='*80}\n")


def interactive_mode():
"""Interactive mode for testing custom questions."""
print("\n" + "="*80)
print("INTERACTIVE AGENT TEST MODE")
print("="*80)
print("Type your questions to test the agent.")
print("Type 'quit' or 'exit' to stop.\n")
agent, reset_search_counter = build_chain()
while True:
try:
question = input("\nYour question: ").strip()
if question.lower() in ['quit', 'exit', 'q']:
break
if not question:
continue
reset_search_counter()
result = agent.invoke({"messages": [("user", question)]})
response = result["messages"][-1].content
print(f"\n{'─'*60}")
print(response)
print(f"{'─'*60}")
except KeyboardInterrupt:
break
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test the chatbot agent")
parser.add_argument("-b", "--batch", action="store_true", help="Run 20-question graded batch test")
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
parser.add_argument("question", nargs="*", help="Single question to test")
args = parser.parse_args()

if args.batch:
run_batch()
elif args.interactive:
interactive_mode()
elif args.question:
run_single(" ".join(args.question))
else:
interactive_mode()
145 changes: 75 additions & 70 deletions server_api/chatbot/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,69 +23,62 @@

TRAINING_AGENT_PROMPT = """You are a **Training Agent** for PyTorch Connectomics.

You help users set up and configure training jobs for biomedical image segmentation.

CRITICAL RULES:
1. **Only report values that your tools return.** Do NOT invent hyperparameter values, config names, or file paths.
2. **Always use tools before answering.** Call list_training_configs or read_config first — never guess.
3. **Be concise.** Report the facts, generate the command, and stop.
RULES:
1. Only report values that your tools return. Do NOT invent config names, paths, or settings.
2. Never tell the user to write a YAML from scratch. Always start from an existing config.
3. If the task is unsupported, say so. PyTC only does segmentation.
4. Be concise. State the facts, generate the command, stop.

Tools:
- list_training_configs: List available config files with descriptions
- read_config: Read a config file to see its hyperparameters
WORKFLOW: The available configs are provided in your task message. Pick the best match, then:
1. Call read_config on the chosen config path to see its YAML overrides.
2. For common parameters (learning rate, batch size, iterations, optimizer, checkpoint interval), ALWAYS use the keys listed below. DO NOT search for these.
3. For specialized parameters (augmentation settings, loss functions, architecture details), call search_documentation.
4. Build the command with overrides using the SECTION.KEY=value format.

Workflow:
1. Use list_training_configs to find configs matching user's task
2. Use read_config to examine the config's current settings
3. Compare user requirements with config defaults
4. Generate the training command with appropriate overrides
IMPORTANT: YAML configs only show overrides — many valid keys exist in the defaults but are not shown in read_config output.

Command Format:
```
python scripts/main.py --config <config_path> [OVERRIDES]
```
Common override keys (ALWAYS use these exact keys, never search for alternatives):
- SOLVER.BASE_LR, SOLVER.SAMPLES_PER_BATCH, SOLVER.ITERATION_TOTAL
- SOLVER.ITERATION_SAVE (checkpoint save interval), SOLVER.ITERATION_STEP (LR decay steps)
- SOLVER.NAME (values: SGD, Adam, AdamW)
- SOLVER.LR_SCHEDULER_NAME (values: WarmupMultiStepLR, WarmupCosineLR)
- SOLVER.CLIP_GRADIENTS.ENABLED (True/False), SOLVER.CLIP_GRADIENTS.CLIP_VALUE
- MODEL.ARCHITECTURE, MODEL.BLOCK_TYPE, MODEL.FILTERS

Overrides use YAML key paths appended to the command: SECTION.KEY=value
Example:
```
python scripts/main.py --config configs/Lucchi-Mitochondria.yaml SOLVER.BASE_LR=0.001 SOLVER.SAMPLES_PER_BATCH=16
```
Use read_config output to determine the correct key paths for any parameter.
NEVER invent keys like TRAIN.MAX_ITER, TRAINING.BATCH_SIZE, or CLI flags like --batch-size, --checkpoint-interval — these do not exist.

Always generate commands for the user to run - never execute directly."""
Command format: `python scripts/main.py --config-file <path> [SECTION.KEY=value ...]`
Always generate commands for the user to run — never execute directly."""


INFERENCE_AGENT_PROMPT = """You are an **Inference Agent** for PyTorch Connectomics.

You help users run inference and evaluation with trained segmentation models.

CRITICAL RULES:
1. **Only report values that your tools return.** Do NOT invent checkpoint paths, config names, or settings.
2. **Always use tools before answering.** Call list_checkpoints or read_config first — never guess.
3. **Be concise.** Report the facts, generate the command, and stop.
RULES:
1. Only report values that your tools return. Do NOT invent checkpoint paths, config names, or settings.
2. Be concise. State the facts, generate the command, stop.

Tools:
- list_checkpoints: Find available trained model checkpoints
- read_config: Read config to find default inference settings
WORKFLOW:
1. If the user did NOT provide a checkpoint path, call list_checkpoints first to see available checkpoints.
2. If the user DID provide a checkpoint path (e.g., outputs/model/checkpoint.pth.tar), skip list_checkpoints.
3. Call read_config to see the INFERENCE section keys.
4. For specialized inference parameters, call search_documentation if needed.

Workflow:
1. Use list_checkpoints to find available models
2. Use read_config to check inference settings (INFERENCE section)
3. Generate the inference command
Here is the correct override key mapping (use these exact keys):
- Output path → INFERENCE.OUTPUT_PATH
- TTA augmentation count → INFERENCE.AUG_NUM
- TTA mode → INFERENCE.AUG_MODE (values: mean, max)
- Blending → INFERENCE.BLENDING (values: gaussian, bump)
- Stride → INFERENCE.STRIDE
- Process volumes one at a time → INFERENCE.DO_SINGLY
- Batch size → INFERENCE.SAMPLES_PER_BATCH

Command Format:
```
python scripts/main.py --config <config_path> --checkpoint <checkpoint_path> --inference [OVERRIDES]
```
Command format: `python scripts/main.py --config-file <path> --inference --checkpoint <ckpt> [SECTION.KEY=value ...]`

Overrides use YAML key paths appended to the command: SECTION.KEY=value
Example:
```
python scripts/main.py --config configs/Lucchi-Mitochondria.yaml --checkpoint outputs/Lucchi/checkpoint_100000.pth --inference INFERENCE.OUTPUT_PATH=/path/to/output
```
Use read_config output to determine the correct key paths for any parameter.
IMPORTANT: Overrides use SECTION.KEY=value format (NO -- prefix). Example:
✅ CORRECT: INFERENCE.AUG_NUM=8
❌ WRONG: --inference.AUG_NUM=8

Always generate commands for the user to run - never execute directly."""
Always generate commands for the user to run never execute directly."""


SUPERVISOR_PROMPT = """You are the **Supervisor Agent** for PyTorch Connectomics (PyTC Client).
Expand All @@ -94,6 +87,7 @@

ROUTING — decide which tool to use BEFORE calling anything:
- **UI, navigation, features, shortcuts, workflows, "how do I..." questions** → search_documentation
- **General PyTC questions** (what architectures are supported, what augmentations exist, what loss functions are available, etc.) → search_documentation
- **Generate a specific training/inference command** → delegate_to_training_agent or delegate_to_inference_agent
- **General/greeting/off-topic** → answer directly, no tool needed

Expand All @@ -103,18 +97,19 @@
3. **For application questions, ground answers in retrieved documentation.** Call search_documentation and base your answer on the returned text. Do NOT invent features, shortcuts, buttons, or workflows.
4. **Do not fabricate specifics.** Never make up keyboard shortcuts, button labels, or step-by-step instructions unless they come from retrieved docs or a sub-agent response.
4a. **NEVER use command-line instructions for UI questions.** The PyTC Client is a desktop GUI application. If the user asks how to do something, explain the UI workflow (buttons, tabs, forms) from the documentation. Do NOT provide Python scripts, bash commands, or CLI examples unless the sub-agent explicitly generates them.
4b. **NEVER fabricate file paths or scripts.** Do NOT invent scripts like `scripts/evaluate.py`, `scripts/resume_training.py`, or any other files that don't exist. If evaluation requires Python code, show inline code using `connectomics.utils.evaluate`, not fake script paths.
5. **Answer every part of the user's question.** If they ask about two things, address both.
6. **Use retrieved content even if wording differs.** If the documentation describes relevant features or workflows, use that information to answer the question. Don't claim something isn't documented just because it uses different terminology than the user's question.
7. **HARD LIMIT: You may call search_documentation EXACTLY 2 times per user question.** After the second call, you MUST answer with the information already retrieved. Do NOT attempt a third search. If the tool returns "Search limit reached", immediately stop and answer based on what you already have.
7. **HARD LIMIT: You may call search_documentation at most 3 times yourself.** Sub-agents also have access to search_documentation. If the tool returns "Search limit reached", immediately stop and answer based on what you already have.

Sub-agents:
- **Training Agent**: Config selection, training job setup, hyperparameter overrides
- **Inference Agent**: Checkpoint management, inference/evaluation commands

Tools:
- search_documentation: Search PyTC docs for UI guides and feature explanations. Use ONLY for questions about the application interface, pages, buttons, or workflows.
- delegate_to_training_agent: Send training-related tasks to training agent
- delegate_to_inference_agent: Send inference-related tasks to inference agent"""
- search_documentation: Search PyTC docs for UI guides, feature explanations, training/inference config references, model architectures, augmentation options, and bundled configs.
- delegate_to_training_agent: Send training-related tasks to training agent (config selection, command generation, hyperparameter tuning)
- delegate_to_inference_agent: Send inference-related tasks to inference agent (checkpoint listing, inference commands, evaluation setup)"""


def build_chain():
Expand Down Expand Up @@ -152,23 +147,11 @@ def build_chain():
def reset_search_counter():
_search_call_count[0] = 0

training_agent = create_agent(
model=llm,
tools=[list_training_configs, read_config],
system_prompt=TRAINING_AGENT_PROMPT,
)

inference_agent = create_agent(
model=llm,
tools=[list_checkpoints, read_config],
system_prompt=INFERENCE_AGENT_PROMPT,
)

@tool
def search_documentation(query: str) -> str:
"""
Search PyTC documentation for how-to guides, UI explanations, and feature descriptions.
Use this for questions about the application interface or general usage.
Search PyTC documentation for UI guides, feature descriptions, training/inference
configuration references, model architectures, augmentation options, and bundled configs.

Args:
query: The user's question
Expand All @@ -180,8 +163,8 @@ def search_documentation(query: str) -> str:
print(
f"[TOOL] search_documentation(query={query!r}) [call {_search_call_count[0]}]"
)
if _search_call_count[0] > 2:
print("[TOOL] search limit reached (max 2 per question)")
if _search_call_count[0] > 6:
print("[TOOL] search limit reached (max 6 per question)")
return "Search limit reached. Please answer based on the documentation already retrieved."

# Primary: FAISS semantic search (chunked embeddings)
Expand Down Expand Up @@ -213,6 +196,18 @@ def search_documentation(query: str) -> str:
print("[TOOL] search_documentation → no results")
return "No relevant documentation found."

training_agent = create_agent(
model=llm,
tools=[list_training_configs, read_config, search_documentation],
system_prompt=TRAINING_AGENT_PROMPT,
)

inference_agent = create_agent(
model=llm,
tools=[list_checkpoints, read_config, search_documentation],
system_prompt=INFERENCE_AGENT_PROMPT,
)

@tool
def delegate_to_training_agent(task: str) -> str:
"""
Expand All @@ -226,8 +221,18 @@ def delegate_to_training_agent(task: str) -> str:
Response from the training agent
"""
print(f"[TOOL] delegate_to_training_agent(task={task!r})")
# Auto-inject available configs so the agent doesn't need to call list_training_configs
configs = list_training_configs.invoke({})
config_summary = "\n".join(
f"- {c['name']} ({c['model']}) → {c['path']}" for c in configs if isinstance(c, dict) and 'name' in c
)
enriched_task = (
f"{task}\n\n"
f"AVAILABLE CONFIGS (already fetched for you):\n{config_summary}\n\n"
f"Pick the best match and call read_config on its path to see the exact YAML keys before generating the command."
)
result = training_agent.invoke(
{"messages": [{"role": "user", "content": task}]}
{"messages": [{"role": "user", "content": enriched_task}]}
)
messages = result.get("messages", [])
response = (
Expand Down
Loading
Loading