Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions src/art/langgraph/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,31 +108,43 @@ async def wrapper(*args, **kwargs):


def init_chat_model(
model: Literal[None] = None,
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal[None] = None,
config_prefix: str | None = None,
**kwargs: Any,
):
config = CURRENT_CONFIG.get()
timeout = kwargs.pop("timeout", 10 * 60)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pass timeout through to ChatOpenAI

init_chat_model currently strips timeout out of **kwargs before constructing ChatOpenAI, so that value is never applied to the underlying model client. This means callers using init_chat_model(..., timeout=...) still cannot configure request timeout behavior at the provider layer (and non-float timeout objects expected by ChatOpenAI are instead routed into asyncio.wait_for). In practice, this reintroduces silent misconfiguration for a common kwarg even though the function now claims to forward kwargs.

Useful? React with 👍 / 👎.

chat_model_kwargs: dict[str, Any] = {
"base_url": config["base_url"],
"api_key": config["api_key"],
"model": model or config["model"],
"temperature": 1.0,
}
chat_model_kwargs.update(kwargs)
Comment on lines 118 to +126
return LoggingLLM(
ChatOpenAI(
base_url=config["base_url"], # ty:ignore[unknown-argument]
api_key=config["api_key"], # ty:ignore[unknown-argument]
model=config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
),
ChatOpenAI(**chat_model_kwargs),
config["logger"],
timeout=timeout,
)


class LoggingLLM(Runnable):
def __init__(self, llm, logger, structured_output=None, tools=None):
def __init__(
self,
llm,
logger,
structured_output=None,
tools=None,
timeout: float = 10 * 60,
):
self.llm = llm
self.logger = logger
self.structured_output = structured_output
self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None
self.timeout = timeout

def _log(self, completion_id, input, output):
if self.logger:
Expand Down Expand Up @@ -167,7 +179,7 @@ async def ainvoke(self, input, config=None, **kwargs):
async def execute():
try:
result = await asyncio.wait_for(
self.llm.ainvoke(input, config=config), timeout=10 * 60
self.llm.ainvoke(input, config=config), timeout=self.timeout
)
self._log(completion_id, input, result)
except asyncio.TimeoutError as e:
Expand All @@ -194,10 +206,16 @@ def with_structured_output(self, tools):
self.logger,
structured_output=tools,
tools=[tools],
timeout=self.timeout,
)

def bind_tools(self, tools):
return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools)
return LoggingLLM(
self.llm.bind_tools(tools),
self.logger,
tools=tools,
timeout=self.timeout,
)

def with_retry(
self,
Expand Down
Loading