Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@
"DeepSeek": {
"id": "deepseek",
"provider": "deepseek",
"type": "openai_chat_completion",
"type": "deepseek_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def dynamic_import_provider(self, type: str) -> None:
from .sources.openrouter_source import (
ProviderOpenRouter as ProviderOpenRouter,
)
case "deepseek_chat_completion":
from .sources.deepseek_source import (
ProviderDeepSeek as ProviderDeepSeek,
)
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
Expand Down
265 changes: 265 additions & 0 deletions astrbot/core/provider/sources/deepseek_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from collections.abc import AsyncGenerator
from typing import Any

from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion

import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse

from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial


@register_provider_adapter(
"deepseek_chat_completion",
"DeepSeek Chat Completion 提供商适配器",
)
class ProviderDeepSeek(ProviderOpenAIOfficial):
_FORCE_OMIT_TOOL_CHOICE_KEY = "_deepseek_force_omit_tool_choice"

@staticmethod
def _extract_thinking_type(source: Any) -> str | None:
if not isinstance(source, dict):
return None
thinking = source.get("thinking")
if not isinstance(thinking, dict):
return None
thinking_type = thinking.get("type")
if not isinstance(thinking_type, str):
return None
normalized = thinking_type.strip().lower()
return normalized or None

def _is_thinking_enabled(
self,
payloads: dict,
extra_body: dict[str, Any] | None = None,
) -> bool:
for source in (
payloads,
extra_body,
self.provider_config.get("custom_extra_body", {}),
):
thinking_type = self._extract_thinking_type(source)
if thinking_type == "enabled":
return True
if thinking_type == "disabled":
return False
# DeepSeek documents thinking mode as enabled by default.
return True

def _is_thinking_tool_choice_error(self, error: Exception) -> bool:
for candidate in self._extract_error_text_candidates(error):
lowered = candidate.lower()
if "tool_choice" in lowered and (
"thinking" in lowered or "reasoning" in lowered
):
return True
return False

def _normalize_tool_choice(
self,
payloads: dict,
extra_body: dict[str, Any],
*,
thinking_enabled: bool,
force_omit: bool = False,
) -> None:
if not thinking_enabled and not force_omit:
return

payload_tool_choice = payloads.pop("tool_choice", None)
extra_tool_choice = extra_body.pop("tool_choice", None)
removed_tool_choice = (
payload_tool_choice
if payload_tool_choice is not None
else extra_tool_choice
)
if removed_tool_choice and removed_tool_choice != "auto":
logger.warning(
f"{self.get_model()} 思考模式不支持 tool_choice={removed_tool_choice!r},"
"已改为 DeepSeek 默认工具选择策略。"
)

def _prepare_request(
self,
payloads: dict,
tools: ToolSet | None,
) -> tuple[dict, dict[str, Any], ToolSet | None]:
if tools:
tool_list = tools.get_func_desc_openai_style(
omit_empty_parameter_field=False,
)
if tool_list:
payloads["tools"] = tool_list

extra_body: dict[str, Any] = {}
to_del = []
for key in payloads:
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]

custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
self._apply_provider_specific_extra_body_overrides(extra_body)

force_omit = bool(payloads.pop(self._FORCE_OMIT_TOOL_CHOICE_KEY, False))
thinking_enabled = self._is_thinking_enabled(payloads, extra_body)
self._normalize_tool_choice(
payloads,
extra_body,
thinking_enabled=thinking_enabled,
force_omit=force_omit,
)
self._sanitize_assistant_messages(payloads)
return payloads, extra_body, tools

def _finally_convert_payload(self, payloads: dict) -> None:
assistant_messages_without_reasoning = set()
if not self._is_thinking_enabled(payloads):
for idx, message in enumerate(payloads.get("messages", [])):
if (
isinstance(message, dict)
and message.get("role") == "assistant"
and "reasoning_content" not in message
):
assistant_messages_without_reasoning.add(idx)

super()._finally_convert_payload(payloads)

if not assistant_messages_without_reasoning:
return

for idx in assistant_messages_without_reasoning:
message = payloads["messages"][idx]
if message.get("reasoning_content") == "":
message.pop("reasoning_content", None)

