From 7172ee6d70dc35823473c96e9fd7a256d9342832 Mon Sep 17 00:00:00 2001 From: sch_chun <2967725930@qq.com> Date: Mon, 22 Jun 2026 20:27:56 +0800 Subject: [PATCH] feat(provider): add DashScope STT (Qwen3-ASR) adapter --- astrbot/core/config/default.py | 27 +++ astrbot/core/provider/manager.py | 4 + .../sources/dashscope_stt_api_source.py | 166 +++++++++++++++ .../en-US/features/config-metadata.json | 10 + .../ru-RU/features/config-metadata.json | 10 + .../zh-CN/features/config-metadata.json | 10 + tests/test_dashscope_stt_api_source.py | 196 ++++++++++++++++++ 7 files changed, 423 insertions(+) create mode 100644 astrbot/core/provider/sources/dashscope_stt_api_source.py create mode 100644 tests/test_dashscope_stt_api_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7fb847dccd..44ef8e86fe 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1742,6 +1742,20 @@ "dashscope_tts_voice": "loongstella", "timeout": "20", }, + "阿里云百炼 STT (Qwen3-ASR)": { + "id": "dashscope_stt", + "provider": "dashscope", + "type": "dashscope_stt", + "provider_type": "speech_to_text", + "enable": False, + "api_key": "", + "api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen3-asr-flash", + "language": "", + "enable_itn": False, + "timeout": "20", + "proxy": "", + }, "Azure TTS": { "id": "azure_tts", "type": "azure_tts", @@ -2665,6 +2679,19 @@ "description": "启用", "type": "bool", }, + "api_key": { + "description": "API key", + "type": "string", + }, + "language": { + "description": "语种", + "type": "string", + }, + "enable_itn": { + "description": "开启 ITN", + "type": "bool", + "hint": "将数字文本转换为阿拉伯数字。" + }, "key": { "description": "API Key", "type": "list", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ae4001fcd6..65a9de3143 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -415,6 +415,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.whisper_selfhosted_source import ( ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost, ) + case "dashscope_stt": + from .sources.dashscope_stt_api_source import ( + ProviderDashScopeSTT as ProviderDashScopeSTT, + ) case "xinference_stt": from .sources.xinference_stt_provider import ( ProviderXinferenceSTT as ProviderXinferenceSTT, diff --git a/astrbot/core/provider/sources/dashscope_stt_api_source.py b/astrbot/core/provider/sources/dashscope_stt_api_source.py new file mode 100644 index 0000000000..a9df64b4b4 --- /dev/null +++ b/astrbot/core/provider/sources/dashscope_stt_api_source.py @@ -0,0 +1,166 @@ +import httpx +from astrbot import logger +from ..entities import ProviderType +from ..provider import STTProvider +from ..register import register_provider_adapter +from astrbot.core.utils.media_utils import MediaResolver, describe_media_ref + + +DEFAULT_DASHSCOPE_STT_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1" +DEFAULT_DASHSCOPE_STT_MODEL = "qwen3-asr-flash" + + +class DashScopeAPIError(Exception): + pass + + +def normalize_timeout(timeout: int | str | None) -> int | None: + if timeout in (None, ""): + return None + if isinstance(timeout, str): + return int(timeout) + return timeout + + +def build_headers(api_key: str) -> dict: + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + +def build_api_url(base_url: str) -> str: + normalized = base_url.rstrip("/") + if normalized.endswith("/chat/completions"): + return normalized + return normalized + "/chat/completions" + + +@register_provider_adapter( + "dashscope_stt", + "阿里云百炼 STT (Qwen3-ASR)", + provider_type=ProviderType.SPEECH_TO_TEXT, +) +class ProviderDashScopeSTT(STTProvider): + def __init__(self, provider_config: dict, provider_settings: dict): + super().__init__(provider_config, provider_settings) + self.api_key = provider_config.get("api_key", "") + if not self.api_key: + raise ValueError("DashScope STT requires an API key.") + self.base_url = ( + provider_config.get("api_base") or + provider_config.get("base_url") or + DEFAULT_DASHSCOPE_STT_BASE + ) + self.model = provider_config.get("model", DEFAULT_DASHSCOPE_STT_MODEL) + self.language = provider_config.get("language", "") or None + self.enable_itn = provider_config.get("enable_itn", False) + self.timeout = normalize_timeout(provider_config.get("timeout", 20)) + self.proxy = provider_config.get("proxy", "") + self._client = None + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + kwargs = {"timeout": self.timeout, "follow_redirects": True} + if self.proxy: + kwargs["proxy"] = self.proxy + self._client = httpx.AsyncClient(**kwargs) + return self._client + + async def get_text(self, audio_url: str) -> str: + """ + Transcribe audio from the given source using DashScope Qwen3-ASR. + + The audio source can be a local file path, a remote HTTP/HTTPS URL, or a + base64-encoded data URI. The MediaResolver will attempt to convert it to a + proper Data URL (data:audio/...;base64,) with WAV format. + + This method uses DashScope's OpenAI-compatible endpoint: + POST /v1/chat/completions + + The request includes: + - model: qwen3-asr-flash (or custom) + - messages: a user message with content of type "input_audio" + - asr_options: optionally specifies language and ITN (Inverse Text Normalization) + - stream: False (non-streaming response) + + Args: + audio_url (str): A media reference supported by MediaResolver, e.g., + - HTTP/HTTPS URL to an audio file + - Local file path (absolute or relative) + - Base64 data URI (if already in proper format, it will be used as-is) + + Returns: + str: The transcribed text from the audio. + + Raises: + ValueError: If the audio source cannot be resolved or converted. + DashScopeAPIError: If the API request fails (HTTP error, malformed response, + empty transcription result, or invalid parameters). + httpx.HTTPStatusError: If the underlying HTTP request fails (propagated if + not caught by the adapter). + + Note: + - The audio file size (original) must be <= 10 MB; base64 encoding increases + size by ~33%, so ensure original file is <= 7.5 MB. + - This implementation uses non-streaming mode; for long audio (e.g., > 60s), + consider using the async file transcription model qwen3-asr-flash-filetrans. + """ + # 1. Obtain audio data + audio_data = await MediaResolver( + audio_url, + media_type="audio", + default_suffix=".wav", + ).to_base64_data(strict=True, target_format="wav") + + if audio_data is None: + raise ValueError(f"Failed to parse audio source: {describe_media_ref(audio_url)}") + + # 2. Build data URI + data_uri = audio_data.to_data_url() + + # 3. Build request body + content = [{"type": "input_audio", "input_audio": {"data": data_uri}}] + asr_options = {} + if self.language: + asr_options["language"] = self.language + if self.enable_itn is not None: + asr_options["enable_itn"] = self.enable_itn + + payload = { + "model": self.model, + "messages": [{"role": "user", "content": content}], + "stream": False, + } + if asr_options: + payload["asr_options"] = asr_options + + # 4. Send request + client = self._get_client() + url = build_api_url(self.base_url) + resp = await client.post(url, headers=build_headers(self.api_key), json=payload) + try: + resp.raise_for_status() + except httpx.HTTPStatusError as e: + error_text = resp.text[:1024] + raise DashScopeAPIError( + f"DashScope STT request failed (HTTP {resp.status_code}): {error_text}" + ) from e + + data = resp.json() + choices = data.get("choices") + if not choices: + raise DashScopeAPIError(f"No choices in response: {data}") + + first = choices[0] + message = first.get("message") or {} + content_text = message.get("content", "") + if not isinstance(content_text, str) or not content_text.strip(): + raise DashScopeAPIError(f"The recognition result is empty: {data}") + + return content_text.strip() + + async def terminate(self): + if self._client: + await self._client.aclose() + self._client = None diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 31d1362b26..35da7ea48d 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -1662,6 +1662,16 @@ "enable": { "description": "Enable" }, + "api_key": { + "description": "API Key" + }, + "language": { + "description": "Language" + }, + "enable_itn": { + "description": "Enable ITN", + "hint": "Convert digit words to Arabic numerals" + }, "key": { "description": "API Key" }, diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index f0a1294d7c..99a6a329e5 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -1659,6 +1659,16 @@ "enable": { "description": "Включить" }, + "api_key": { + "description": "API ключ" + }, + "language": { + "description": "Язык" + }, + "enable_itn": { + "description": "Включить ITN", + "hint": "Преобразовать цифры в арабские" + }, "key": { "description": "Ключ Coze API для доступа к сервисам Coze." }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 6ae988383f..66de9f10e9 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1664,6 +1664,16 @@ "enable": { "description": "启用" }, + "api_key": { + "description": "API Key" + }, + "language": { + "description": "语种" + }, + "enable_itn": { + "description": "开启 ITN", + "hint": "将数字文本转换为阿拉伯数字" + }, "key": { "description": "API Key" }, diff --git a/tests/test_dashscope_stt_api_source.py b/tests/test_dashscope_stt_api_source.py new file mode 100644 index 0000000000..a236a6a342 --- /dev/null +++ b/tests/test_dashscope_stt_api_source.py @@ -0,0 +1,196 @@ +import pytest +import httpx +from unittest.mock import AsyncMock, MagicMock, patch +from astrbot.core.provider.sources.dashscope_stt_api_source import ( + ProviderDashScopeSTT, + DashScopeAPIError, + build_headers, + build_api_url, + normalize_timeout, +) +from astrbot.core.provider.entities import ProviderType +from astrbot.core.utils.media_utils import MediaResolver + + +@pytest.fixture +def mock_media_resolver(): + """Mock MediaResolver to return a fake base64 data URI.""" + with patch("astrbot.core.provider.sources.dashscope_stt_api_source.MediaResolver") as mock: + instance = mock.return_value + resolved_mock = MagicMock() + resolved_mock.base64_data = "SUQzBAAAAA..." + resolved_mock.mime_type = "audio/wav" + resolved_mock.to_data_url.return_value = "data:audio/wav;base64,SUQzBAAAAA..." + instance.to_base64_data = AsyncMock(return_value=resolved_mock) + yield mock + + +@pytest.fixture +def provider_config(): + return { + "id": "dashscope_stt_test", + "api_key": "sk-test123", + "api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen3-asr-flash", + "language": "zh", + "enable_itn": True, + "timeout": 30, + "proxy": "", + } + + +@pytest.fixture +def provider(provider_config): + return ProviderDashScopeSTT(provider_config, {}) + + +@pytest.mark.asyncio +async def test_build_headers(): + headers = build_headers("sk-abc") + assert headers["Authorization"] == "Bearer sk-abc" + assert headers["Content-Type"] == "application/json" + + +@pytest.mark.asyncio +async def test_build_api_url(): + # With trailing slash + assert build_api_url("https://api.example.com/v1/") == "https://api.example.com/v1/chat/completions" + # Without trailing slash + assert build_api_url("https://api.example.com/v1") == "https://api.example.com/v1/chat/completions" + # Already ends with /chat/completions + assert build_api_url("https://api.example.com/v1/chat/completions") == "https://api.example.com/v1/chat/completions" + + +@pytest.mark.asyncio +async def test_normalize_timeout(): + assert normalize_timeout(None) is None + assert normalize_timeout("") is None + assert normalize_timeout(30) == 30 + assert normalize_timeout("30") == 30 + + +@pytest.mark.asyncio +async def test_get_text_success(provider, mock_media_resolver): + """Test successful transcription.""" + # Mock HTTP response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "欢迎使用阿里云。" + } + } + ] + } + mock_response.raise_for_status = MagicMock() + + with patch.object(provider, "_get_client") as mock_client: + client = AsyncMock() + client.post = AsyncMock(return_value=mock_response) + mock_client.return_value = client + + result = await provider.get_text("https://example.com/audio.mp3") + + assert result == "欢迎使用阿里云。" + # Verify correct payload was sent + call_args = client.post.call_args + assert call_args[0][0] == "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" + payload = call_args[1]["json"] + assert payload["model"] == "qwen3-asr-flash" + assert payload["messages"][0]["content"][0]["type"] == "input_audio" + assert payload["messages"][0]["content"][0]["input_audio"]["data"].startswith("data:audio/wav;base64,") + assert payload["asr_options"]["language"] == "zh" + assert payload["asr_options"]["enable_itn"] is True + assert payload["stream"] is False + + +@pytest.mark.asyncio +async def test_get_text_with_empty_language(provider_config, mock_media_resolver): + """Test that language is omitted when not provided.""" + provider_config["language"] = "" + provider = ProviderDashScopeSTT(provider_config, {}) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Hello world."}}] + } + mock_response.raise_for_status = MagicMock() + + with patch.object(provider, "_get_client") as mock_client: + client = AsyncMock() + client.post = AsyncMock(return_value=mock_response) + mock_client.return_value = client + + await provider.get_text("https://example.com/audio.mp3") + payload = client.post.call_args[1]["json"] + assert "language" not in payload["asr_options"] + + +@pytest.mark.asyncio +async def test_get_text_http_error(provider, mock_media_resolver): + """Test HTTP error handling (e.g., 400 Bad Request).""" + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = '{"error":{"message":"Invalid parameter"}}' + mock_response.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError( + "Bad Request", request=MagicMock(), response=mock_response + )) + + with patch.object(provider, "_get_client") as mock_client: + client = AsyncMock() + client.post = AsyncMock(return_value=mock_response) + mock_client.return_value = client + + with pytest.raises(DashScopeAPIError) as excinfo: + await provider.get_text("https://example.com/audio.mp3") + assert "DashScope STT request failed" in str(excinfo.value) + assert "400" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_get_text_empty_response(provider, mock_media_resolver): + """Test when API returns empty choices or content.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"choices": []} + mock_response.raise_for_status = MagicMock() + + with patch.object(provider, "_get_client") as mock_client: + client = AsyncMock() + client.post = AsyncMock(return_value=mock_response) + mock_client.return_value = client + + with pytest.raises(DashScopeAPIError) as excinfo: + await provider.get_text("https://example.com/audio.mp3") + assert "No choices" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_get_text_media_resolver_failure(provider): + """Test when MediaResolver fails to convert audio.""" + with patch("astrbot.core.provider.sources.dashscope_stt_api_source.MediaResolver") as mock: + instance = mock.return_value + instance.to_base64_data = AsyncMock(return_value=None) + + with pytest.raises(ValueError) as excinfo: + await provider.get_text("invalid_audio") + assert "Failed to parse audio source" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_terminate(provider): + """Test terminate closes the client.""" + client = AsyncMock() + provider._client = client + await provider.terminate() + client.aclose.assert_called_once() + assert provider._client is None + + +@pytest.mark.asyncio +async def test_terminate_no_client(provider): + """Test terminate when client is None.""" + provider._client = None + await provider.terminate() # Should not raise