Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 62 additions & 32 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
)
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.db import BaseDatabase
from astrbot.core.message.components import File, Image, Record, Reply, Video
from astrbot.core.persona_error_reply import (
extract_persona_custom_error_message_from_persona,
Expand Down Expand Up @@ -73,7 +74,6 @@
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
normalize_umo_for_workspace,
)
from astrbot.core.tools.cron_tools import FutureTaskTool
from astrbot.core.tools.knowledge_base_tools import (
Expand Down Expand Up @@ -115,6 +115,10 @@
extract_quoted_message_text,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from astrbot.core.workspace import (
normalize_umo_for_workspace,
resolve_workspace_root_for_umo,
)

LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message"
WEEKDAY_NAMES = (
Expand Down Expand Up @@ -357,41 +361,63 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
req.prompt = f"{prefix}{req.prompt}"


def _get_workspace_path_for_umo(umo: str) -> Path:
normalized_umo = normalize_umo_for_workspace(umo)
return Path(get_astrbot_workspaces_path()) / normalized_umo
async def _get_workspace_path_for_umo(umo: str, plugin_context: Context) -> Path:
"""Resolve the workspace path for the current request.

Args:
umo: Unified message origin.
plugin_context: Star context containing the database instance.

Returns:
Workspace path used as cwd.
"""
fallback_root = (
Path(get_astrbot_workspaces_path()) / normalize_umo_for_workspace(umo)
).resolve(strict=False)
db = getattr(plugin_context, "_db", None)
if not isinstance(db, BaseDatabase):
return fallback_root
try:
return await resolve_workspace_root_for_umo(umo, db)
except Exception:
return fallback_root


def _apply_workspace_extra_prompt(
async def _apply_workspace_extra_prompt(
event: AstrMessageEvent,
req: ProviderRequest,
plugin_context: Context,
) -> None:
extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
"EXTRA_PROMPT.md"
workspace_root = await _get_workspace_path_for_umo(
event.unified_msg_origin,
plugin_context,
)
if not extra_prompt_path.is_file():
return

try:
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to read workspace extra prompt for umo=%s from %s: %s",
event.unified_msg_origin,
extra_prompt_path,
exc,
)
return
extra_prompts: list[str] = []
extra_prompt_path = workspace_root / "EXTRA_PROMPT.md"
if extra_prompt_path.is_file():
try:
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to read workspace extra prompt for umo=%s from %s: %s",
event.unified_msg_origin,
extra_prompt_path,
exc,
)
else:
if extra_prompt:
extra_prompts.append(f"From `{extra_prompt_path}`:\n{extra_prompt}")

if not extra_prompt:
if not extra_prompts:
return

extra_prompt_text = "\n\n".join(extra_prompts)
req.system_prompt = (
f"{req.system_prompt or ''}\n"
"[Workspace Extra Prompt]\n"
"The following instructions are loaded from the current workspace "
"`EXTRA_PROMPT.md` file.\n"
f"{extra_prompt}\n"
f"{extra_prompt_text}\n"
)


Expand Down Expand Up @@ -498,13 +524,13 @@ async def _ensure_persona_and_skills(
skill_manager = SkillManager()
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
skills = _filter_skills_for_current_config(skills, cfg)
workspace_skills = (
skill_manager.list_workspace_skills(
_get_workspace_path_for_umo(event.unified_msg_origin)
workspace_skills: list[SkillInfo] = []
if runtime == "local":
workspace_root = await _get_workspace_path_for_umo(
event.unified_msg_origin,
plugin_context,
)
if runtime == "local"
else []
)
workspace_skills.extend(skill_manager.list_workspace_skills(workspace_root))

if skills or workspace_skills:
if persona and persona.get("skills") is not None:
Expand Down Expand Up @@ -989,7 +1015,7 @@ async def _decorate_llm_request(
if tz is None:
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
await _apply_workspace_extra_prompt(event, req, plugin_context)


def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
Expand Down Expand Up @@ -1590,10 +1616,14 @@ async def build_main_agent(
)

if config.computer_use_runtime == "local":
workspace_root = await _get_workspace_path_for_umo(
event.unified_msg_origin,
plugin_context,
)
workspace_prompt = f"\nCurrent workspace you can use: `{workspace_root}`\n"
tool_prompt += (
f"\nCurrent workspace you can use: "
f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n"
"Unless the user explicitly specifies a different directory, "
workspace_prompt
+ "Unless the user explicitly specifies a different directory, "
"perform all file-related operations in this workspace.\n"
)

Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ async def create_chatui_project(
title: str,
emoji: str | None = "📁",
description: str | None = None,
workspace_type: str = "session",
workspace_path: str | None = None,
) -> ChatUIProject:
"""Create a new ChatUI project."""
...
Expand All @@ -877,6 +879,8 @@ async def update_chatui_project(
title: str | None = None,
emoji: str | None = None,
description: str | None = None,
workspace_type: str | None = None,
workspace_path: str | None = None,
) -> None:
"""Update a ChatUI project."""
...
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ class ChatUIProject(TimestampMixin, SQLModel, table=True):
"""Title of the project"""
description: str | None = Field(default=None, max_length=1000)
"""Description of the project"""
workspace_type: str = Field(default="session", nullable=False, max_length=32)
"""Workspace mode: session, project, or custom"""
workspace_path: str | None = Field(default=None, max_length=1024)
"""Custom workspace path"""

__table_args__ = (
UniqueConstraint(
Expand Down
27 changes: 27 additions & 0 deletions astrbot/core/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def initialize(self) -> None:
await self._ensure_persona_skills_column(conn)
await self._ensure_persona_custom_error_message_column(conn)
await self._ensure_platform_message_history_checkpoint_column(conn)
await self._ensure_chatui_project_workspace_columns(conn)
await conn.commit()

async def _ensure_persona_folder_columns(self, conn) -> None:
Expand Down Expand Up @@ -128,6 +129,23 @@ async def _ensure_platform_message_history_checkpoint_column(self, conn) -> None
)
)

async def _ensure_chatui_project_workspace_columns(self, conn) -> None:
"""Ensure chatui_projects has workspace configuration columns."""
result = await conn.execute(text("PRAGMA table_info(chatui_projects)"))
columns = {row[1] for row in result.fetchall()}

if "workspace_type" not in columns:
await conn.execute(
text(
"ALTER TABLE chatui_projects "
"ADD COLUMN workspace_type VARCHAR(32) NOT NULL DEFAULT 'session'"
)
)
if "workspace_path" not in columns:
await conn.execute(
text("ALTER TABLE chatui_projects ADD COLUMN workspace_path VARCHAR")
)

# ====
# Platform Statistics
# ====
Expand Down Expand Up @@ -1877,6 +1895,8 @@ async def create_chatui_project(
title: str,
emoji: str | None = "📁",
description: str | None = None,
workspace_type: str = "session",
workspace_path: str | None = None,
) -> ChatUIProject:
"""Create a new ChatUI project."""
async with self.get_db() as session:
Expand All @@ -1887,6 +1907,8 @@ async def create_chatui_project(
title=title,
emoji=emoji,
description=description,
workspace_type=workspace_type,
workspace_path=workspace_path,
)
session.add(project)
await session.flush()
Expand Down Expand Up @@ -1929,6 +1951,8 @@ async def update_chatui_project(
title: str | None = None,
emoji: str | None = None,
description: str | None = None,
workspace_type: str | None = None,
workspace_path: str | None = None,
) -> None:
"""Update a ChatUI project."""
async with self.get_db() as session:
Expand All @@ -1941,6 +1965,9 @@ async def update_chatui_project(
values["emoji"] = emoji
if description is not None:
values["description"] = description
if workspace_type is not None:
values["workspace_type"] = workspace_type
values["workspace_path"] = workspace_path

await session.execute(
update(ChatUIProject)
Expand Down
Loading
Loading