async def _handle_api_error(
self,
e: Exception,
payloads: dict,
context_query: list,
func_tool: ToolSet | None,
chosen_key: str,
available_api_keys: list[str],
retry_cnt: int,
max_retries: int,
image_fallback_used: bool = False,
) -> tuple:
if self._is_thinking_tool_choice_error(e):
logger.warning(
f"{self.get_model()} 思考模式不支持当前 tool_choice,已移除该参数并重试。"
)
payloads.pop("tool_choice", None)
payloads[self._FORCE_OMIT_TOOL_CHOICE_KEY] = True
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
image_fallback_used,
)
return await super()._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
image_fallback_used=image_fallback_used,
)
Comment thread
murphys7017 marked this conversation as resolved.

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting shared request-building and chunk-handling helpers and making streaming responses stateless to reduce duplication and local complexity without changing behavior.

You can reduce the new complexity without changing behavior by extracting a few small helpers and avoiding the shared mutable llm_response in streaming.

1. Factor out shared request-building logic

Both _query and _query_stream do the same steps (tools, tool_choice, extra_body extraction, sanitization). Centralizing this will make future changes safer.

def _build_deepseek_request(
    self,
    payloads: dict,
    tools: ToolSet | None,
    *,
    stream: bool,
) -> tuple[dict, dict]:
    # 1) tools
    if tools:
        tool_list = tools.get_func_desc_openai_style(
            omit_empty_parameter_field=False,
        )
        if tool_list:
            payloads["tools"] = tool_list

    # 2) tool_choice normalization
    self._normalize_deepseek_tool_choice(payloads)

    # 3) extra_body construction
    extra_body = {}
    # keep stream/non-stream behavior differences if needed
    custom_extra_body = self.provider_config.get("custom_extra_body", {})
    if isinstance(custom_extra_body, dict):
        extra_body.update(custom_extra_body)

    # move non-default params into extra_body
    base_payloads, payload_extra_body = self._extract_extra_body(payloads)
    extra_body.update(payload_extra_body)

    # provider-specific overrides
    self._apply_provider_specific_extra_body_overrides(extra_body)

    # sanitize assistant messages on payloads we actually send
    self._sanitize_assistant_messages(base_payloads)

    return base_payloads, extra_body

Usage in _query and _query_stream stays very local:

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
    payloads, extra_body = self._build_deepseek_request(
        payloads,
        tools,
        stream=False,
    )
    completion = await self.client.chat.completions.create(
        **payloads,
        stream=False,
        extra_body=extra_body,
    )
    ...
async def _query_stream(
    self,
    payloads: dict,
    tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
    payloads, extra_body = self._build_deepseek_request(
        payloads,
        tools,
        stream=True,
    )
    stream = await self.client.chat.completions.create(
        **payloads,
        stream=True,
        extra_body=extra_body,
        stream_options={"include_usage": True},
    )
    ...

2. Simplify extra_body extraction

The to_del pattern can be replaced with a small helper that does the pruning once and clearly:

def _extract_extra_body(self, payloads: dict) -> tuple[dict, dict]:
    base_payloads = {}
    extra_body = {}
    for key, value in payloads.items():
        if key in self.default_params:
            base_payloads[key] = value
        else:
            extra_body[key] = value
    return base_payloads, extra_body

This keeps your behavior (non-default params moved to extra_body) but removes the two-step delete loop.

3. Extract tool-call normalization

Move the DeepSeek-specific normalization out of the streaming loop to reduce noise and make intent clear:

def _normalize_tool_calls(self, tool_calls) -> None:
    if not tool_calls:
        return
    for idx, tc in enumerate(tool_calls):
        if getattr(tc, "function", None) and getattr(tc.function, "arguments", None):
            tc.type = "function"
        if not hasattr(tc, "index") or tc.index is None:
            tc.index = idx

Then in _query_stream:

async for chunk in stream:
    choice = chunk.choices[0] if chunk.choices else None
    delta = choice.delta if choice else None

    if delta and delta.tool_calls:
        self._normalize_tool_calls(delta.tool_calls)
    ...

4. Make per-chunk streaming responses stateless

Instead of mutating one LLMResponse instance for every chunk, construct a fresh one per chunk and put the merging logic in a helper. This keeps the stream state (ChatCompletionStreamState) separate from chunk delivery.

async def _chunk_to_llm_response(
    self,
    chunk,
    choice,
    delta,
    tools: ToolSet | None,
) -> LLMResponse | None:
    reasoning = self._extract_reasoning_content(chunk)
    has_delta = bool(reasoning or (delta and delta.content))

    if not has_delta and not chunk.usage and not getattr(choice, "usage", None):
        return None

    resp = LLMResponse("assistant", is_chunk=True)
    resp.id = chunk.id

    if reasoning is not None:
        resp.reasoning_content = reasoning

    if delta and delta.content:
        completion_text = self._normalize_content(delta.content, strip=False)
        resp.result_chain = MessageChain(
            chain=[Comp.Plain(completion_text)],
        )

    if chunk.usage:
        resp.usage = self._extract_usage(chunk.usage)
    elif choice and (choice_usage := getattr(choice, "usage", None)):
        resp.usage = self._extract_usage(choice_usage)

    return resp

And _query_stream becomes easier to follow:

async def _query_stream(...):
    ...
    state = ChatCompletionStreamState()

    async for chunk in stream:
        choice = chunk.choices[0] if chunk.choices else None
        delta = choice.delta if choice else None

        if delta and delta.tool_calls:
            self._normalize_tool_calls(delta.tool_calls)

        if delta is not None or chunk.usage:
            try:
                state.handle_chunk(chunk)
            except Exception as e:
                logger.error("Saving chunk state error: " + str(e))

        resp = await self._chunk_to_llm_response(chunk, choice, delta, tools)
        if resp:
            # if you still need to keep usage snapshot in state:
            if choice and getattr(choice, "usage", None):
                state.current_completion_snapshot.usage = choice.usage
            yield resp

    try:
        final_completion = state.get_final_completion()
        final_response = await self._parse_openai_completion(final_completion, tools)
        yield final_response
    except Exception as e:
        logger.error("get_final_completion error: " + str(e))
        return

This keeps all existing functionality (reasoning extraction, usage handling, ChatCompletionStreamState reconstruction) but removes the shared mutable llm_response and localizes the per-chunk logic.

payloads, extra_body, tools = self._prepare_request(payloads, tools)

completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
)

