diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index b0078e27ce..faf253c97e 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -62,6 +62,23 @@ def _file_uri_to_path(uri: str) -> Optional[Path]: return Path(unquote(parsed.path)) +def _validate_path_segment(value: str, name: str) -> None: + """Validates that a value is safe for use as a single path segment. + + Args: + value: The string to validate (e.g. a user_id or session_id). + name: A human-readable label used in error messages. + + Raises: + InputValidationError: If the value contains path separators or traversal + sequences. + """ + if not value: + raise InputValidationError(f"{name} must not be empty.") + if any(sep in value for sep in ("/", "\\", "..")): + raise InputValidationError(f"{name} contains invalid characters: {value!r}") + + _USER_NAMESPACE_PREFIX = "user:" @@ -145,6 +162,7 @@ def _user_artifacts_dir(base_root: Path) -> Path: def _session_artifacts_dir(base_root: Path, session_id: str) -> Path: """Returns the path that stores session-scoped artifacts.""" + _validate_path_segment(session_id, "session_id") return base_root / "sessions" / session_id / "artifacts" @@ -220,6 +238,7 @@ def __init__(self, root_dir: Path | str): def _base_root(self, user_id: str, /) -> Path: """Returns the artifacts root directory for a user.""" + _validate_path_segment(user_id, "user_id") return self.root_dir / "users" / user_id def _scope_root( diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 25294d4909..8c288d58b4 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -769,6 +769,34 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("user_id", "session_id"), + [ + ("../escape", "sess123"), + ("user/../../etc", "sess123"), + ("user\\\\..\\\\secret", "sess123"), + ("valid_user", "../escape"), + ("valid_user", "sess/../../etc"), + ("valid_user", "sess\\\\..\\\\secret"), + ], +) +async def test_file_save_artifact_rejects_path_traversal_in_ids( + tmp_path, user_id, session_id +) -> None: + """FileArtifactService rejects user_id/session_id with path traversal.""" + artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts") + part = types.Part(text="content") + with pytest.raises(InputValidationError): + await artifact_service.save_artifact( + app_name="myapp", + user_id=user_id, + session_id=session_id, + filename="safe.txt", + artifact=part, + ) + + class TestEnsurePart: """Tests for the ensure_part normalization helper."""