From 18fae4a00bddcd57506286483b48d4c3584c78ca Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 17:37:03 +1200 Subject: [PATCH 1/6] Add assistant middleware inheritance option --- slack_bolt/app/app.py | 19 +- slack_bolt/app/async_app.py | 19 +- slack_bolt/middleware/assistant/assistant.py | 50 +++- .../middleware/assistant/async_assistant.py | 50 +++- .../app/test_app_assistant_middleware.py | 249 ++++++++++++++++++ 5 files changed, 367 insertions(+), 20 deletions(-) create mode 100644 tests/slack_bolt/app/test_app_assistant_middleware.py diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 0af27913c..22735b90d 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -683,19 +683,26 @@ def middleware_func(logger, body, next): self._middleware_list.append(middleware) if isinstance(middleware, Assistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store + elif not isinstance(middleware, Assistant): + self._inherit_app_middleware_for_assistants(middleware) elif callable(middleware_or_callable): - self._middleware_list.append( - CustomMiddleware( - app_name=self.name, - func=middleware_or_callable, - base_logger=self._base_logger, - ) + middleware = CustomMiddleware( + app_name=self.name, + func=middleware_or_callable, + base_logger=self._base_logger, ) + self._middleware_list.append(middleware) + self._inherit_app_middleware_for_assistants(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None + def _inherit_app_middleware_for_assistants(self, middleware: Middleware) -> None: + for registered_middleware in self._middleware_list[:-1]: + if isinstance(registered_middleware, Assistant): + registered_middleware.inherit_app_middleware(middleware) + # ------------------------- # AI Agents & Assistants diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index cc94f9e15..0f57fc150 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -710,19 +710,26 @@ async def middleware_func(logger, body, next): self._async_middleware_list.append(middleware) if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store + elif not isinstance(middleware, AsyncAssistant): + self._inherit_app_middleware_for_assistants(middleware) elif callable(middleware_or_callable): - self._async_middleware_list.append( - AsyncCustomMiddleware( - app_name=self.name, - func=middleware_or_callable, - base_logger=self._base_logger, - ) + middleware = AsyncCustomMiddleware( + app_name=self.name, + func=middleware_or_callable, + base_logger=self._base_logger, ) + self._async_middleware_list.append(middleware) + self._inherit_app_middleware_for_assistants(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None + def _inherit_app_middleware_for_assistants(self, middleware: AsyncMiddleware) -> None: + for registered_middleware in self._async_middleware_list[:-1]: + if isinstance(registered_middleware, AsyncAssistant): + registered_middleware.inherit_app_middleware(middleware) + def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]: return self.middleware(assistant) diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index ad842f94d..837dfedcd 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -1,7 +1,7 @@ import logging from functools import wraps from logging import Logger -from typing import List, Optional, Union, Callable +from typing import List, Optional, Union, Callable, Tuple from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore @@ -42,16 +42,25 @@ def __init__( app_name: str = "assistant", thread_context_store: Optional[AssistantThreadContextStore] = None, logger: Optional[logging.Logger] = None, + auto_inherit_app_middleware: bool = False, ): self.app_name = app_name self.thread_context_store = thread_context_store self.base_logger = logger + self.auto_inherit_app_middleware = auto_inherit_app_middleware + self._inherited_app_middleware: List[Middleware] = [] self._thread_started_listeners = None self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None + def inherit_app_middleware(self, middleware: Middleware) -> None: + if self.auto_inherit_app_middleware is False: + return + + self._inherited_app_middleware.append(middleware) + def thread_started( self, *args, @@ -237,7 +246,11 @@ def process( # type: ignore[return] if listeners is not None: for listener in listeners: if listener.matches(req=req, resp=resp): - middleware_resp, next_was_not_called = listener.run_middleware(req=req, resp=resp) + middleware_resp, next_was_not_called = self._run_middleware( + listener=listener, + req=req, + resp=resp, + ) if next_was_not_called: if middleware_resp is not None: return middleware_resp @@ -258,6 +271,33 @@ def process( # type: ignore[return] next() + def _run_middleware( + self, + *, + listener: Listener, + req: BoltRequest, + resp: BoltResponse, + ) -> Tuple[Optional[BoltResponse], bool]: + middleware = list(listener.middleware) + if len(self._inherited_app_middleware) > 0: + insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AttachingConversationKwargs) else 0 + middleware = [ + *middleware[:insertion_index], + *self._inherited_app_middleware, + *middleware[insertion_index:], + ] + + for m in middleware: + middleware_state = {"next_called": False} + + def next_(): + middleware_state["next_called"] = True + + resp = m.process(req=req, resp=resp, next=next_) # type: ignore[assignment] + if not middleware_state["next_called"]: + return resp, True + return resp, False + def build_listener( self, listener_or_functions: Union[Listener, Callable, List[Callable]], @@ -271,8 +311,10 @@ def build_listener( if isinstance(listener_or_functions, Listener): return listener_or_functions elif isinstance(listener_or_functions, list): - middleware = middleware if middleware else [] - middleware.insert(0, AttachingConversationKwargs(self.thread_context_store)) + middleware = [ + AttachingConversationKwargs(self.thread_context_store), + *(middleware if middleware else []), + ] functions = listener_or_functions ack_function = functions.pop(0) diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index 588de8b41..14a8c0c75 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -1,7 +1,7 @@ import logging from functools import wraps from logging import Logger -from typing import List, Optional, Union, Callable, Awaitable +from typing import List, Optional, Union, Callable, Awaitable, Tuple from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore @@ -42,16 +42,25 @@ def __init__( app_name: str = "assistant", thread_context_store: Optional[AsyncAssistantThreadContextStore] = None, logger: Optional[logging.Logger] = None, + auto_inherit_app_middleware: bool = False, ): self.app_name = app_name self.thread_context_store = thread_context_store self.base_logger = logger + self.auto_inherit_app_middleware = auto_inherit_app_middleware + self._inherited_app_middleware: List[AsyncMiddleware] = [] self._thread_started_listeners = None self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None + def inherit_app_middleware(self, middleware: AsyncMiddleware) -> None: + if self.auto_inherit_app_middleware is False: + return + + self._inherited_app_middleware.append(middleware) + def thread_started( self, *args, @@ -268,7 +277,11 @@ async def async_process( # type: ignore[return] if listeners is not None: for listener in listeners: if listener is not None and await listener.async_matches(req=req, resp=resp): - middleware_resp, next_was_not_called = await listener.run_async_middleware(req=req, resp=resp) + middleware_resp, next_was_not_called = await self._run_middleware( + listener=listener, + req=req, + resp=resp, + ) if next_was_not_called: if middleware_resp is not None: return middleware_resp @@ -289,6 +302,33 @@ async def async_process( # type: ignore[return] await next() + async def _run_middleware( + self, + *, + listener: AsyncListener, + req: AsyncBoltRequest, + resp: BoltResponse, + ) -> Tuple[Optional[BoltResponse], bool]: + middleware = list(listener.middleware) + if len(self._inherited_app_middleware) > 0: + insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AsyncAttachingConversationKwargs) else 0 + middleware = [ + *middleware[:insertion_index], + *self._inherited_app_middleware, + *middleware[insertion_index:], + ] + + for m in middleware: + middleware_state = {"next_called": False} + + async def next_(): + middleware_state["next_called"] = True + + resp = await m.async_process(req=req, resp=resp, next=next_) # type: ignore[assignment] + if not middleware_state["next_called"]: + return resp, True + return resp, False + def build_listener( self, listener_or_functions: Union[AsyncListener, Callable, List[Callable]], @@ -302,8 +342,10 @@ def build_listener( if isinstance(listener_or_functions, AsyncListener): return listener_or_functions elif isinstance(listener_or_functions, list): - middleware = middleware if middleware else [] - middleware.insert(0, AsyncAttachingConversationKwargs(self.thread_context_store)) + middleware = [ + AsyncAttachingConversationKwargs(self.thread_context_store), + *(middleware if middleware else []), + ] functions = listener_or_functions ack_function = functions.pop(0) diff --git a/tests/slack_bolt/app/test_app_assistant_middleware.py b/tests/slack_bolt/app/test_app_assistant_middleware.py new file mode 100644 index 000000000..9bcfd1980 --- /dev/null +++ b/tests/slack_bolt/app/test_app_assistant_middleware.py @@ -0,0 +1,249 @@ +from typing import Awaitable, Callable, Optional + +import pytest +from slack_sdk import WebClient +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt import App, Assistant, BoltRequest +from slack_bolt.async_app import AsyncApp, AsyncAssistant, AsyncBoltRequest +from slack_bolt.authorization import AuthorizeResult +from slack_bolt.middleware import Middleware +from slack_bolt.middleware.async_middleware import AsyncMiddleware +from slack_bolt.request import BoltRequest as BoltRequestType +from slack_bolt.response import BoltResponse +from tests.scenario_tests.test_events_assistant import thread_started_event_body, user_message_event_body + + +def authorize_test_app(context, enterprise_id, team_id, user_id): + return AuthorizeResult( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + bot_user_id="W111", + bot_id="B111", + bot_token="xoxb-valid", + ) + + +async def async_authorize_test_app(context, enterprise_id, team_id, user_id): + return authorize_test_app(context, enterprise_id, team_id, user_id) + + +class TestAppAssistantMiddleware: + def test_assistant_inherits_app_middleware_registered_after_assistant(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant(auto_inherit_app_middleware=True) + calls = [] + + class ListenerMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("listener") + return next() + + @assistant.user_message(middleware=[ListenerMiddleware()]) + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + + @app.middleware + def app_middleware(req, next): + calls.append("app") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None + return next() + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["app", "listener", "handler"] + + def test_assistant_does_not_inherit_app_middleware_by_default(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + def test_assistant_inherits_app_middleware_for_listeners_registered_later(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant(auto_inherit_app_middleware=True) + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["app", "handler"] + + def test_assistant_inherited_app_middleware_can_short_circuit(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant(auto_inherit_app_middleware=True) + calls = [] + + class BlockingMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return BoltResponse(status=201) + + @assistant.thread_started + def start_thread(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(BlockingMiddleware()) + + request = BoltRequest(body=thread_started_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 201 + assert calls == ["app"] + + +class TestAsyncAppAssistantMiddleware: + @pytest.mark.asyncio + async def test_assistant_inherits_app_middleware_registered_after_assistant(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant(auto_inherit_app_middleware=True) + calls = [] + + class ListenerMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("listener") + return await next() + + @assistant.user_message(middleware=[ListenerMiddleware()]) + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + + @app.middleware + async def app_middleware(req, next): + calls.append("app") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None + return await next() + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["app", "listener", "handler"] + + @pytest.mark.asyncio + async def test_assistant_does_not_inherit_app_middleware_by_default(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + @pytest.mark.asyncio + async def test_assistant_inherits_app_middleware_for_listeners_registered_later(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant(auto_inherit_app_middleware=True) + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["app", "handler"] + + @pytest.mark.asyncio + async def test_assistant_inherited_app_middleware_can_short_circuit(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant(auto_inherit_app_middleware=True) + calls = [] + + class BlockingMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return BoltResponse(status=201) + + @assistant.thread_started + async def start_thread(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(BlockingMiddleware()) + + request = AsyncBoltRequest(body=thread_started_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 201 + assert calls == ["app"] From 99553b5c67373613dddeda7688cb31af914b66b5 Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 19:10:54 +1200 Subject: [PATCH 2/6] Register assistant handlers as app listeners --- slack_bolt/app/app.py | 20 ++- slack_bolt/app/async_app.py | 20 ++- slack_bolt/middleware/assistant/assistant.py | 135 +++++++++------- .../middleware/assistant/async_assistant.py | 146 +++++++++++------- .../app/test_app_assistant_middleware.py | 66 +++++++- 5 files changed, 251 insertions(+), 136 deletions(-) diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 22735b90d..3a44c9afb 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -349,6 +349,7 @@ def message_hello(message, say): self._middleware_list: List[Middleware] = [] self._listeners: List[Listener] = [] + self._assistant_listener_insertion_index = 0 if listener_executor is None: listener_executor = ThreadPoolExecutor(max_workers=5) @@ -680,11 +681,12 @@ def middleware_func(logger, body, next): middleware_or_callable = args[0] if isinstance(middleware_or_callable, Middleware): middleware: Middleware = middleware_or_callable + if isinstance(middleware, Assistant) and middleware.auto_inherit_app_middleware is True: + self._register_assistant_listeners(middleware) + return None self._middleware_list.append(middleware) if isinstance(middleware, Assistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store - elif not isinstance(middleware, Assistant): - self._inherit_app_middleware_for_assistants(middleware) elif callable(middleware_or_callable): middleware = CustomMiddleware( app_name=self.name, @@ -692,16 +694,20 @@ def middleware_func(logger, body, next): base_logger=self._base_logger, ) self._middleware_list.append(middleware) - self._inherit_app_middleware_for_assistants(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None - def _inherit_app_middleware_for_assistants(self, middleware: Middleware) -> None: - for registered_middleware in self._middleware_list[:-1]: - if isinstance(registered_middleware, Assistant): - registered_middleware.inherit_app_middleware(middleware) + def _register_assistant_listeners(self, assistant: Assistant) -> None: + if assistant.thread_context_store is not None: + self._assistant_thread_context_store = assistant.thread_context_store + + def register_listener(listener: Listener) -> None: + self._listeners.insert(self._assistant_listener_insertion_index, listener) + self._assistant_listener_insertion_index += 1 + + assistant._register_app_listeners(register_listener) # ------------------------- # AI Agents & Assistants diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 0f57fc150..93cb70934 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -361,6 +361,7 @@ async def message_hello(message, say): # async function self._async_middleware_list: List[AsyncMiddleware] = [] self._async_listeners: List[AsyncListener] = [] + self._assistant_listener_insertion_index = 0 self._assistant_thread_context_store = assistant_thread_context_store self._attaching_conversation_kwargs_enabled = attaching_conversation_kwargs_enabled @@ -707,11 +708,12 @@ async def middleware_func(logger, body, next): middleware_or_callable = args[0] if isinstance(middleware_or_callable, AsyncMiddleware): middleware: AsyncMiddleware = middleware_or_callable + if isinstance(middleware, AsyncAssistant) and middleware.auto_inherit_app_middleware is True: + self._register_assistant_listeners(middleware) + return None self._async_middleware_list.append(middleware) if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store - elif not isinstance(middleware, AsyncAssistant): - self._inherit_app_middleware_for_assistants(middleware) elif callable(middleware_or_callable): middleware = AsyncCustomMiddleware( app_name=self.name, @@ -719,16 +721,20 @@ async def middleware_func(logger, body, next): base_logger=self._base_logger, ) self._async_middleware_list.append(middleware) - self._inherit_app_middleware_for_assistants(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None - def _inherit_app_middleware_for_assistants(self, middleware: AsyncMiddleware) -> None: - for registered_middleware in self._async_middleware_list[:-1]: - if isinstance(registered_middleware, AsyncAssistant): - registered_middleware.inherit_app_middleware(middleware) + def _register_assistant_listeners(self, assistant: AsyncAssistant) -> None: + if assistant.thread_context_store is not None: + self._assistant_thread_context_store = assistant.thread_context_store + + def register_listener(listener: AsyncListener) -> None: + self._async_listeners.insert(self._assistant_listener_insertion_index, listener) + self._assistant_listener_insertion_index += 1 + + assistant._register_app_listeners(register_listener) def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]: return self.middleware(assistant) diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index 837dfedcd..2387be496 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -1,7 +1,7 @@ import logging from functools import wraps from logging import Logger -from typing import List, Optional, Union, Callable, Tuple +from typing import List, Optional, Union, Callable from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore @@ -32,6 +32,8 @@ class Assistant(Middleware): _thread_context_changed_listeners: Optional[List[Listener]] _user_message_listeners: Optional[List[Listener]] _bot_message_listeners: Optional[List[Listener]] + _other_message_sub_event_listeners: Optional[List[Listener]] + _app_listener_registrars: List[Callable[[Listener], None]] thread_context_store: Optional[AssistantThreadContextStore] base_logger: Optional[logging.Logger] @@ -48,18 +50,13 @@ def __init__( self.thread_context_store = thread_context_store self.base_logger = logger self.auto_inherit_app_middleware = auto_inherit_app_middleware - self._inherited_app_middleware: List[Middleware] = [] self._thread_started_listeners = None self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None - - def inherit_app_middleware(self, middleware: Middleware) -> None: - if self.auto_inherit_app_middleware is False: - return - - self._inherited_app_middleware.append(middleware) + self._other_message_sub_event_listeners = None + self._app_listener_registrars = [] def thread_started( self, @@ -73,23 +70,25 @@ def thread_started( all_matchers = self._merge_matchers(is_assistant_thread_started_event, matchers) if is_used_without_argument(args): func = args[0] - self._thread_started_listeners.append( + self._append_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_started_listeners.append( + self._append_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -112,23 +111,25 @@ def user_message( all_matchers = self._merge_matchers(is_user_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._user_message_listeners.append( + self._append_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._user_message_listeners.append( + self._append_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -151,23 +152,25 @@ def bot_message( all_matchers = self._merge_matchers(is_bot_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._bot_message_listeners.append( + self._append_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._bot_message_listeners.append( + self._append_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -190,23 +193,25 @@ def thread_context_changed( all_matchers = self._merge_matchers(is_assistant_thread_context_changed_event, matchers) if is_used_without_argument(args): func = args[0] - self._thread_context_changed_listeners.append( + self._append_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_context_changed_listeners.append( + self._append_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -230,12 +235,54 @@ def _merge_matchers( def default_thread_context_changed(save_thread_context: SaveThreadContext, payload: dict): save_thread_context(payload["assistant_thread"]["context"]) - def process( # type: ignore[return] - self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse] - ) -> Optional[BoltResponse]: + @staticmethod + def default_other_message_sub_event(ack): + ack() + + def _register_app_listeners(self, listener_registrar: Callable[[Listener], None]) -> None: + self._ensure_default_thread_context_changed_listener() + self._ensure_other_message_sub_event_listener() + for listener in self._app_listeners: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) + + @property + def _app_listeners(self) -> List[Listener]: + listeners: List[Listener] = [] + for listener_list in [ + self._thread_started_listeners, + self._thread_context_changed_listeners, + self._user_message_listeners, + self._bot_message_listeners, + self._other_message_sub_event_listeners, + ]: + if listener_list is not None: + listeners.extend(listener_list) + return listeners + + def _append_listener(self, listeners: List[Listener], listener: Listener) -> None: + listeners.append(listener) + for registrar in self._app_listener_registrars: + registrar(listener) + + def _ensure_default_thread_context_changed_listener(self) -> None: if self._thread_context_changed_listeners is None: self.thread_context_changed(self.default_thread_context_changed) + def _ensure_other_message_sub_event_listener(self) -> None: + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + self._append_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=[is_other_message_sub_event_in_assistant_thread], + ), + ) + + def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]) -> Optional[BoltResponse]: + self._ensure_default_thread_context_changed_listener() + listener_runner: ThreadListenerRunner = req.context.listener_runner for listeners in [ self._thread_started_listeners, @@ -246,11 +293,7 @@ def process( # type: ignore[return] if listeners is not None: for listener in listeners: if listener.matches(req=req, resp=resp): - middleware_resp, next_was_not_called = self._run_middleware( - listener=listener, - req=req, - resp=resp, - ) + middleware_resp, next_was_not_called = listener.run_middleware(req=req, resp=resp) if next_was_not_called: if middleware_resp is not None: return middleware_resp @@ -270,33 +313,7 @@ def process( # type: ignore[return] return req.context.ack() next() - - def _run_middleware( - self, - *, - listener: Listener, - req: BoltRequest, - resp: BoltResponse, - ) -> Tuple[Optional[BoltResponse], bool]: - middleware = list(listener.middleware) - if len(self._inherited_app_middleware) > 0: - insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AttachingConversationKwargs) else 0 - middleware = [ - *middleware[:insertion_index], - *self._inherited_app_middleware, - *middleware[insertion_index:], - ] - - for m in middleware: - middleware_state = {"next_called": False} - - def next_(): - middleware_state["next_called"] = True - - resp = m.process(req=req, resp=resp, next=next_) # type: ignore[assignment] - if not middleware_state["next_called"]: - return resp, True - return resp, False + return None def build_listener( self, diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index 14a8c0c75..58d8455f4 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -1,7 +1,7 @@ import logging from functools import wraps from logging import Logger -from typing import List, Optional, Union, Callable, Awaitable, Tuple +from typing import List, Optional, Union, Callable, Awaitable, cast from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore @@ -32,6 +32,8 @@ class AsyncAssistant(AsyncMiddleware): _user_message_listeners: Optional[List[AsyncListener]] _bot_message_listeners: Optional[List[AsyncListener]] _thread_context_changed_listeners: Optional[List[AsyncListener]] + _other_message_sub_event_listeners: Optional[List[AsyncListener]] + _app_listener_registrars: List[Callable[[AsyncListener], None]] thread_context_store: Optional[AsyncAssistantThreadContextStore] base_logger: Optional[logging.Logger] @@ -48,18 +50,13 @@ def __init__( self.thread_context_store = thread_context_store self.base_logger = logger self.auto_inherit_app_middleware = auto_inherit_app_middleware - self._inherited_app_middleware: List[AsyncMiddleware] = [] self._thread_started_listeners = None self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None - - def inherit_app_middleware(self, middleware: AsyncMiddleware) -> None: - if self.auto_inherit_app_middleware is False: - return - - self._inherited_app_middleware.append(middleware) + self._other_message_sub_event_listeners = None + self._app_listener_registrars = [] def thread_started( self, @@ -80,23 +77,25 @@ def thread_started( ) if is_used_without_argument(args): func = args[0] - self._thread_started_listeners.append( + self._append_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_started_listeners.append( + self._append_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -126,23 +125,25 @@ def user_message( ) if is_used_without_argument(args): func = args[0] - self._user_message_listeners.append( + self._append_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._user_message_listeners.append( + self._append_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -172,23 +173,25 @@ def bot_message( ) if is_used_without_argument(args): func = args[0] - self._bot_message_listeners.append( + self._append_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._bot_message_listeners.append( + self._append_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -218,23 +221,25 @@ def thread_context_changed( ) if is_used_without_argument(args): func = args[0] - self._thread_context_changed_listeners.append( + self._append_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_context_changed_listeners.append( + self._append_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -257,15 +262,68 @@ async def default_thread_context_changed(save_thread_context: AsyncSaveThreadCon new_context: dict = payload["assistant_thread"]["context"] await save_thread_context(new_context) - async def async_process( # type: ignore[return] + @staticmethod + async def default_other_message_sub_event(ack): + await ack() + + def _register_app_listeners(self, listener_registrar: Callable[[AsyncListener], None]) -> None: + self._ensure_default_thread_context_changed_listener() + self._ensure_other_message_sub_event_listener() + for listener in self._app_listeners: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) + + @property + def _app_listeners(self) -> List[AsyncListener]: + listeners: List[AsyncListener] = [] + for listener_list in [ + self._thread_started_listeners, + self._thread_context_changed_listeners, + self._user_message_listeners, + self._bot_message_listeners, + self._other_message_sub_event_listeners, + ]: + if listener_list is not None: + listeners.extend(listener_list) + return listeners + + def _append_listener(self, listeners: List[AsyncListener], listener: AsyncListener) -> None: + listeners.append(listener) + for registrar in self._app_listener_registrars: + registrar(listener) + + def _ensure_default_thread_context_changed_listener(self) -> None: + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) + + def _ensure_other_message_sub_event_listener(self) -> None: + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + self._append_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=[ + cast( + AsyncListenerMatcher, + build_listener_matcher( + func=is_other_message_sub_event_in_assistant_thread, + asyncio=True, + base_logger=self.base_logger, + ), + ) + ], + ), + ) + + async def async_process( self, *, req: AsyncBoltRequest, resp: BoltResponse, next: Callable[[], Awaitable[BoltResponse]], ) -> Optional[BoltResponse]: - if self._thread_context_changed_listeners is None: - self.thread_context_changed(self.default_thread_context_changed) + self._ensure_default_thread_context_changed_listener() listener_runner: AsyncioListenerRunner = req.context.listener_runner for listeners in [ @@ -277,11 +335,7 @@ async def async_process( # type: ignore[return] if listeners is not None: for listener in listeners: if listener is not None and await listener.async_matches(req=req, resp=resp): - middleware_resp, next_was_not_called = await self._run_middleware( - listener=listener, - req=req, - resp=resp, - ) + middleware_resp, next_was_not_called = await listener.run_async_middleware(req=req, resp=resp) if next_was_not_called: if middleware_resp is not None: return middleware_resp @@ -301,33 +355,7 @@ async def async_process( # type: ignore[return] return await req.context.ack() await next() - - async def _run_middleware( - self, - *, - listener: AsyncListener, - req: AsyncBoltRequest, - resp: BoltResponse, - ) -> Tuple[Optional[BoltResponse], bool]: - middleware = list(listener.middleware) - if len(self._inherited_app_middleware) > 0: - insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AsyncAttachingConversationKwargs) else 0 - middleware = [ - *middleware[:insertion_index], - *self._inherited_app_middleware, - *middleware[insertion_index:], - ] - - for m in middleware: - middleware_state = {"next_called": False} - - async def next_(): - middleware_state["next_called"] = True - - resp = await m.async_process(req=req, resp=resp, next=next_) # type: ignore[assignment] - if not middleware_state["next_called"]: - return resp, True - return resp, False + return None def build_listener( self, diff --git a/tests/slack_bolt/app/test_app_assistant_middleware.py b/tests/slack_bolt/app/test_app_assistant_middleware.py index 9bcfd1980..bdc313e6a 100644 --- a/tests/slack_bolt/app/test_app_assistant_middleware.py +++ b/tests/slack_bolt/app/test_app_assistant_middleware.py @@ -30,6 +30,35 @@ async def async_authorize_test_app(context, enterprise_id, team_id, user_id): class TestAppAssistantMiddleware: + def test_auto_inherit_assistant_registers_handlers_as_app_listeners(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant(auto_inherit_app_middleware=True) + + @app.message("") + def catch_all(): + pass + + @assistant.user_message + def handle_user_message(): + pass + + app.assistant(assistant) + + second_assistant = Assistant(auto_inherit_app_middleware=True) + + @second_assistant.user_message + def handle_second_user_message(): + pass + + app.assistant(second_assistant) + + assert assistant not in app._middleware_list + listener_functions = [listener.ack_function for listener in app._listeners] + assert handle_user_message in listener_functions + assert listener_functions.index(handle_user_message) < listener_functions.index(catch_all) + assert listener_functions.index(handle_user_message) < listener_functions.index(handle_second_user_message) + assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) + def test_assistant_inherits_app_middleware_registered_after_assistant(self): app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) assistant = Assistant(auto_inherit_app_middleware=True) @@ -38,6 +67,8 @@ def test_assistant_inherits_app_middleware_registered_after_assistant(self): class ListenerMiddleware(Middleware): def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): calls.append("listener") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None return next() @assistant.user_message(middleware=[ListenerMiddleware()]) @@ -49,8 +80,6 @@ def handle_user_message(): @app.middleware def app_middleware(req, next): calls.append("app") - assert req.context.get("set_status") is not None - assert req.context.get("set_title") is not None return next() request = BoltRequest(body=user_message_event_body, mode="socket_mode") @@ -126,6 +155,35 @@ def start_thread(): class TestAsyncAppAssistantMiddleware: + def test_auto_inherit_assistant_registers_handlers_as_app_listeners(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant(auto_inherit_app_middleware=True) + + @app.message("") + async def catch_all(): + pass + + @assistant.user_message + async def handle_user_message(): + pass + + app.assistant(assistant) + + second_assistant = AsyncAssistant(auto_inherit_app_middleware=True) + + @second_assistant.user_message + async def handle_second_user_message(): + pass + + app.assistant(second_assistant) + + assert assistant not in app._async_middleware_list + listener_functions = [listener.ack_function for listener in app._async_listeners] + assert handle_user_message in listener_functions + assert listener_functions.index(handle_user_message) < listener_functions.index(catch_all) + assert listener_functions.index(handle_user_message) < listener_functions.index(handle_second_user_message) + assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) + @pytest.mark.asyncio async def test_assistant_inherits_app_middleware_registered_after_assistant(self): app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) @@ -141,6 +199,8 @@ async def async_process( next: Callable[[], Awaitable[BoltResponse]], ) -> Optional[BoltResponse]: calls.append("listener") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None return await next() @assistant.user_message(middleware=[ListenerMiddleware()]) @@ -152,8 +212,6 @@ async def handle_user_message(): @app.middleware async def app_middleware(req, next): calls.append("app") - assert req.context.get("set_status") is not None - assert req.context.get("set_title") is not None return await next() request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") From fdc2e1761bbc0ad5d72a7ce31218aeef889fb5bd Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 19:36:12 +1200 Subject: [PATCH 3/6] Minimize assistant listener changes --- slack_bolt/app/app.py | 11 +-- slack_bolt/app/async_app.py | 11 +-- slack_bolt/middleware/assistant/assistant.py | 50 +++++-------- .../middleware/assistant/async_assistant.py | 71 ++++++++----------- 4 files changed, 59 insertions(+), 84 deletions(-) diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 3a44c9afb..2c340f6e6 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -688,12 +688,13 @@ def middleware_func(logger, body, next): if isinstance(middleware, Assistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store elif callable(middleware_or_callable): - middleware = CustomMiddleware( - app_name=self.name, - func=middleware_or_callable, - base_logger=self._base_logger, + self._middleware_list.append( + CustomMiddleware( + app_name=self.name, + func=middleware_or_callable, + base_logger=self._base_logger, + ) ) - self._middleware_list.append(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 93cb70934..cca58b6b9 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -715,12 +715,13 @@ async def middleware_func(logger, body, next): if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store elif callable(middleware_or_callable): - middleware = AsyncCustomMiddleware( - app_name=self.name, - func=middleware_or_callable, - base_logger=self._base_logger, + self._async_middleware_list.append( + AsyncCustomMiddleware( + app_name=self.name, + func=middleware_or_callable, + base_logger=self._base_logger, + ) ) - self._async_middleware_list.append(middleware) return middleware_or_callable else: raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index 2387be496..68d37dbc3 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -240,15 +240,17 @@ def default_other_message_sub_event(ack): ack() def _register_app_listeners(self, listener_registrar: Callable[[Listener], None]) -> None: - self._ensure_default_thread_context_changed_listener() - self._ensure_other_message_sub_event_listener() - for listener in self._app_listeners: - listener_registrar(listener) - self._app_listener_registrars.append(listener_registrar) - - @property - def _app_listeners(self) -> List[Listener]: - listeners: List[Listener] = [] + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + self._append_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=[is_other_message_sub_event_in_assistant_thread], + ), + ) for listener_list in [ self._thread_started_listeners, self._thread_context_changed_listeners, @@ -257,32 +259,21 @@ def _app_listeners(self) -> List[Listener]: self._other_message_sub_event_listeners, ]: if listener_list is not None: - listeners.extend(listener_list) - return listeners + for listener in listener_list: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) def _append_listener(self, listeners: List[Listener], listener: Listener) -> None: listeners.append(listener) for registrar in self._app_listener_registrars: registrar(listener) - def _ensure_default_thread_context_changed_listener(self) -> None: + def process( # type: ignore[return] + self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse] + ) -> Optional[BoltResponse]: if self._thread_context_changed_listeners is None: self.thread_context_changed(self.default_thread_context_changed) - def _ensure_other_message_sub_event_listener(self) -> None: - if self._other_message_sub_event_listeners is None: - self._other_message_sub_event_listeners = [] - self._append_listener( - self._other_message_sub_event_listeners, - self.build_listener( - listener_or_functions=self.default_other_message_sub_event, - matchers=[is_other_message_sub_event_in_assistant_thread], - ), - ) - - def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]) -> Optional[BoltResponse]: - self._ensure_default_thread_context_changed_listener() - listener_runner: ThreadListenerRunner = req.context.listener_runner for listeners in [ self._thread_started_listeners, @@ -313,7 +304,6 @@ def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], Bo return req.context.ack() next() - return None def build_listener( self, @@ -328,10 +318,8 @@ def build_listener( if isinstance(listener_or_functions, Listener): return listener_or_functions elif isinstance(listener_or_functions, list): - middleware = [ - AttachingConversationKwargs(self.thread_context_store), - *(middleware if middleware else []), - ] + middleware = middleware if middleware else [] + middleware.insert(0, AttachingConversationKwargs(self.thread_context_store)) functions = listener_or_functions ack_function = functions.pop(0) diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index 58d8455f4..47acb0563 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -1,7 +1,7 @@ import logging from functools import wraps from logging import Logger -from typing import List, Optional, Union, Callable, Awaitable, cast +from typing import List, Optional, Union, Callable, Awaitable from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore @@ -267,15 +267,25 @@ async def default_other_message_sub_event(ack): await ack() def _register_app_listeners(self, listener_registrar: Callable[[AsyncListener], None]) -> None: - self._ensure_default_thread_context_changed_listener() - self._ensure_other_message_sub_event_listener() - for listener in self._app_listeners: - listener_registrar(listener) - self._app_listener_registrars.append(listener_registrar) - - @property - def _app_listeners(self) -> List[AsyncListener]: - listeners: List[AsyncListener] = [] + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + all_matchers = self._merge_matchers( + build_listener_matcher( + func=is_other_message_sub_event_in_assistant_thread, + asyncio=True, + base_logger=self.base_logger, + ), # type: ignore[arg-type] + None, + ) + self._append_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=all_matchers, + ), + ) for listener_list in [ self._thread_started_listeners, self._thread_context_changed_listeners, @@ -284,46 +294,24 @@ def _app_listeners(self) -> List[AsyncListener]: self._other_message_sub_event_listeners, ]: if listener_list is not None: - listeners.extend(listener_list) - return listeners + for listener in listener_list: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) def _append_listener(self, listeners: List[AsyncListener], listener: AsyncListener) -> None: listeners.append(listener) for registrar in self._app_listener_registrars: registrar(listener) - def _ensure_default_thread_context_changed_listener(self) -> None: - if self._thread_context_changed_listeners is None: - self.thread_context_changed(self.default_thread_context_changed) - - def _ensure_other_message_sub_event_listener(self) -> None: - if self._other_message_sub_event_listeners is None: - self._other_message_sub_event_listeners = [] - self._append_listener( - self._other_message_sub_event_listeners, - self.build_listener( - listener_or_functions=self.default_other_message_sub_event, - matchers=[ - cast( - AsyncListenerMatcher, - build_listener_matcher( - func=is_other_message_sub_event_in_assistant_thread, - asyncio=True, - base_logger=self.base_logger, - ), - ) - ], - ), - ) - - async def async_process( + async def async_process( # type: ignore[return] self, *, req: AsyncBoltRequest, resp: BoltResponse, next: Callable[[], Awaitable[BoltResponse]], ) -> Optional[BoltResponse]: - self._ensure_default_thread_context_changed_listener() + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) listener_runner: AsyncioListenerRunner = req.context.listener_runner for listeners in [ @@ -355,7 +343,6 @@ async def async_process( return await req.context.ack() await next() - return None def build_listener( self, @@ -370,10 +357,8 @@ def build_listener( if isinstance(listener_or_functions, AsyncListener): return listener_or_functions elif isinstance(listener_or_functions, list): - middleware = [ - AsyncAttachingConversationKwargs(self.thread_context_store), - *(middleware if middleware else []), - ] + middleware = middleware if middleware else [] + middleware.insert(0, AsyncAttachingConversationKwargs(self.thread_context_store)) functions = listener_or_functions ack_function = functions.pop(0) From 66b8d170a9205ed6f71ef2f0747c134a5ec7317d Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 19:58:06 +1200 Subject: [PATCH 4/6] Clarify assistant listener registration --- slack_bolt/app/app.py | 2 ++ slack_bolt/app/async_app.py | 2 ++ slack_bolt/middleware/assistant/assistant.py | 21 ++++++++++--------- .../middleware/assistant/async_assistant.py | 21 ++++++++++--------- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 2c340f6e6..bce705d49 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -682,6 +682,7 @@ def middleware_func(logger, body, next): if isinstance(middleware_or_callable, Middleware): middleware: Middleware = middleware_or_callable if isinstance(middleware, Assistant) and middleware.auto_inherit_app_middleware is True: + # In this opt-in mode, Assistant handlers should run through the app listener pipeline. self._register_assistant_listeners(middleware) return None self._middleware_list.append(middleware) @@ -705,6 +706,7 @@ def _register_assistant_listeners(self, assistant: Assistant) -> None: self._assistant_thread_context_store = assistant.thread_context_store def register_listener(listener: Listener) -> None: + # Keep Assistant listeners before catch-all listeners while preserving Assistant registration order. self._listeners.insert(self._assistant_listener_insertion_index, listener) self._assistant_listener_insertion_index += 1 diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index cca58b6b9..c821026b4 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -709,6 +709,7 @@ async def middleware_func(logger, body, next): if isinstance(middleware_or_callable, AsyncMiddleware): middleware: AsyncMiddleware = middleware_or_callable if isinstance(middleware, AsyncAssistant) and middleware.auto_inherit_app_middleware is True: + # In this opt-in mode, Assistant handlers should run through the app listener pipeline. self._register_assistant_listeners(middleware) return None self._async_middleware_list.append(middleware) @@ -732,6 +733,7 @@ def _register_assistant_listeners(self, assistant: AsyncAssistant) -> None: self._assistant_thread_context_store = assistant.thread_context_store def register_listener(listener: AsyncListener) -> None: + # Keep Assistant listeners before catch-all listeners while preserving Assistant registration order. self._async_listeners.insert(self._assistant_listener_insertion_index, listener) self._assistant_listener_insertion_index += 1 diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index 68d37dbc3..07fa15e8a 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -70,7 +70,7 @@ def thread_started( all_matchers = self._merge_matchers(is_assistant_thread_started_event, matchers) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._thread_started_listeners, self.build_listener( listener_or_functions=func, @@ -82,7 +82,7 @@ def thread_started( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._thread_started_listeners, self.build_listener( listener_or_functions=functions, @@ -111,7 +111,7 @@ def user_message( all_matchers = self._merge_matchers(is_user_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._user_message_listeners, self.build_listener( listener_or_functions=func, @@ -123,7 +123,7 @@ def user_message( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._user_message_listeners, self.build_listener( listener_or_functions=functions, @@ -152,7 +152,7 @@ def bot_message( all_matchers = self._merge_matchers(is_bot_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._bot_message_listeners, self.build_listener( listener_or_functions=func, @@ -164,7 +164,7 @@ def bot_message( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._bot_message_listeners, self.build_listener( listener_or_functions=functions, @@ -193,7 +193,7 @@ def thread_context_changed( all_matchers = self._merge_matchers(is_assistant_thread_context_changed_event, matchers) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, @@ -205,7 +205,7 @@ def thread_context_changed( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, @@ -244,7 +244,8 @@ def _register_app_listeners(self, listener_registrar: Callable[[Listener], None] self.thread_context_changed(self.default_thread_context_changed) if self._other_message_sub_event_listeners is None: self._other_message_sub_event_listeners = [] - self._append_listener( + # Preserve the middleware path's ack behavior for message_changed, message_deleted, and similar subevents. + self._append_and_register_listener( self._other_message_sub_event_listeners, self.build_listener( listener_or_functions=self.default_other_message_sub_event, @@ -263,7 +264,7 @@ def _register_app_listeners(self, listener_registrar: Callable[[Listener], None] listener_registrar(listener) self._app_listener_registrars.append(listener_registrar) - def _append_listener(self, listeners: List[Listener], listener: Listener) -> None: + def _append_and_register_listener(self, listeners: List[Listener], listener: Listener) -> None: listeners.append(listener) for registrar in self._app_listener_registrars: registrar(listener) diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index 47acb0563..e57eb2bc2 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -77,7 +77,7 @@ def thread_started( ) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._thread_started_listeners, self.build_listener( listener_or_functions=func, @@ -89,7 +89,7 @@ def thread_started( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._thread_started_listeners, self.build_listener( listener_or_functions=functions, @@ -125,7 +125,7 @@ def user_message( ) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._user_message_listeners, self.build_listener( listener_or_functions=func, @@ -137,7 +137,7 @@ def user_message( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._user_message_listeners, self.build_listener( listener_or_functions=functions, @@ -173,7 +173,7 @@ def bot_message( ) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._bot_message_listeners, self.build_listener( listener_or_functions=func, @@ -185,7 +185,7 @@ def bot_message( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._bot_message_listeners, self.build_listener( listener_or_functions=functions, @@ -221,7 +221,7 @@ def thread_context_changed( ) if is_used_without_argument(args): func = args[0] - self._append_listener( + self._append_and_register_listener( self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, @@ -233,7 +233,7 @@ def thread_context_changed( def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._append_listener( + self._append_and_register_listener( self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, @@ -279,7 +279,8 @@ def _register_app_listeners(self, listener_registrar: Callable[[AsyncListener], ), # type: ignore[arg-type] None, ) - self._append_listener( + # Preserve the middleware path's ack behavior for message_changed, message_deleted, and similar subevents. + self._append_and_register_listener( self._other_message_sub_event_listeners, self.build_listener( listener_or_functions=self.default_other_message_sub_event, @@ -298,7 +299,7 @@ def _register_app_listeners(self, listener_registrar: Callable[[AsyncListener], listener_registrar(listener) self._app_listener_registrars.append(listener_registrar) - def _append_listener(self, listeners: List[AsyncListener], listener: AsyncListener) -> None: + def _append_and_register_listener(self, listeners: List[AsyncListener], listener: AsyncListener) -> None: listeners.append(listener) for registrar in self._app_listener_registrars: registrar(listener) From 32e48e0445789fac3808edf391dc86ca4f0bc21b Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 22:29:01 +1200 Subject: [PATCH 5/6] Use listener registry for assistant ordering --- slack_bolt/app/app.py | 8 +++---- slack_bolt/app/async_app.py | 8 +++---- slack_bolt/app/listener_registry.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 slack_bolt/app/listener_registry.py diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index bce705d49..c2d43c04a 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -20,6 +20,7 @@ CallableAuthorize, ) +from slack_bolt.app.listener_registry import ListenerRegistry from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore from slack_bolt.error import BoltError, BoltUnhandledRequestError @@ -348,8 +349,7 @@ def message_hello(message, say): # -------------------------------------- self._middleware_list: List[Middleware] = [] - self._listeners: List[Listener] = [] - self._assistant_listener_insertion_index = 0 + self._listeners: ListenerRegistry[Listener] = ListenerRegistry() if listener_executor is None: listener_executor = ThreadPoolExecutor(max_workers=5) @@ -706,9 +706,7 @@ def _register_assistant_listeners(self, assistant: Assistant) -> None: self._assistant_thread_context_store = assistant.thread_context_store def register_listener(listener: Listener) -> None: - # Keep Assistant listeners before catch-all listeners while preserving Assistant registration order. - self._listeners.insert(self._assistant_listener_insertion_index, listener) - self._assistant_listener_insertion_index += 1 + self._listeners.append_assistant(listener) assistant._register_app_listeners(register_listener) diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index c821026b4..6a61af3ea 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -8,6 +8,7 @@ from aiohttp import web from slack_bolt.app.async_server import AsyncSlackAppServer +from slack_bolt.app.listener_registry import ListenerRegistry from slack_bolt.context.assistant.thread_context_store.async_store import ( AsyncAssistantThreadContextStore, ) @@ -360,8 +361,7 @@ async def message_hello(message, say): # async function # -------------------------------------- self._async_middleware_list: List[AsyncMiddleware] = [] - self._async_listeners: List[AsyncListener] = [] - self._assistant_listener_insertion_index = 0 + self._async_listeners: ListenerRegistry[AsyncListener] = ListenerRegistry() self._assistant_thread_context_store = assistant_thread_context_store self._attaching_conversation_kwargs_enabled = attaching_conversation_kwargs_enabled @@ -733,9 +733,7 @@ def _register_assistant_listeners(self, assistant: AsyncAssistant) -> None: self._assistant_thread_context_store = assistant.thread_context_store def register_listener(listener: AsyncListener) -> None: - # Keep Assistant listeners before catch-all listeners while preserving Assistant registration order. - self._async_listeners.insert(self._assistant_listener_insertion_index, listener) - self._assistant_listener_insertion_index += 1 + self._async_listeners.append_assistant(listener) assistant._register_app_listeners(register_listener) diff --git a/slack_bolt/app/listener_registry.py b/slack_bolt/app/listener_registry.py new file mode 100644 index 000000000..28087e19d --- /dev/null +++ b/slack_bolt/app/listener_registry.py @@ -0,0 +1,33 @@ +from typing import Generic, Iterator, List, TypeVar, Union, overload + +ListenerT = TypeVar("ListenerT") + + +class ListenerRegistry(Generic[ListenerT]): + def __init__(self) -> None: + self._assistant_listeners: List[ListenerT] = [] + self._listeners: List[ListenerT] = [] + + def append(self, listener: ListenerT) -> None: + self._listeners.append(listener) + + def append_assistant(self, listener: ListenerT) -> None: + self._assistant_listeners.append(listener) + + def __iter__(self) -> Iterator[ListenerT]: + yield from self._assistant_listeners + yield from self._listeners + + def __len__(self) -> int: + return len(self._assistant_listeners) + len(self._listeners) + + @overload + def __getitem__(self, index: int) -> ListenerT: + pass + + @overload + def __getitem__(self, index: slice) -> List[ListenerT]: + pass + + def __getitem__(self, index: Union[int, slice]) -> Union[ListenerT, List[ListenerT]]: + return list(self)[index] From 8822ac4082bc382f4c4077d693a578f76495f6a2 Mon Sep 17 00:00:00 2001 From: Voronin Sergei Date: Sun, 31 May 2026 22:50:38 +1200 Subject: [PATCH 6/6] Add explicit assistant registration mode --- slack_bolt/app/app.py | 13 +- slack_bolt/app/async_app.py | 13 +- slack_bolt/middleware/assistant/assistant.py | 2 - .../middleware/assistant/async_assistant.py | 2 - .../app/test_app_assistant_middleware.py | 120 ++++++++++++++---- 5 files changed, 106 insertions(+), 44 deletions(-) diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index c2d43c04a..8686ea97f 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -681,10 +681,6 @@ def middleware_func(logger, body, next): middleware_or_callable = args[0] if isinstance(middleware_or_callable, Middleware): middleware: Middleware = middleware_or_callable - if isinstance(middleware, Assistant) and middleware.auto_inherit_app_middleware is True: - # In this opt-in mode, Assistant handlers should run through the app listener pipeline. - self._register_assistant_listeners(middleware) - return None self._middleware_list.append(middleware) if isinstance(middleware, Assistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store @@ -713,8 +709,13 @@ def register_listener(listener: Listener) -> None: # ------------------------- # AI Agents & Assistants - def assistant(self, assistant: Assistant) -> Optional[Callable]: - return self.middleware(assistant) + def assistant(self, assistant: Assistant, mode: str = "middleware") -> Optional[Callable]: + if mode == "middleware": + return self.middleware(assistant) + if mode == "listeners": + self._register_assistant_listeners(assistant) + return None + raise BoltError(f"Unsupported Assistant registration mode ({mode})") # ------------------------- # Workflows: Steps from apps diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 6a61af3ea..a926a7e4b 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -708,10 +708,6 @@ async def middleware_func(logger, body, next): middleware_or_callable = args[0] if isinstance(middleware_or_callable, AsyncMiddleware): middleware: AsyncMiddleware = middleware_or_callable - if isinstance(middleware, AsyncAssistant) and middleware.auto_inherit_app_middleware is True: - # In this opt-in mode, Assistant handlers should run through the app listener pipeline. - self._register_assistant_listeners(middleware) - return None self._async_middleware_list.append(middleware) if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None: self._assistant_thread_context_store = middleware.thread_context_store @@ -737,8 +733,13 @@ def register_listener(listener: AsyncListener) -> None: assistant._register_app_listeners(register_listener) - def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]: - return self.middleware(assistant) + def assistant(self, assistant: AsyncAssistant, mode: str = "middleware") -> Optional[Callable]: + if mode == "middleware": + return self.middleware(assistant) + if mode == "listeners": + self._register_assistant_listeners(assistant) + return None + raise BoltError(f"Unsupported Assistant registration mode ({mode})") # ------------------------- # Workflows: Steps from apps diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index 07fa15e8a..b88cfd48b 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -44,12 +44,10 @@ def __init__( app_name: str = "assistant", thread_context_store: Optional[AssistantThreadContextStore] = None, logger: Optional[logging.Logger] = None, - auto_inherit_app_middleware: bool = False, ): self.app_name = app_name self.thread_context_store = thread_context_store self.base_logger = logger - self.auto_inherit_app_middleware = auto_inherit_app_middleware self._thread_started_listeners = None self._thread_context_changed_listeners = None diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index e57eb2bc2..e89282745 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -44,12 +44,10 @@ def __init__( app_name: str = "assistant", thread_context_store: Optional[AsyncAssistantThreadContextStore] = None, logger: Optional[logging.Logger] = None, - auto_inherit_app_middleware: bool = False, ): self.app_name = app_name self.thread_context_store = thread_context_store self.base_logger = logger - self.auto_inherit_app_middleware = auto_inherit_app_middleware self._thread_started_listeners = None self._thread_context_changed_listeners = None diff --git a/tests/slack_bolt/app/test_app_assistant_middleware.py b/tests/slack_bolt/app/test_app_assistant_middleware.py index bdc313e6a..3a92c4fe9 100644 --- a/tests/slack_bolt/app/test_app_assistant_middleware.py +++ b/tests/slack_bolt/app/test_app_assistant_middleware.py @@ -7,6 +7,7 @@ from slack_bolt import App, Assistant, BoltRequest from slack_bolt.async_app import AsyncApp, AsyncAssistant, AsyncBoltRequest from slack_bolt.authorization import AuthorizeResult +from slack_bolt.error import BoltError from slack_bolt.middleware import Middleware from slack_bolt.middleware.async_middleware import AsyncMiddleware from slack_bolt.request import BoltRequest as BoltRequestType @@ -30,9 +31,9 @@ async def async_authorize_test_app(context, enterprise_id, team_id, user_id): class TestAppAssistantMiddleware: - def test_auto_inherit_assistant_registers_handlers_as_app_listeners(self): + def test_assistant_listener_mode_registers_handlers_as_app_listeners(self): app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) - assistant = Assistant(auto_inherit_app_middleware=True) + assistant = Assistant() @app.message("") def catch_all(): @@ -42,15 +43,15 @@ def catch_all(): def handle_user_message(): pass - app.assistant(assistant) + app.assistant(assistant, mode="listeners") - second_assistant = Assistant(auto_inherit_app_middleware=True) + second_assistant = Assistant() @second_assistant.user_message def handle_second_user_message(): pass - app.assistant(second_assistant) + app.assistant(second_assistant, mode="listeners") assert assistant not in app._middleware_list listener_functions = [listener.ack_function for listener in app._listeners] @@ -59,9 +60,9 @@ def handle_second_user_message(): assert listener_functions.index(handle_user_message) < listener_functions.index(handle_second_user_message) assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) - def test_assistant_inherits_app_middleware_registered_after_assistant(self): + def test_assistant_listener_mode_inherits_app_middleware_registered_after_assistant(self): app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) - assistant = Assistant(auto_inherit_app_middleware=True) + assistant = Assistant() calls = [] class ListenerMiddleware(Middleware): @@ -75,7 +76,7 @@ def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[] def handle_user_message(): calls.append("handler") - app.assistant(assistant) + app.assistant(assistant, mode="listeners") @app.middleware def app_middleware(req, next): @@ -109,9 +110,9 @@ def handle_user_message(): assert response.status == 200 assert calls == ["handler"] - def test_assistant_inherits_app_middleware_for_listeners_registered_later(self): + def test_assistant_middleware_mode_does_not_inherit_app_middleware(self): app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) - assistant = Assistant(auto_inherit_app_middleware=True) + assistant = Assistant() calls = [] class AppMiddleware(Middleware): @@ -119,7 +120,29 @@ def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[] calls.append("app") return next() - app.assistant(assistant) + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="middleware") + app.middleware(AppMiddleware()) + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + def test_assistant_listener_mode_inherits_app_middleware_for_listeners_registered_later(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + app.assistant(assistant, mode="listeners") app.middleware(AppMiddleware()) @assistant.user_message @@ -131,9 +154,9 @@ def handle_user_message(): assert response.status == 200 assert calls == ["app", "handler"] - def test_assistant_inherited_app_middleware_can_short_circuit(self): + def test_assistant_listener_mode_inherited_app_middleware_can_short_circuit(self): app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) - assistant = Assistant(auto_inherit_app_middleware=True) + assistant = Assistant() calls = [] class BlockingMiddleware(Middleware): @@ -145,7 +168,7 @@ def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[] def start_thread(): calls.append("handler") - app.assistant(assistant) + app.assistant(assistant, mode="listeners") app.middleware(BlockingMiddleware()) request = BoltRequest(body=thread_started_event_body, mode="socket_mode") @@ -153,11 +176,17 @@ def start_thread(): assert response.status == 201 assert calls == ["app"] + def test_assistant_rejects_unknown_mode(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + + with pytest.raises(BoltError, match="Unsupported Assistant registration mode"): + app.assistant(Assistant(), mode="something") + class TestAsyncAppAssistantMiddleware: - def test_auto_inherit_assistant_registers_handlers_as_app_listeners(self): + def test_assistant_listener_mode_registers_handlers_as_app_listeners(self): app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) - assistant = AsyncAssistant(auto_inherit_app_middleware=True) + assistant = AsyncAssistant() @app.message("") async def catch_all(): @@ -167,15 +196,15 @@ async def catch_all(): async def handle_user_message(): pass - app.assistant(assistant) + app.assistant(assistant, mode="listeners") - second_assistant = AsyncAssistant(auto_inherit_app_middleware=True) + second_assistant = AsyncAssistant() @second_assistant.user_message async def handle_second_user_message(): pass - app.assistant(second_assistant) + app.assistant(second_assistant, mode="listeners") assert assistant not in app._async_middleware_list listener_functions = [listener.ack_function for listener in app._async_listeners] @@ -185,9 +214,9 @@ async def handle_second_user_message(): assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) @pytest.mark.asyncio - async def test_assistant_inherits_app_middleware_registered_after_assistant(self): + async def test_assistant_listener_mode_inherits_app_middleware_registered_after_assistant(self): app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) - assistant = AsyncAssistant(auto_inherit_app_middleware=True) + assistant = AsyncAssistant() calls = [] class ListenerMiddleware(AsyncMiddleware): @@ -207,7 +236,7 @@ async def async_process( async def handle_user_message(): calls.append("handler") - app.assistant(assistant) + app.assistant(assistant, mode="listeners") @app.middleware async def app_middleware(req, next): @@ -249,9 +278,9 @@ async def handle_user_message(): assert calls == ["handler"] @pytest.mark.asyncio - async def test_assistant_inherits_app_middleware_for_listeners_registered_later(self): + async def test_assistant_middleware_mode_does_not_inherit_app_middleware(self): app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) - assistant = AsyncAssistant(auto_inherit_app_middleware=True) + assistant = AsyncAssistant() calls = [] class AppMiddleware(AsyncMiddleware): @@ -265,7 +294,36 @@ async def async_process( calls.append("app") return await next() - app.assistant(assistant) + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="middleware") + app.middleware(AppMiddleware()) + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + @pytest.mark.asyncio + async def test_assistant_listener_mode_inherits_app_middleware_for_listeners_registered_later(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + app.assistant(assistant, mode="listeners") app.middleware(AppMiddleware()) @assistant.user_message @@ -278,9 +336,9 @@ async def handle_user_message(): assert calls == ["app", "handler"] @pytest.mark.asyncio - async def test_assistant_inherited_app_middleware_can_short_circuit(self): + async def test_assistant_listener_mode_inherited_app_middleware_can_short_circuit(self): app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) - assistant = AsyncAssistant(auto_inherit_app_middleware=True) + assistant = AsyncAssistant() calls = [] class BlockingMiddleware(AsyncMiddleware): @@ -298,10 +356,16 @@ async def async_process( async def start_thread(): calls.append("handler") - app.assistant(assistant) + app.assistant(assistant, mode="listeners") app.middleware(BlockingMiddleware()) request = AsyncBoltRequest(body=thread_started_event_body, mode="socket_mode") response = await app.async_dispatch(request) assert response.status == 201 assert calls == ["app"] + + def test_assistant_rejects_unknown_mode(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + + with pytest.raises(BoltError, match="Unsupported Assistant registration mode"): + app.assistant(AsyncAssistant(), mode="something")