if not isinstance(completion, ChatCompletion):
raise Exception(
f"API 返回的 completion 类型错误:{type(completion)}: {completion}。",
)

logger.debug(f"completion: {completion}")

return await self._parse_openai_completion(completion, tools)

async def _query_stream(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
payloads, extra_body, tools = self._prepare_request(payloads, tools)

stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
)

llm_response = LLMResponse("assistant", is_chunk=True)
state = ChatCompletionStreamState()

async for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
delta = choice.delta if choice else None

if delta and (dtcs := delta.tool_calls):
for idx, tc in enumerate(dtcs):
if tc.function and tc.function.arguments:
tc.type = "function"
if not hasattr(tc, "index") or tc.index is None:
tc.index = idx

if delta is not None or chunk.usage:
try:
state.handle_chunk(chunk)
except Exception as e:
logger.error("Saving chunk state error: " + str(e))

reasoning = self._extract_reasoning_content(chunk)
has_delta = False
llm_response.id = chunk.id
llm_response.reasoning_content = None
llm_response.completion_text = ""
if reasoning is not None:
llm_response.reasoning_content = reasoning
has_delta = True
if delta and delta.content:
completion_text = self._normalize_content(delta.content, strip=False)
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
has_delta = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
elif choice and (choice_usage := getattr(choice, "usage", None)):
llm_response.usage = self._extract_usage(choice_usage)
state.current_completion_snapshot.usage = choice_usage
if has_delta:
yield llm_response

try:
final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
yield llm_response
except Exception as e:
logger.error("get_final_completion error: " + str(e))
return
Comment on lines +185 to +265

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

🔴 严重缺陷:缺少 request_max_retries 参数导致 TypeError 崩溃 & 丢失重试机制

当前重写的 _query_query_stream 方法存在以下严重问题:

  1. 运行时崩溃 (TypeError):基类 ProviderOpenAIOfficial 中的 text_chattext_chat_stream 在调用 _query / _query_stream 时会传入 request_max_retries 关键字参数。由于当前重写的方法签名中缺少该参数,运行时会直接抛出 TypeError: _query() got an unexpected keyword argument 'request_max_retries' 导致请求崩溃。
  2. 丢失重试机制:当前实现中直接调用了 self.client.chat.completions.create,而没有使用 retry_provider_request 包装,导致 DeepSeek 适配器完全失去了配置中指定的请求重试能力。
  3. 代码冗余:复制了基类中大量的流式处理和错误处理逻辑,不利于后续维护。

💡 优雅的解决方案

