1010import re
1111from abc import ABC , abstractmethod
1212from collections .abc import AsyncGenerator , Awaitable , Callable
13- from contextlib import asynccontextmanager
13+ from contextlib import asynccontextmanager , suppress
1414from dataclasses import dataclass
1515from http import HTTPStatus
1616from typing import Any
@@ -171,6 +171,7 @@ def __init__(
171171 ] = {}
172172 self ._sse_stream_writers : dict [RequestId , MemoryObjectSendStream [dict [str , str ]]] = {}
173173 self ._terminated = False
174+ self ._session_run_error : BaseException | None = None
174175 # Idle timeout cancel scope; managed by the session manager.
175176 self .idle_scope : anyio .CancelScope | None = None
176177
@@ -179,6 +180,16 @@ def is_terminated(self) -> bool:
179180 """Check if this transport has been explicitly terminated."""
180181 return self ._terminated
181182
183+ def note_session_run_error (self , exc : BaseException ) -> None :
184+ self ._session_run_error = exc
185+
186+ def _post_error_message (self , err : BaseException ) -> str :
187+ display_error = self ._session_run_error or err
188+ display_error_text = str (display_error ) or type (display_error ).__name__
189+ if display_error is not err and str (display_error ):
190+ display_error_text = f"{ type (display_error ).__name__ } : { display_error_text } "
191+ return f"Error handling POST request: { display_error_text } "
192+
182193 def close_sse_stream (self , request_id : RequestId ) -> None :
183194 """Close SSE connection for a specific request without terminating the stream.
184195
@@ -363,6 +374,10 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
363374 # Remove the request stream from the mapping
364375 self ._request_streams .pop (request_id , None )
365376
377+ async def _clean_up_post_request_stream (self , request_id : RequestId | None ) -> None :
378+ if request_id is not None :
379+ await self ._clean_up_memory_streams (request_id )
380+
366381 async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
367382 """Application entry point that handles all HTTP requests."""
368383 request = Request (scope , receive )
@@ -443,6 +458,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
443458 writer = self ._read_stream_writer
444459 if writer is None : # pragma: no cover
445460 raise ValueError ("No read stream writer available. Ensure connect() is called first." )
461+ request_id : RequestId | None = None
446462 try :
447463 # Validate Accept header
448464 if not await self ._validate_accept_header (request , scope , send ):
@@ -637,15 +653,16 @@ async def sse_writer():
637653
638654 except Exception as err :
639655 logger .exception ("Error handling POST request" )
656+ await self ._clean_up_post_request_stream (request_id )
640657 response = self ._create_error_response (
641- f"Error handling POST request: { err } " ,
658+ self . _post_error_message ( err ) ,
642659 HTTPStatus .INTERNAL_SERVER_ERROR ,
643660 INTERNAL_ERROR ,
644661 )
645662 await response (scope , receive , send )
646- if writer : # pragma: no cover
663+ with suppress ( anyio . BrokenResourceError , anyio . ClosedResourceError ):
647664 await writer .send (Exception (err ))
648- return # pragma: no cover
665+ return
649666
650667 async def _handle_get_request (self , request : Request , send : Send ) -> None :
651668 """Handle GET request to establish SSE.
0 commit comments