diff --git a/src/art/langgraph/llm_wrapper.py b/src/art/langgraph/llm_wrapper.py index 36b5314b3..8c5d819ab 100644 --- a/src/art/langgraph/llm_wrapper.py +++ b/src/art/langgraph/llm_wrapper.py @@ -108,7 +108,7 @@ async def wrapper(*args, **kwargs): def init_chat_model( - model: Literal[None] = None, + model: str | None = None, *, model_provider: str | None = None, configurable_fields: Literal[None] = None, @@ -116,23 +116,35 @@ def init_chat_model( **kwargs: Any, ): config = CURRENT_CONFIG.get() + timeout = kwargs.pop("timeout", 10 * 60) + chat_model_kwargs: dict[str, Any] = { + "base_url": config["base_url"], + "api_key": config["api_key"], + "model": model or config["model"], + "temperature": 1.0, + } + chat_model_kwargs.update(kwargs) return LoggingLLM( - ChatOpenAI( - base_url=config["base_url"], # ty:ignore[unknown-argument] - api_key=config["api_key"], # ty:ignore[unknown-argument] - model=config["model"], # ty:ignore[unknown-argument] - temperature=1.0, - ), + ChatOpenAI(**chat_model_kwargs), config["logger"], + timeout=timeout, ) class LoggingLLM(Runnable): - def __init__(self, llm, logger, structured_output=None, tools=None): + def __init__( + self, + llm, + logger, + structured_output=None, + tools=None, + timeout: float = 10 * 60, + ): self.llm = llm self.logger = logger self.structured_output = structured_output self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None + self.timeout = timeout def _log(self, completion_id, input, output): if self.logger: @@ -167,7 +179,7 @@ async def ainvoke(self, input, config=None, **kwargs): async def execute(): try: result = await asyncio.wait_for( - self.llm.ainvoke(input, config=config), timeout=10 * 60 + self.llm.ainvoke(input, config=config), timeout=self.timeout ) self._log(completion_id, input, result) except asyncio.TimeoutError as e: @@ -194,10 +206,16 @@ def with_structured_output(self, tools): self.logger, structured_output=tools, tools=[tools], + timeout=self.timeout, ) def bind_tools(self, tools): - return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools) + return LoggingLLM( + self.llm.bind_tools(tools), + self.logger, + tools=tools, + timeout=self.timeout, + ) def with_retry( self,