diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index c29e45876d..f8cbe740b4 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -25,7 +25,10 @@ from .chunking.recursive import RecursiveCharacterChunker from .kb_db_sqlite import KBSQLiteDatabase from .models import KBDocument, KBMedia, KnowledgeBase -from .parsers.url_parser import extract_text_from_url +from .parsers.url_parser import ( + SUPPORTED_URL_EXTRACT_PROVIDERS, + extract_text_from_url, +) from .parsers.util import select_parser from .prompts import TEXT_REPAIR_SYSTEM_PROMPT @@ -616,14 +619,27 @@ async def upload_from_url( ValueError: 如果 URL 为空或无法提取内容 IOError: 如果网络请求失败 """ - # 获取 Tavily API 密钥 - config = self.prov_mgr.acm.default_conf - tavily_keys = config.get("provider_settings", {}).get( - "websearch_tavily_key", [] + # 根据配置的网页搜索提供商选择 URL 内容提取后端。 + # 仅 Tavily 与 Firecrawl 支持单页内容提取;其余提供商回退到 Tavily。 + provider_settings = ( + self.prov_mgr.acm.default_conf.get("provider_settings") or {} + ) + websearch_provider = provider_settings.get("websearch_provider", "tavily").lower() + url_extract_provider = ( + websearch_provider + if websearch_provider in SUPPORTED_URL_EXTRACT_PROVIDERS + else "tavily" ) - if not tavily_keys: + tavily_keys = provider_settings.get("websearch_tavily_key", []) + firecrawl_keys = provider_settings.get("websearch_firecrawl_key", []) + + provider_keys = ( + firecrawl_keys if url_extract_provider == "firecrawl" else tavily_keys + ) + if not provider_keys: raise ValueError( - "Error: Tavily API key is not configured in provider_settings." + f"Error: {url_extract_provider.capitalize()} API key is not " + "configured in provider_settings." ) # 阶段1: 从 URL 提取内容 @@ -631,7 +647,12 @@ async def upload_from_url( await progress_callback("extracting", 0, 100) try: - text_content = await extract_text_from_url(url, tavily_keys) + text_content = await extract_text_from_url( + url, + tavily_keys, + provider=url_extract_provider, + firecrawl_keys=firecrawl_keys, + ) except Exception as e: logger.error(f"Failed to extract content from URL {url}: {e}") raise OSError(f"Failed to extract content from URL {url}: {e}") from e diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py index 2867164a96..eed945f279 100644 --- a/astrbot/core/knowledge_base/parsers/url_parser.py +++ b/astrbot/core/knowledge_base/parsers/url_parser.py @@ -2,36 +2,74 @@ import aiohttp +# 支持从 URL 提取正文内容的网页搜索提供商。 +# 其余提供商(bocha、brave、baidu_ai_search 等)暂不支持单页内容提取。 +SUPPORTED_URL_EXTRACT_PROVIDERS = ("tavily", "firecrawl") + + +def _normalize_keys(keys: str | list[str] | None) -> list[str]: + """将密钥配置规范化为列表。 + + 兼容历史配置中将单个密钥存为字符串的情况,避免 list("key") 把字符串 + 拆成单个字符。 + """ + if isinstance(keys, str): + return [keys] if keys else [] + return list(keys or []) + class URLExtractor: - """URL 内容提取器,封装了 Tavily API 调用和密钥管理""" + """URL 内容提取器,封装 Tavily / Firecrawl API 调用和密钥轮换。 - def __init__(self, tavily_keys: list[str]) -> None: + 与 web_searcher 内置工具保持一致地支持多个网页搜索提供商,但这里是 + 专门为知识库模块设计的简化版本,不依赖 AstrMessageEvent。 + """ + + def __init__( + self, + tavily_keys: str | list[str] | None = None, + *, + provider: str = "tavily", + firecrawl_keys: str | list[str] | None = None, + ) -> None: """ 初始化 URL 提取器 Args: tavily_keys: Tavily API 密钥列表 + provider: URL 内容提取所用的提供商("tavily" 或 "firecrawl") + firecrawl_keys: Firecrawl API 密钥列表 """ - if not tavily_keys: - raise ValueError("Error: Tavily API keys are not configured.") - - self.tavily_keys = tavily_keys - self.tavily_key_index = 0 - self.tavily_key_lock = asyncio.Lock() - - async def _get_tavily_key(self) -> str: - """并发安全的从列表中获取并轮换Tavily API密钥。""" - async with self.tavily_key_lock: - key = self.tavily_keys[self.tavily_key_index] - self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys) + self.provider = (provider or "tavily").lower() + if self.provider not in SUPPORTED_URL_EXTRACT_PROVIDERS: + raise ValueError( + f"Error: Unsupported URL extraction provider '{self.provider}'. " + f"Supported providers: {', '.join(SUPPORTED_URL_EXTRACT_PROVIDERS)}." + ) + + self._keys: dict[str, list[str]] = { + "tavily": _normalize_keys(tavily_keys), + "firecrawl": _normalize_keys(firecrawl_keys), + } + if not self._keys[self.provider]: + raise ValueError( + f"Error: {self.provider.capitalize()} API keys are not configured." + ) + + self._key_index = 0 + self._key_lock = asyncio.Lock() + + async def _get_key(self) -> str: + """并发安全地从当前提供商的密钥列表中获取并轮换 API 密钥。""" + keys = self._keys[self.provider] + async with self._key_lock: + key = keys[self._key_index] + self._key_index = (self._key_index + 1) % len(keys) return key async def extract_text_from_url(self, url: str) -> str: """ - 使用 Tavily API 从 URL 提取主要文本内容。 - 这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本, - 专门为知识库模块设计,不依赖 AstrMessageEvent。 + 使用已配置的提供商从 URL 提取主要文本内容。 Args: url: 要提取内容的网页 URL @@ -40,13 +78,19 @@ async def extract_text_from_url(self, url: str) -> str: 提取的文本内容 Raises: - ValueError: 如果 URL 为空或 API 密钥未配置 + ValueError: 如果 URL 为空或未提取到内容 IOError: 如果请求失败或返回错误 """ if not url: raise ValueError("Error: url must be a non-empty string.") - tavily_key = await self._get_tavily_key() + if self.provider == "firecrawl": + return await self._extract_with_firecrawl(url) + return await self._extract_with_tavily(url) + + async def _extract_with_tavily(self, url: str) -> str: + """使用 Tavily API 从 URL 提取主要文本内容。""" + tavily_key = await self._get_key() api_url = "https://api.tavily.com/extract" headers = { "Authorization": f"Bearer {tavily_key}", @@ -83,21 +127,80 @@ async def extract_text_from_url(self, url: str) -> str: except aiohttp.ClientError as e: raise OSError(f"Failed to fetch URL {url}: {e}") from e + except (ValueError, OSError): + raise + except Exception as e: + raise OSError(f"Failed to extract content from URL {url}: {e}") from e + + async def _extract_with_firecrawl(self, url: str) -> str: + """使用 Firecrawl Scrape API 从 URL 提取主要文本内容(Markdown)。""" + firecrawl_key = await self._get_key() + api_url = "https://api.firecrawl.dev/v2/scrape" + headers = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + + payload = { + "url": url, + "formats": ["markdown"], + "onlyMainContent": True, + } + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + api_url, + json=payload, + headers=headers, + timeout=30.0, + ) as response: + if response.status != 200: + reason = await response.text() + raise OSError( + f"Firecrawl web extraction failed: {reason}, status: {response.status}" + ) + + data = await response.json() + result = data.get("data", {}) + content = result.get("markdown", "") if result else "" + + if not content: + raise ValueError(f"No content extracted from URL: {url}") + + return content + + except aiohttp.ClientError as e: + raise OSError(f"Failed to fetch URL {url}: {e}") from e + except (ValueError, OSError): + raise except Exception as e: raise OSError(f"Failed to extract content from URL {url}: {e}") from e # 为了向后兼容,提供一个简单的函数接口 -async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str: +async def extract_text_from_url( + url: str, + tavily_keys: str | list[str] | None = None, + *, + provider: str = "tavily", + firecrawl_keys: str | list[str] | None = None, +) -> str: """ 简单的函数接口,用于从 URL 提取文本内容 Args: url: 要提取内容的网页 URL tavily_keys: Tavily API 密钥列表 + provider: URL 内容提取所用的提供商("tavily" 或 "firecrawl") + firecrawl_keys: Firecrawl API 密钥列表 Returns: 提取的文本内容 """ - extractor = URLExtractor(tavily_keys) + extractor = URLExtractor( + tavily_keys, + provider=provider, + firecrawl_keys=firecrawl_keys, + ) return await extractor.extract_text_from_url(url) diff --git a/tests/unit/test_url_parser.py b/tests/unit/test_url_parser.py new file mode 100644 index 0000000000..08b4c48ef8 --- /dev/null +++ b/tests/unit/test_url_parser.py @@ -0,0 +1,142 @@ +import pytest + +from astrbot.core.knowledge_base.parsers import url_parser + + +class _FakeResponse: + def __init__(self, status: int, json_data: dict): + self.status = status + self._json_data = json_data + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + async def json(self): + return self._json_data + + async def text(self): + return "error body" + + +class _FakeSession: + """Captures the request and returns a canned response.""" + + def __init__(self, response: _FakeResponse, recorder: dict): + self._response = response + self._recorder = recorder + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + def post(self, url, json=None, headers=None, timeout=None): + self._recorder["url"] = url + self._recorder["json"] = json + self._recorder["headers"] = headers + return self._response + + +def _patch_session(monkeypatch, response: _FakeResponse) -> dict: + recorder: dict = {} + + def fake_client_session(*args, **kwargs): + return _FakeSession(response, recorder) + + monkeypatch.setattr(url_parser.aiohttp, "ClientSession", fake_client_session) + return recorder + + +def test_unsupported_provider_raises(): + with pytest.raises(ValueError): + url_parser.URLExtractor(["k"], provider="bocha") + + +def test_missing_keys_for_selected_provider_raises(): + # Firecrawl selected but only Tavily keys supplied. + with pytest.raises(ValueError): + url_parser.URLExtractor(["tavily-key"], provider="firecrawl") + + +@pytest.mark.asyncio +async def test_tavily_extraction_hits_tavily(monkeypatch): + response = _FakeResponse(200, {"results": [{"raw_content": "tavily body"}]}) + recorder = _patch_session(monkeypatch, response) + + content = await url_parser.extract_text_from_url( + "https://example.com", ["tavily-key"], provider="tavily" + ) + + assert content == "tavily body" + assert recorder["url"] == "https://api.tavily.com/extract" + assert recorder["headers"]["Authorization"] == "Bearer tavily-key" + assert recorder["json"]["urls"] == ["https://example.com"] + + +@pytest.mark.asyncio +async def test_firecrawl_extraction_hits_firecrawl(monkeypatch): + response = _FakeResponse(200, {"data": {"markdown": "# firecrawl body"}}) + recorder = _patch_session(monkeypatch, response) + + content = await url_parser.extract_text_from_url( + "https://example.com", + tavily_keys=[], + provider="firecrawl", + firecrawl_keys=["firecrawl-key"], + ) + + assert content == "# firecrawl body" + assert recorder["url"] == "https://api.firecrawl.dev/v2/scrape" + assert recorder["headers"]["Authorization"] == "Bearer firecrawl-key" + assert recorder["json"] == { + "url": "https://example.com", + "formats": ["markdown"], + "onlyMainContent": True, + } + + +@pytest.mark.asyncio +async def test_firecrawl_empty_content_raises(monkeypatch): + response = _FakeResponse(200, {"data": {"markdown": ""}}) + _patch_session(monkeypatch, response) + + with pytest.raises(ValueError): + await url_parser.extract_text_from_url( + "https://example.com", + provider="firecrawl", + firecrawl_keys=["firecrawl-key"], + ) + + +@pytest.mark.asyncio +async def test_single_string_key_is_not_split_into_chars(monkeypatch): + # Legacy configs may store a single key as a bare string; it must be + # treated as one key, not split into individual characters. + response = _FakeResponse(200, {"data": {"markdown": "body"}}) + recorder = _patch_session(monkeypatch, response) + + await url_parser.extract_text_from_url( + "https://example.com", + provider="firecrawl", + firecrawl_keys="firecrawl-key", + ) + + assert recorder["headers"]["Authorization"] == "Bearer firecrawl-key" + + +@pytest.mark.asyncio +async def test_default_provider_is_tavily_backward_compatible(monkeypatch): + response = _FakeResponse(200, {"results": [{"raw_content": "legacy body"}]}) + recorder = _patch_session(monkeypatch, response) + + # Legacy positional call signature must keep working. + content = await url_parser.extract_text_from_url( + "https://example.com", ["tavily-key"] + ) + + assert content == "legacy body" + assert recorder["url"] == "https://api.tavily.com/extract"