我们可以通过继承基类的 _query_query_stream,并使用一个自定义的 dict 子类(DeepSeekPayloadDict)来动态拦截并阻止 tool_choice 的写入。为了避免代码重复,应将 DeepSeekPayloadDict 提取为共享的辅助类,而不是在每个方法中重复定义。

class DeepSeekPayloadDict(dict):
    def __setitem__(self, key, value):
        if key == "tool_choice":
            return
        super().__setitem__(key, value)

    async def _query(
        self,
        payloads: dict,
        tools: ToolSet | None,
        *,
        request_max_retries: int | None = None,
    ) -> LLMResponse:
        if self._deepseek_omits_tool_choice(payloads):
            self._normalize_deepseek_tool_choice(payloads)
            payloads = DeepSeekPayloadDict(payloads)
        return await super()._query(payloads, tools, request_max_retries=request_max_retries)

    async def _query_stream(
        self,
        payloads: dict,
        tools: ToolSet | None,
        *,
        request_max_retries: int | None = None,
    ) -> AsyncGenerator[LLMResponse, None]:
        if self._deepseek_omits_tool_choice(payloads):
            self._normalize_deepseek_tool_choice(payloads)
            payloads = DeepSeekPayloadDict(payloads)
        async for response in super()._query_stream(payloads, tools, request_max_retries=request_max_retries):
            yield response
References
  1. When implementing similar functionality for different cases, refactor the logic into a shared helper function to avoid code duplication.

106 changes: 106 additions & 0 deletions tests/test_deepseek_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import asyncio
from types import SimpleNamespace

from astrbot.core.provider.sources.deepseek_source import ProviderDeepSeek


def _make_provider(overrides: dict | None = None) -> ProviderDeepSeek:
provider_config = {
"id": "test-deepseek",
"type": "deepseek_chat_completion",
"model": "deepseek-v4-flash",
"key": ["test-key"],
"custom_extra_body": {},
}
if overrides:
provider_config.update(overrides)
return ProviderDeepSeek(
provider_config=provider_config,
provider_settings={},
)


def test_deepseek_thinking_mode_removes_tool_choice_from_payload_and_extra_body():
provider = _make_provider(
{
"custom_extra_body": {
"thinking": {"type": "enabled"},
"tool_choice": "required",
}
}
)
try:
payloads = {
"model": "deepseek-v4-flash",
"messages": [{"role": "user", "content": "hello"}],
"tool_choice": "required",
}

normalized_payloads, extra_body, _ = provider._prepare_request(payloads, None)

assert "tool_choice" not in normalized_payloads
assert "tool_choice" not in extra_body
assert extra_body["thinking"]["type"] == "enabled"
finally:
asyncio.run(provider.terminate())


def test_deepseek_non_thinking_mode_keeps_tool_choice():
provider = _make_provider(
{
"custom_extra_body": {
"thinking": {"type": "disabled"},
}
}
)
try:
payloads = {
"model": "deepseek-v4-flash",
"messages": [{"role": "user", "content": "hello"}],
"tool_choice": "required",
}

normalized_payloads, extra_body, _ = provider._prepare_request(payloads, None)

assert normalized_payloads["tool_choice"] == "required"
assert extra_body["thinking"]["type"] == "disabled"
finally:
asyncio.run(provider.terminate())


def test_deepseek_non_thinking_payload_does_not_inject_empty_reasoning_content():
provider = ProviderDeepSeek.__new__(ProviderDeepSeek)
provider.provider_config = {
"custom_extra_body": {
"thinking": {"type": "disabled"},
}
}
provider.client = SimpleNamespace(base_url=SimpleNamespace(host="api.deepseek.com"))

payloads = {
"model": "deepseek-v4-flash",
"messages": [{"role": "assistant", "content": "previous reply"}],
}

provider._finally_convert_payload(payloads)

assert "reasoning_content" not in payloads["messages"][0]


def test_deepseek_thinking_payload_keeps_empty_reasoning_content_for_history():
provider = ProviderDeepSeek.__new__(ProviderDeepSeek)
provider.provider_config = {
"custom_extra_body": {
"thinking": {"type": "enabled"},
}
}
provider.client = SimpleNamespace(base_url=SimpleNamespace(host="api.deepseek.com"))

payloads = {
"model": "deepseek-v4-flash",
"messages": [{"role": "assistant", "content": "previous reply"}],
}

provider._finally_convert_payload(payloads)

assert payloads["messages"][0]["reasoning_content"] == ""