diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 16ebac7a8b..1eae0f17e3 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -29,6 +29,7 @@ TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, ) from astrbot.core.conversation_mgr import Conversation +from astrbot.core.db import BaseDatabase from astrbot.core.message.components import File, Image, Record, Reply, Video from astrbot.core.persona_error_reply import ( extract_persona_custom_error_message_from_persona, @@ -73,7 +74,6 @@ RollbackSkillReleaseTool, RunBrowserSkillTool, SyncSkillReleaseTool, - normalize_umo_for_workspace, ) from astrbot.core.tools.cron_tools import FutureTaskTool from astrbot.core.tools.knowledge_base_tools import ( @@ -115,6 +115,10 @@ extract_quoted_message_text, ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings +from astrbot.core.workspace import ( + normalize_umo_for_workspace, + resolve_workspace_root_for_umo, +) LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message" WEEKDAY_NAMES = ( @@ -357,41 +361,63 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None: req.prompt = f"{prefix}{req.prompt}" -def _get_workspace_path_for_umo(umo: str) -> Path: - normalized_umo = normalize_umo_for_workspace(umo) - return Path(get_astrbot_workspaces_path()) / normalized_umo +async def _get_workspace_path_for_umo(umo: str, plugin_context: Context) -> Path: + """Resolve the workspace path for the current request. + + Args: + umo: Unified message origin. + plugin_context: Star context containing the database instance. + + Returns: + Workspace path used as cwd. + """ + fallback_root = ( + Path(get_astrbot_workspaces_path()) / normalize_umo_for_workspace(umo) + ).resolve(strict=False) + db = getattr(plugin_context, "_db", None) + if not isinstance(db, BaseDatabase): + return fallback_root + try: + return await resolve_workspace_root_for_umo(umo, db) + except Exception: + return fallback_root -def _apply_workspace_extra_prompt( +async def _apply_workspace_extra_prompt( event: AstrMessageEvent, req: ProviderRequest, + plugin_context: Context, ) -> None: - extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / ( - "EXTRA_PROMPT.md" + workspace_root = await _get_workspace_path_for_umo( + event.unified_msg_origin, + plugin_context, ) - if not extra_prompt_path.is_file(): - return - - try: - extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip() - except Exception as exc: # noqa: BLE001 - logger.warning( - "Failed to read workspace extra prompt for umo=%s from %s: %s", - event.unified_msg_origin, - extra_prompt_path, - exc, - ) - return + extra_prompts: list[str] = [] + extra_prompt_path = workspace_root / "EXTRA_PROMPT.md" + if extra_prompt_path.is_file(): + try: + extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip() + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to read workspace extra prompt for umo=%s from %s: %s", + event.unified_msg_origin, + extra_prompt_path, + exc, + ) + else: + if extra_prompt: + extra_prompts.append(f"From `{extra_prompt_path}`:\n{extra_prompt}") - if not extra_prompt: + if not extra_prompts: return + extra_prompt_text = "\n\n".join(extra_prompts) req.system_prompt = ( f"{req.system_prompt or ''}\n" "[Workspace Extra Prompt]\n" "The following instructions are loaded from the current workspace " "`EXTRA_PROMPT.md` file.\n" - f"{extra_prompt}\n" + f"{extra_prompt_text}\n" ) @@ -498,13 +524,13 @@ async def _ensure_persona_and_skills( skill_manager = SkillManager() skills = skill_manager.list_skills(active_only=True, runtime=runtime) skills = _filter_skills_for_current_config(skills, cfg) - workspace_skills = ( - skill_manager.list_workspace_skills( - _get_workspace_path_for_umo(event.unified_msg_origin) + workspace_skills: list[SkillInfo] = [] + if runtime == "local": + workspace_root = await _get_workspace_path_for_umo( + event.unified_msg_origin, + plugin_context, ) - if runtime == "local" - else [] - ) + workspace_skills.extend(skill_manager.list_workspace_skills(workspace_root)) if skills or workspace_skills: if persona and persona.get("skills") is not None: @@ -989,7 +1015,7 @@ async def _decorate_llm_request( if tz is None: tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) - _apply_workspace_extra_prompt(event, req) + await _apply_workspace_extra_prompt(event, req, plugin_context) def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: @@ -1590,10 +1616,14 @@ async def build_main_agent( ) if config.computer_use_runtime == "local": + workspace_root = await _get_workspace_path_for_umo( + event.unified_msg_origin, + plugin_context, + ) + workspace_prompt = f"\nCurrent workspace you can use: `{workspace_root}`\n" tool_prompt += ( - f"\nCurrent workspace you can use: " - f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n" - "Unless the user explicitly specifies a different directory, " + workspace_prompt + + "Unless the user explicitly specifies a different directory, " "perform all file-related operations in this workspace.\n" ) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index cf37c48663..8e319bc529 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -851,6 +851,8 @@ async def create_chatui_project( title: str, emoji: str | None = "📁", description: str | None = None, + workspace_type: str = "session", + workspace_path: str | None = None, ) -> ChatUIProject: """Create a new ChatUI project.""" ... @@ -877,6 +879,8 @@ async def update_chatui_project( title: str | None = None, emoji: str | None = None, description: str | None = None, + workspace_type: str | None = None, + workspace_path: str | None = None, ) -> None: """Update a ChatUI project.""" ... diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 9a297b34da..5bfbc7d9e6 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -447,6 +447,10 @@ class ChatUIProject(TimestampMixin, SQLModel, table=True): """Title of the project""" description: str | None = Field(default=None, max_length=1000) """Description of the project""" + workspace_type: str = Field(default="session", nullable=False, max_length=32) + """Workspace mode: session, project, or custom""" + workspace_path: str | None = Field(default=None, max_length=1024) + """Custom workspace path""" __table_args__ = ( UniqueConstraint( diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index b7706cc513..f59f234192 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -64,6 +64,7 @@ async def initialize(self) -> None: await self._ensure_persona_skills_column(conn) await self._ensure_persona_custom_error_message_column(conn) await self._ensure_platform_message_history_checkpoint_column(conn) + await self._ensure_chatui_project_workspace_columns(conn) await conn.commit() async def _ensure_persona_folder_columns(self, conn) -> None: @@ -128,6 +129,23 @@ async def _ensure_platform_message_history_checkpoint_column(self, conn) -> None ) ) + async def _ensure_chatui_project_workspace_columns(self, conn) -> None: + """Ensure chatui_projects has workspace configuration columns.""" + result = await conn.execute(text("PRAGMA table_info(chatui_projects)")) + columns = {row[1] for row in result.fetchall()} + + if "workspace_type" not in columns: + await conn.execute( + text( + "ALTER TABLE chatui_projects " + "ADD COLUMN workspace_type VARCHAR(32) NOT NULL DEFAULT 'session'" + ) + ) + if "workspace_path" not in columns: + await conn.execute( + text("ALTER TABLE chatui_projects ADD COLUMN workspace_path VARCHAR") + ) + # ==== # Platform Statistics # ==== @@ -1877,6 +1895,8 @@ async def create_chatui_project( title: str, emoji: str | None = "📁", description: str | None = None, + workspace_type: str = "session", + workspace_path: str | None = None, ) -> ChatUIProject: """Create a new ChatUI project.""" async with self.get_db() as session: @@ -1887,6 +1907,8 @@ async def create_chatui_project( title=title, emoji=emoji, description=description, + workspace_type=workspace_type, + workspace_path=workspace_path, ) session.add(project) await session.flush() @@ -1929,6 +1951,8 @@ async def update_chatui_project( title: str | None = None, emoji: str | None = None, description: str | None = None, + workspace_type: str | None = None, + workspace_path: str | None = None, ) -> None: """Update a ChatUI project.""" async with self.get_db() as session: @@ -1941,6 +1965,9 @@ async def update_chatui_project( values["emoji"] = emoji if description is not None: values["description"] = description + if workspace_type is not None: + values["workspace_type"] = workspace_type + values["workspace_path"] = workspace_path await session.execute( update(ChatUIProject) diff --git a/astrbot/core/tools/computer_tools/fs.py b/astrbot/core/tools/computer_tools/fs.py index c3236c8a18..cac1131f86 100644 --- a/astrbot/core/tools/computer_tools/fs.py +++ b/astrbot/core/tools/computer_tools/fs.py @@ -14,7 +14,7 @@ implement them and the main agent does not expose them in local mode. - Member + local: read/grep are restricted to `data/skills`, plugin-provided `data/plugins/*/skills`, - `data/workspaces/{normalized_umo}`, and `/tmp/.astrbot`; write/edit are + the current session or project workspace, and `/tmp/.astrbot`; write/edit are restricted to the same local roots except plugin-provided Skills, which are read-only. Upload/download are denied by `check_admin_permission` if invoked. - Admin + sandbox: read/write/edit/grep are not path-restricted by this @@ -28,8 +28,7 @@ admin behavior. Local path resolution rule: -- In local runtime, relative paths are resolved under - `data/workspaces/{normalized_umo}`. +- In local runtime, relative paths are resolved under the primary workspace. - In sandbox runtime, relative paths are passed through unchanged. """ @@ -60,6 +59,7 @@ check_admin_permission, is_local_runtime, normalize_umo_for_workspace, + workspace_root_for_context, ) _COMPUTER_RUNTIME_TOOL_CONFIG = { @@ -77,9 +77,13 @@ def _remote_basename(path: str) -> str: return path.replace("\\", "/").rstrip("/").split("/")[-1] -def _restricted_env_path_labels(umo: str, *, include_plugin_skills: bool) -> list[str]: +def _restricted_env_path_labels( + umo: str, + *, + include_plugin_skills: bool, + current_workspace_root: Path | None = None, +) -> list[str]: """Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment""" - normalized_umo = normalize_umo_for_workspace(umo) labels = [ "data/skills", ] @@ -87,7 +91,7 @@ def _restricted_env_path_labels(umo: str, *, include_plugin_skills: bool) -> lis labels.append("data/plugins/*/skills") labels.extend( [ - f"data/workspaces/{normalized_umo}", + str(current_workspace_root or _workspace_root(umo)), get_astrbot_system_tmp_path(), get_astrbot_temp_path(), ] @@ -117,22 +121,28 @@ def _plugin_skill_roots() -> tuple[Path, ...]: ) -def _read_allowed_roots(umo: str) -> tuple[Path, ...]: +def _read_allowed_roots( + umo: str, + current_workspace_root: Path | None = None, +) -> tuple[Path, ...]: """Non-admin users can only read files within these directories (and their subdirectories)""" return ( Path(get_astrbot_skills_path()).resolve(strict=False), *_plugin_skill_roots(), - _workspace_root(umo), + current_workspace_root or _workspace_root(umo), Path(get_astrbot_system_tmp_path()).resolve(strict=False), Path(get_astrbot_temp_path()).resolve(strict=False), ) -def _write_allowed_roots(umo: str) -> tuple[Path, ...]: +def _write_allowed_roots( + umo: str, + current_workspace_root: Path | None = None, +) -> tuple[Path, ...]: """Non-admin users cannot modify plugin-provided Skills.""" return ( Path(get_astrbot_skills_path()).resolve(strict=False), - _workspace_root(umo), + current_workspace_root or _workspace_root(umo), Path(get_astrbot_system_tmp_path()).resolve(strict=False), Path(get_astrbot_temp_path()).resolve(strict=False), ) @@ -149,7 +159,13 @@ def _is_restricted_env(context: ContextWrapper[AstrAgentContext]) -> bool: return require_admin and context.context.event.role != "admin" -def _resolve_tool_path(path: str, *, local_env: bool, umo: str) -> str: +def _resolve_tool_path( + path: str, + *, + local_env: bool, + umo: str, + current_workspace_root: Path | None = None, +) -> str: normalized_path = path.strip() if not normalized_path: return normalized_path @@ -157,16 +173,28 @@ def _resolve_tool_path(path: str, *, local_env: bool, umo: str) -> str: if candidate.is_absolute(): return str(candidate.resolve(strict=False)) if local_env: - return str((_workspace_root(umo) / candidate).resolve(strict=False)) + return str( + ((current_workspace_root or _workspace_root(umo)) / candidate).resolve( + strict=False + ) + ) return normalized_path -def _resolve_user_path(path: str, *, local_env: bool, umo: str) -> Path: +def _resolve_user_path( + path: str, + *, + local_env: bool, + umo: str, + current_workspace_root: Path | None = None, +) -> Path: candidate = Path(path).expanduser() if candidate.is_absolute(): return candidate.resolve(strict=False) if local_env: - return (_workspace_root(umo) / candidate).resolve(strict=False) + return ((current_workspace_root or _workspace_root(umo)) / candidate).resolve( + strict=False + ) return (Path.cwd() / candidate).resolve(strict=False) @@ -175,8 +203,14 @@ def _is_path_within_allowed_roots( *, umo: str, allowed_roots: tuple[Path, ...], + current_workspace_root: Path | None = None, ) -> bool: - resolved = _resolve_user_path(path, local_env=True, umo=umo) + resolved = _resolve_user_path( + path, + local_env=True, + umo=umo, + current_workspace_root=current_workspace_root, + ) return any( resolved == allowed_root or resolved.is_relative_to(allowed_root) for allowed_root in allowed_roots @@ -209,19 +243,34 @@ def _normalize_rw_path( local_env: bool, umo: str, write: bool = False, + current_workspace_root: Path | None = None, ) -> str: - normalized_path = _resolve_tool_path(path, local_env=local_env, umo=umo) + normalized_path = _resolve_tool_path( + path, + local_env=local_env, + umo=umo, + current_workspace_root=current_workspace_root, + ) if not normalized_path: raise ValueError("`path` must be a non-empty string.") if restricted: - allowed_roots = _write_allowed_roots(umo) if write else _read_allowed_roots(umo) + allowed_roots = ( + _write_allowed_roots(umo, current_workspace_root) + if write + else _read_allowed_roots(umo, current_workspace_root) + ) if restricted and not _is_path_within_allowed_roots( normalized_path, umo=umo, allowed_roots=allowed_roots, + current_workspace_root=current_workspace_root, ): allowed = ", ".join( - _restricted_env_path_labels(umo, include_plugin_skills=not write) + _restricted_env_path_labels( + umo, + include_plugin_skills=not write, + current_workspace_root=current_workspace_root, + ) ) access = "Write" if write else "Read" raise PermissionError( @@ -291,6 +340,9 @@ async def call( ) -> ToolExecResult: local_env = is_local_runtime(context) restricted = _is_restricted_env(context) + current_workspace_root = ( + await workspace_root_for_context(context) if local_env else None + ) try: normalized_path = ( _normalize_rw_path( @@ -298,6 +350,7 @@ async def call( restricted=restricted, local_env=local_env, umo=context.context.event.unified_msg_origin, + current_workspace_root=current_workspace_root, ) if local_env else path.strip() @@ -316,7 +369,10 @@ async def call( offset=offset, limit=limit, workspace_dir=( - str(_workspace_root(context.context.event.unified_msg_origin)) + str( + current_workspace_root + or _workspace_root(context.context.event.unified_msg_origin) + ) if local_env else None ), @@ -358,6 +414,9 @@ async def call( ) -> ToolExecResult: local_env = is_local_runtime(context) restricted = _is_restricted_env(context) + current_workspace_root = ( + await workspace_root_for_context(context) if local_env else None + ) try: normalized_path = ( _normalize_rw_path( @@ -366,6 +425,7 @@ async def call( local_env=local_env, umo=context.context.event.unified_msg_origin, write=True, + current_workspace_root=current_workspace_root, ) if local_env else path.strip() @@ -437,6 +497,9 @@ async def call( umo = str(context.context.event.unified_msg_origin) local_env = is_local_runtime(context) restricted = _is_restricted_env(context) + current_workspace_root = ( + await workspace_root_for_context(context) if local_env else None + ) try: normalized_path = ( _normalize_rw_path( @@ -445,6 +508,7 @@ async def call( local_env=local_env, umo=umo, write=True, + current_workspace_root=current_workspace_root, ) if local_env else path.strip() @@ -594,15 +658,28 @@ def _normalize_search_paths( restricted: bool, local_env: bool, umo: str, + current_workspace_root: Path | None = None, ) -> list[str]: normalized = ( - [_resolve_tool_path(path, local_env=local_env, umo=umo)] if path else [] + [ + _resolve_tool_path( + path, + local_env=local_env, + umo=umo, + current_workspace_root=current_workspace_root, + ) + ] + if path + else [] ) if not normalized: if restricted: - return [str(root) for root in _read_allowed_roots(umo)] + return [ + str(root) + for root in _read_allowed_roots(umo, current_workspace_root) + ] if local_env: - return [str(_workspace_root(umo))] + return [str(current_workspace_root or _workspace_root(umo))] return ["."] if restricted: @@ -612,12 +689,17 @@ def _normalize_search_paths( if not _is_path_within_allowed_roots( path, umo=umo, - allowed_roots=_read_allowed_roots(umo), + allowed_roots=_read_allowed_roots(umo, current_workspace_root), + current_workspace_root=current_workspace_root, ) ] if disallowed: allowed = ", ".join( - _restricted_env_path_labels(umo, include_plugin_skills=True) + _restricted_env_path_labels( + umo, + include_plugin_skills=True, + current_workspace_root=current_workspace_root, + ) ) blocked = ", ".join(disallowed) raise PermissionError( @@ -644,6 +726,9 @@ async def call( local_env = is_local_runtime(context) restricted = _is_restricted_env(context) + current_workspace_root = ( + await workspace_root_for_context(context) if local_env else None + ) try: search_paths = ( self._normalize_search_paths( @@ -651,6 +736,7 @@ async def call( restricted=restricted, local_env=local_env, umo=context.context.event.unified_msg_origin, + current_workspace_root=current_workspace_root, ) if local_env else ([path.strip()] if path and path.strip() else ["."]) diff --git a/astrbot/core/tools/computer_tools/python.py b/astrbot/core/tools/computer_tools/python.py index f9500ff7e8..b51c225d97 100644 --- a/astrbot/core/tools/computer_tools/python.py +++ b/astrbot/core/tools/computer_tools/python.py @@ -11,7 +11,10 @@ from astrbot.core.message.message_event_result import MessageChain from ..registry import builtin_tool -from .util import check_admin_permission, workspace_root +from .util import ( + check_admin_permission, + workspace_root_for_context, +) _OS_NAME = platform.system() _SANDBOX_PYTHON_TOOL_CONFIG = { @@ -137,9 +140,7 @@ async def call( else context.tool_call_timeout ) try: - current_workspace_root = workspace_root( - context.context.event.unified_msg_origin - ) + current_workspace_root = await workspace_root_for_context(context) current_workspace_root.mkdir(parents=True, exist_ok=True) result = await sb.python.exec( code, diff --git a/astrbot/core/tools/computer_tools/shell.py b/astrbot/core/tools/computer_tools/shell.py index 1e1acfbf9a..88d1d69e4c 100644 --- a/astrbot/core/tools/computer_tools/shell.py +++ b/astrbot/core/tools/computer_tools/shell.py @@ -14,7 +14,11 @@ from astrbot.core.utils.astrbot_path import get_astrbot_system_tmp_path from ..registry import builtin_tool -from .util import check_admin_permission, is_local_runtime, workspace_root +from .util import ( + check_admin_permission, + is_local_runtime, + workspace_root_for_context, +) _COMPUTER_RUNTIME_TOOL_CONFIG = { "provider_settings.computer_use_runtime": ("local", "sandbox"), @@ -99,9 +103,7 @@ async def call( try: cwd: str | None = None if is_local_runtime(context): - current_workspace_root = workspace_root( - context.context.event.unified_msg_origin - ) + current_workspace_root = await workspace_root_for_context(context) current_workspace_root.mkdir(parents=True, exist_ok=True) cwd = str(current_workspace_root) diff --git a/astrbot/core/tools/computer_tools/util.py b/astrbot/core/tools/computer_tools/util.py index a3930b4c6a..9bb71298bd 100644 --- a/astrbot/core/tools/computer_tools/util.py +++ b/astrbot/core/tools/computer_tools/util.py @@ -1,20 +1,48 @@ -import re from pathlib import Path from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.db import BaseDatabase from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path +from astrbot.core.workspace import ( + normalize_umo_for_workspace, + resolve_workspace_root_for_umo, +) -def normalize_umo_for_workspace(umo: str) -> str: - normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip()) - return normalized or "unknown" +def workspace_root(umo: str) -> Path: + """Return the legacy workspace root for compatibility. + Args: + umo: Unified message origin. -def workspace_root(umo: str) -> Path: - """Root directory for relative paths in local runtime""" - normalized_umo = normalize_umo_for_workspace(umo) - return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False) + Returns: + Legacy per-session workspace root. + """ + return ( + Path(get_astrbot_workspaces_path()) / normalize_umo_for_workspace(umo) + ).resolve(strict=False) + + +async def workspace_root_for_context( + context: ContextWrapper[AstrAgentContext], +) -> Path: + """Resolve the workspace root for a tool call context. + + Args: + context: Tool call context. + + Returns: + Workspace root used as cwd. + """ + umo = context.context.event.unified_msg_origin + db = getattr(context.context.context, "_db", None) + if not isinstance(db, BaseDatabase): + return workspace_root(umo) + try: + return await resolve_workspace_root_for_umo(umo, db) + except Exception: + return workspace_root(umo) def is_local_runtime(context: ContextWrapper[AstrAgentContext]) -> bool: diff --git a/astrbot/core/tools/message_tools.py b/astrbot/core/tools/message_tools.py index 0f5ac7e5d3..7bb7eda26e 100644 --- a/astrbot/core/tools/message_tools.py +++ b/astrbot/core/tools/message_tools.py @@ -20,6 +20,7 @@ check_admin_permission, is_local_runtime, workspace_root, + workspace_root_for_context, ) from astrbot.core.tools.registry import builtin_tool from astrbot.core.utils.astrbot_path import ( @@ -28,10 +29,13 @@ ) -def _file_send_allowed_roots(umo: str | None) -> tuple[Path, ...]: +def _file_send_allowed_roots( + umo: str | None, + current_workspace_root: Path | None = None, +) -> tuple[Path, ...]: roots = [] if umo: - roots.append(workspace_root(umo)) + roots.append(current_workspace_root or workspace_root(umo)) roots.extend( [ Path(get_astrbot_temp_path()).resolve(strict=False), @@ -59,9 +63,10 @@ def _is_restricted_local_env(context: ContextWrapper[AstrAgentContext]) -> bool: def _can_send_local_file( context: ContextWrapper[AstrAgentContext], local_path: Path, + current_workspace_root: Path | None = None, ) -> bool: umo = context.context.event.unified_msg_origin - allowed_roots = _file_send_allowed_roots(umo) + allowed_roots = _file_send_allowed_roots(umo, current_workspace_root) if _is_path_within(local_path, allowed_roots): return True return is_local_runtime(context) and not _is_restricted_local_env(context) @@ -137,12 +142,18 @@ async def _resolve_path_from_sandbox( if not path: raise FileNotFoundError(f"{component_type} path is empty") + current_workspace_root = ( + await workspace_root_for_context(context) + if is_local_runtime(context) + else None + ) + # Relative host paths are resolved only inside the user's workspace. if not os.path.isabs(path): unified_msg_origin = context.context.event.unified_msg_origin if unified_msg_origin: + ws_path = current_workspace_root or workspace_root(unified_msg_origin) try: - ws_path = workspace_root(unified_msg_origin) ws_candidate = (ws_path / path).resolve(strict=False) if ws_candidate.is_file() and ws_candidate.is_relative_to(ws_path): return str(ws_candidate), False @@ -151,13 +162,16 @@ async def _resolve_path_from_sandbox( else: local_candidate = Path(path).expanduser().resolve(strict=False) if local_candidate.is_file(): - if _can_send_local_file(context, local_candidate): + if _can_send_local_file( + context, local_candidate, current_workspace_root + ): return str(local_candidate), False if is_local_runtime(context): allowed = ", ".join( str(root) for root in _file_send_allowed_roots( - context.context.event.unified_msg_origin + context.context.event.unified_msg_origin, + current_workspace_root, ) ) raise PermissionError( diff --git a/astrbot/core/workspace.py b/astrbot/core/workspace.py new file mode 100644 index 0000000000..645a0e3d66 --- /dev/null +++ b/astrbot/core/workspace.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +from astrbot.core.db import BaseDatabase +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path + +WORKSPACE_TYPE_SESSION = "session" +WORKSPACE_TYPE_PROJECT = "project" +WORKSPACE_TYPE_CUSTOM = "custom" +WORKSPACE_TYPES = { + WORKSPACE_TYPE_SESSION, + WORKSPACE_TYPE_PROJECT, + WORKSPACE_TYPE_CUSTOM, +} + + +def normalize_umo_for_workspace(umo: str) -> str: + """Normalize a unified message origin into a filesystem-safe name. + + Args: + umo: Unified message origin. + + Returns: + Filesystem-safe workspace directory name. + """ + normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip()) + return normalized or "unknown" + + +def normalize_project_workspace_type(value: Any) -> str: + """Normalize stored or incoming project workspace type. + + Args: + value: Raw workspace type value. + + Returns: + A known workspace type. + """ + workspace_type = str(value or WORKSPACE_TYPE_SESSION).strip().lower() + return ( + workspace_type if workspace_type in WORKSPACE_TYPES else WORKSPACE_TYPE_SESSION + ) + + +def normalize_workspace_path(path: Any) -> str | None: + """Normalize a custom workspace path value for storage. + + Args: + path: Raw path value from API or database. + + Returns: + Normalized path string, or None when empty. + """ + if not isinstance(path, str): + return None + value = path.strip() + return value or None + + +def default_workspace_root(umo: str) -> Path: + """Return the legacy per-session workspace root. + + Args: + umo: Unified message origin. + + Returns: + The legacy workspace directory path. + """ + return ( + Path(get_astrbot_workspaces_path()) / normalize_umo_for_workspace(umo) + ).resolve(strict=False) + + +def project_workspace_root(project_id: str) -> Path: + """Return the default shared workspace root for a ChatUI project. + + Args: + project_id: ChatUI project ID. + + Returns: + The project workspace directory path. + """ + safe_project_id = re.sub(r"[^A-Za-z0-9._-]+", "_", project_id.strip()) + return (Path(get_astrbot_workspaces_path()) / f"project_{safe_project_id}").resolve( + strict=False + ) + + +def workspace_path_to_root(path: str) -> Path: + """Resolve a custom workspace path. + + Args: + path: Stored workspace path. Relative values are rooted under AstrBot + workspaces. Absolute values must also remain within AstrBot + workspaces. + + Returns: + Absolute resolved path. + + Raises: + ValueError: If the path escapes or targets the AstrBot workspaces root. + """ + workspaces_root = Path(get_astrbot_workspaces_path()).resolve(strict=False) + candidate = Path(path).expanduser() + if not candidate.is_absolute(): + candidate = workspaces_root / candidate + resolved = candidate.resolve(strict=False) + if resolved == workspaces_root or not resolved.is_relative_to(workspaces_root): + raise ValueError( + "Workspace path must stay within a subdirectory of AstrBot workspaces" + ) + return resolved + + +def resolve_project_workspace_root(project: Any, *, fallback_umo: str) -> Path: + """Resolve the workspace root from a project record. + + Args: + project: Project-like object with workspace fields. + fallback_umo: UMO used when the project keeps legacy session workspaces. + + Returns: + Workspace root used as cwd. + """ + fallback = default_workspace_root(fallback_umo) + workspace_type = normalize_project_workspace_type( + getattr(project, "workspace_type", WORKSPACE_TYPE_SESSION) + ) + if workspace_type == WORKSPACE_TYPE_SESSION: + return fallback + if workspace_type == WORKSPACE_TYPE_PROJECT: + return project_workspace_root(str(project.project_id)) + if workspace_type == WORKSPACE_TYPE_CUSTOM: + workspace_path = normalize_workspace_path( + getattr(project, "workspace_path", None) + ) + if workspace_path: + return workspace_path_to_root(workspace_path) + return fallback + + +def parse_webchat_umo(umo: str) -> tuple[str, str] | None: + """Extract creator and session ID from a webchat UMO. + + Args: + umo: Unified message origin. + + Returns: + Tuple of creator and ChatUI session ID, or None for non-webchat UMO. + """ + try: + message_session = MessageSession.from_str(umo) + except Exception: + return None + + if message_session.platform_name != "webchat": + return None + + parts = message_session.session_id.split("!", 2) + if len(parts) != 3 or parts[0] != "webchat": + return None + return parts[1], parts[2] + + +async def resolve_workspace_root_for_umo( + umo: str, + db: BaseDatabase | None = None, +) -> Path: + """Resolve the workspace root for a UMO. + + Args: + umo: Unified message origin. + db: Optional database instance. Falls back to the global database helper. + + Returns: + Workspace root used as cwd. + """ + parsed = parse_webchat_umo(umo) + if not parsed: + return default_workspace_root(umo) + + creator, session_id = parsed + if db is None: + from astrbot.core import db_helper + + db = db_helper + + project = await db.get_project_by_session(session_id=session_id, creator=creator) + if not project: + return default_workspace_root(umo) + return resolve_project_workspace_root(project, fallback_umo=umo) diff --git a/astrbot/dashboard/schemas.py b/astrbot/dashboard/schemas.py index f37449c532..bd72e1e4b3 100644 --- a/astrbot/dashboard/schemas.py +++ b/astrbot/dashboard/schemas.py @@ -95,6 +95,8 @@ class ChatProjectRequest(OpenModel): title: str | None = None emoji: str | None = None description: str | None = None + workspace_type: str | None = None + workspace_path: str | None = None class ChatProjectSessionRequest(OpenModel): diff --git a/astrbot/dashboard/services/chatui_project_service.py b/astrbot/dashboard/services/chatui_project_service.py index 4a928751ef..34e711d744 100644 --- a/astrbot/dashboard/services/chatui_project_service.py +++ b/astrbot/dashboard/services/chatui_project_service.py @@ -1,7 +1,17 @@ from __future__ import annotations +import os + from astrbot.core.db import BaseDatabase from astrbot.core.utils.datetime_utils import to_utc_isoformat +from astrbot.core.workspace import ( + WORKSPACE_TYPE_CUSTOM, + WORKSPACE_TYPE_SESSION, + normalize_project_workspace_type, + normalize_workspace_path, + resolve_project_workspace_root, + workspace_path_to_root, +) class ChatUIProjectServiceError(Exception): @@ -17,6 +27,7 @@ async def create_project(self, username: str, data: object) -> dict: title = payload.get("title") emoji = payload.get("emoji", "📁") description = payload.get("description") + workspace_type, workspace_path = self._normalize_workspace_config(payload) if not title: raise ChatUIProjectServiceError("Missing key: title") @@ -26,6 +37,8 @@ async def create_project(self, username: str, data: object) -> dict: title=title, emoji=emoji, description=description, + workspace_type=workspace_type, + workspace_path=workspace_path, ) return self._serialize_project(project) @@ -53,12 +66,22 @@ async def update_project(self, username: str, data: object) -> None: if not project_id: raise ChatUIProjectServiceError("Missing key: project_id") - await self._get_owned_project(username, project_id) + project = await self._get_owned_project(username, project_id) + workspace_type = None + workspace_path = None + if "workspace_type" in payload or "workspace_path" in payload: + workspace_type, workspace_path = self._normalize_workspace_config( + payload, + fallback_type=project.workspace_type, + fallback_path=project.workspace_path, + ) await self.db.update_chatui_project( project_id=project_id, title=payload.get("title"), emoji=payload.get("emoji"), description=payload.get("description"), + workspace_type=workspace_type, + workspace_path=workspace_path, ) async def delete_project(self, username: str, project_id: str | None) -> None: @@ -136,11 +159,32 @@ async def _get_owned_session(self, username: str, session_id: str): @staticmethod def _serialize_project(project) -> dict: + workspace_type = normalize_project_workspace_type( + getattr(project, "workspace_type", WORKSPACE_TYPE_SESSION) + ) + workspace_path = normalize_workspace_path( + getattr(project, "workspace_path", None) + ) + resolved_workspace_path = None + if workspace_type != WORKSPACE_TYPE_SESSION: + fallback_umo = f"webchat:FriendMessage:webchat!{project.creator}!default" + try: + resolved_workspace_path = str( + resolve_project_workspace_root( + project, + fallback_umo=fallback_umo, + ) + ) + except ValueError: + resolved_workspace_path = None return { "project_id": project.project_id, "title": project.title, "emoji": project.emoji, "description": project.description, + "workspace_type": workspace_type, + "workspace_path": workspace_path, + "resolved_workspace_path": resolved_workspace_path, "created_at": to_utc_isoformat(project.created_at), "updated_at": to_utc_isoformat(project.updated_at), } @@ -160,3 +204,49 @@ def _serialize_session(session) -> dict: @staticmethod def _as_payload(data: object) -> dict: return data if isinstance(data, dict) else {} + + @staticmethod + def _normalize_workspace_config( + payload: dict, + *, + fallback_type: str | None = None, + fallback_path: str | None = None, + ) -> tuple[str, str | None]: + """Normalize project workspace config from request payload. + + Args: + payload: Request payload. + fallback_type: Existing workspace type used when omitted. + fallback_path: Existing workspace path used when omitted. + + Returns: + Normalized workspace type and path. + + Raises: + ChatUIProjectServiceError: If a custom workspace has no usable path. + """ + workspace_type = normalize_project_workspace_type( + payload.get("workspace_type", fallback_type or WORKSPACE_TYPE_SESSION) + ) + raw_path = payload.get("workspace_path", fallback_path) + workspace_path = normalize_workspace_path(raw_path) + if workspace_type != WORKSPACE_TYPE_CUSTOM: + workspace_path = None + return workspace_type, workspace_path + + if not workspace_path: + raise ChatUIProjectServiceError("Custom workspace requires a path") + + try: + workspace_root = workspace_path_to_root(workspace_path) + except ValueError as exc: + raise ChatUIProjectServiceError(str(exc)) from exc + if not workspace_root.exists(): + raise ChatUIProjectServiceError("Custom workspace path does not exist") + if not workspace_root.is_dir(): + raise ChatUIProjectServiceError("Custom workspace path must be a directory") + if not os.access(workspace_root, os.R_OK | os.W_OK | os.X_OK): + raise ChatUIProjectServiceError( + "Custom workspace path requires read, write, and enter permissions" + ) + return workspace_type, workspace_path diff --git a/dashboard/src/api/generated/openapi-v1/types.gen.ts b/dashboard/src/api/generated/openapi-v1/types.gen.ts index 49e94b2bb1..4e3d94ffc8 100644 --- a/dashboard/src/api/generated/openapi-v1/types.gen.ts +++ b/dashboard/src/api/generated/openapi-v1/types.gen.ts @@ -83,8 +83,12 @@ export type ChatProjectRequest = { title?: string; emoji?: string; description?: string; + workspace_type?: 'session' | 'project' | 'custom'; + workspace_path?: string; }; +export type workspace_type = 'session' | 'project' | 'custom'; + export type ChatRequest = { /** * Caller-declared WebChat sender/session owner. This value is used as the message sender identity and may participate in sender-ID-based command permission checks. Treat chat-scoped API keys as trusted backend credentials and map or validate usernames before accepting end-user input. diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css index 9ada8cf9c9..9f3c1e4c78 100644 --- a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css +++ b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css @@ -1,4 +1,4 @@ -/* Auto-generated MDI subset – 279 icons */ +/* Auto-generated MDI subset – 280 icons */ /* Do not edit manually. Run: pnpm run subset-icons */ @font-face { @@ -496,6 +496,10 @@ content: "\F024B"; } +.mdi-folder-cog-outline::before { + content: "\F1080"; +} + .mdi-folder-move::before { content: "\F0252"; } diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff index 6401c409e3..7dc440f448 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff differ diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 index d1a9fd21c8..2f42327fc6 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 differ diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index d9a63e7f77..497bafd39e 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -10,25 +10,31 @@ :class="{ collapsed: isSidebarCollapsed }" :permanent="lgAndUp" :temporary="!lgAndUp" - :rail="lgAndUp && sidebarCollapsed" + :rail="lgAndUp && customizer.chatSidebarCollapsed" :width="280" - :rail-width="68" + :rail-width="56" location="left" floating >