diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 0af27913c..8686ea97f 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -20,6 +20,7 @@ CallableAuthorize, ) +from slack_bolt.app.listener_registry import ListenerRegistry from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore from slack_bolt.error import BoltError, BoltUnhandledRequestError @@ -348,7 +349,7 @@ def message_hello(message, say): # -------------------------------------- self._middleware_list: List[Middleware] = [] - self._listeners: List[Listener] = [] + self._listeners: ListenerRegistry[Listener] = ListenerRegistry() if listener_executor is None: listener_executor = ThreadPoolExecutor(max_workers=5) @@ -696,11 +697,25 @@ def middleware_func(logger, body, next): raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None + def _register_assistant_listeners(self, assistant: Assistant) -> None: + if assistant.thread_context_store is not None: + self._assistant_thread_context_store = assistant.thread_context_store + + def register_listener(listener: Listener) -> None: + self._listeners.append_assistant(listener) + + assistant._register_app_listeners(register_listener) + # ------------------------- # AI Agents & Assistants - def assistant(self, assistant: Assistant) -> Optional[Callable]: - return self.middleware(assistant) + def assistant(self, assistant: Assistant, mode: str = "middleware") -> Optional[Callable]: + if mode == "middleware": + return self.middleware(assistant) + if mode == "listeners": + self._register_assistant_listeners(assistant) + return None + raise BoltError(f"Unsupported Assistant registration mode ({mode})") # ------------------------- # Workflows: Steps from apps diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index cc94f9e15..a926a7e4b 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -8,6 +8,7 @@ from aiohttp import web from slack_bolt.app.async_server import AsyncSlackAppServer +from slack_bolt.app.listener_registry import ListenerRegistry from slack_bolt.context.assistant.thread_context_store.async_store import ( AsyncAssistantThreadContextStore, ) @@ -360,7 +361,7 @@ async def message_hello(message, say): # async function # -------------------------------------- self._async_middleware_list: List[AsyncMiddleware] = [] - self._async_listeners: List[AsyncListener] = [] + self._async_listeners: ListenerRegistry[AsyncListener] = ListenerRegistry() self._assistant_thread_context_store = assistant_thread_context_store self._attaching_conversation_kwargs_enabled = attaching_conversation_kwargs_enabled @@ -723,8 +724,22 @@ async def middleware_func(logger, body, next): raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})") return None - def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]: - return self.middleware(assistant) + def _register_assistant_listeners(self, assistant: AsyncAssistant) -> None: + if assistant.thread_context_store is not None: + self._assistant_thread_context_store = assistant.thread_context_store + + def register_listener(listener: AsyncListener) -> None: + self._async_listeners.append_assistant(listener) + + assistant._register_app_listeners(register_listener) + + def assistant(self, assistant: AsyncAssistant, mode: str = "middleware") -> Optional[Callable]: + if mode == "middleware": + return self.middleware(assistant) + if mode == "listeners": + self._register_assistant_listeners(assistant) + return None + raise BoltError(f"Unsupported Assistant registration mode ({mode})") # ------------------------- # Workflows: Steps from apps diff --git a/slack_bolt/app/listener_registry.py b/slack_bolt/app/listener_registry.py new file mode 100644 index 000000000..28087e19d --- /dev/null +++ b/slack_bolt/app/listener_registry.py @@ -0,0 +1,33 @@ +from typing import Generic, Iterator, List, TypeVar, Union, overload + +ListenerT = TypeVar("ListenerT") + + +class ListenerRegistry(Generic[ListenerT]): + def __init__(self) -> None: + self._assistant_listeners: List[ListenerT] = [] + self._listeners: List[ListenerT] = [] + + def append(self, listener: ListenerT) -> None: + self._listeners.append(listener) + + def append_assistant(self, listener: ListenerT) -> None: + self._assistant_listeners.append(listener) + + def __iter__(self) -> Iterator[ListenerT]: + yield from self._assistant_listeners + yield from self._listeners + + def __len__(self) -> int: + return len(self._assistant_listeners) + len(self._listeners) + + @overload + def __getitem__(self, index: int) -> ListenerT: + pass + + @overload + def __getitem__(self, index: slice) -> List[ListenerT]: + pass + + def __getitem__(self, index: Union[int, slice]) -> Union[ListenerT, List[ListenerT]]: + return list(self)[index] diff --git a/slack_bolt/middleware/assistant/assistant.py b/slack_bolt/middleware/assistant/assistant.py index ad842f94d..b88cfd48b 100644 --- a/slack_bolt/middleware/assistant/assistant.py +++ b/slack_bolt/middleware/assistant/assistant.py @@ -32,6 +32,8 @@ class Assistant(Middleware): _thread_context_changed_listeners: Optional[List[Listener]] _user_message_listeners: Optional[List[Listener]] _bot_message_listeners: Optional[List[Listener]] + _other_message_sub_event_listeners: Optional[List[Listener]] + _app_listener_registrars: List[Callable[[Listener], None]] thread_context_store: Optional[AssistantThreadContextStore] base_logger: Optional[logging.Logger] @@ -51,6 +53,8 @@ def __init__( self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None + self._other_message_sub_event_listeners = None + self._app_listener_registrars = [] def thread_started( self, @@ -64,23 +68,25 @@ def thread_started( all_matchers = self._merge_matchers(is_assistant_thread_started_event, matchers) if is_used_without_argument(args): func = args[0] - self._thread_started_listeners.append( + self._append_and_register_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_started_listeners.append( + self._append_and_register_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -103,23 +109,25 @@ def user_message( all_matchers = self._merge_matchers(is_user_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._user_message_listeners.append( + self._append_and_register_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._user_message_listeners.append( + self._append_and_register_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -142,23 +150,25 @@ def bot_message( all_matchers = self._merge_matchers(is_bot_message_event_in_assistant_thread, matchers) if is_used_without_argument(args): func = args[0] - self._bot_message_listeners.append( + self._append_and_register_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._bot_message_listeners.append( + self._append_and_register_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -181,23 +191,25 @@ def thread_context_changed( all_matchers = self._merge_matchers(is_assistant_thread_context_changed_event, matchers) if is_used_without_argument(args): func = args[0] - self._thread_context_changed_listeners.append( + self._append_and_register_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_context_changed_listeners.append( + self._append_and_register_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -221,6 +233,40 @@ def _merge_matchers( def default_thread_context_changed(save_thread_context: SaveThreadContext, payload: dict): save_thread_context(payload["assistant_thread"]["context"]) + @staticmethod + def default_other_message_sub_event(ack): + ack() + + def _register_app_listeners(self, listener_registrar: Callable[[Listener], None]) -> None: + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + # Preserve the middleware path's ack behavior for message_changed, message_deleted, and similar subevents. + self._append_and_register_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=[is_other_message_sub_event_in_assistant_thread], + ), + ) + for listener_list in [ + self._thread_started_listeners, + self._thread_context_changed_listeners, + self._user_message_listeners, + self._bot_message_listeners, + self._other_message_sub_event_listeners, + ]: + if listener_list is not None: + for listener in listener_list: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) + + def _append_and_register_listener(self, listeners: List[Listener], listener: Listener) -> None: + listeners.append(listener) + for registrar in self._app_listener_registrars: + registrar(listener) + def process( # type: ignore[return] self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse] ) -> Optional[BoltResponse]: diff --git a/slack_bolt/middleware/assistant/async_assistant.py b/slack_bolt/middleware/assistant/async_assistant.py index 588de8b41..e89282745 100644 --- a/slack_bolt/middleware/assistant/async_assistant.py +++ b/slack_bolt/middleware/assistant/async_assistant.py @@ -32,6 +32,8 @@ class AsyncAssistant(AsyncMiddleware): _user_message_listeners: Optional[List[AsyncListener]] _bot_message_listeners: Optional[List[AsyncListener]] _thread_context_changed_listeners: Optional[List[AsyncListener]] + _other_message_sub_event_listeners: Optional[List[AsyncListener]] + _app_listener_registrars: List[Callable[[AsyncListener], None]] thread_context_store: Optional[AsyncAssistantThreadContextStore] base_logger: Optional[logging.Logger] @@ -51,6 +53,8 @@ def __init__( self._thread_context_changed_listeners = None self._user_message_listeners = None self._bot_message_listeners = None + self._other_message_sub_event_listeners = None + self._app_listener_registrars = [] def thread_started( self, @@ -71,23 +75,25 @@ def thread_started( ) if is_used_without_argument(args): func = args[0] - self._thread_started_listeners.append( + self._append_and_register_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_started_listeners.append( + self._append_and_register_listener( + self._thread_started_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -117,23 +123,25 @@ def user_message( ) if is_used_without_argument(args): func = args[0] - self._user_message_listeners.append( + self._append_and_register_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._user_message_listeners.append( + self._append_and_register_listener( + self._user_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -163,23 +171,25 @@ def bot_message( ) if is_used_without_argument(args): func = args[0] - self._bot_message_listeners.append( + self._append_and_register_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._bot_message_listeners.append( + self._append_and_register_listener( + self._bot_message_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -209,23 +219,25 @@ def thread_context_changed( ) if is_used_without_argument(args): func = args[0] - self._thread_context_changed_listeners.append( + self._append_and_register_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=func, matchers=all_matchers, middleware=middleware, # type: ignore[arg-type] - ) + ), ) return func def _inner(func): functions = [func] + (lazy if lazy is not None else []) - self._thread_context_changed_listeners.append( + self._append_and_register_listener( + self._thread_context_changed_listeners, self.build_listener( listener_or_functions=functions, matchers=all_matchers, middleware=middleware, - ) + ), ) @wraps(func) @@ -248,6 +260,48 @@ async def default_thread_context_changed(save_thread_context: AsyncSaveThreadCon new_context: dict = payload["assistant_thread"]["context"] await save_thread_context(new_context) + @staticmethod + async def default_other_message_sub_event(ack): + await ack() + + def _register_app_listeners(self, listener_registrar: Callable[[AsyncListener], None]) -> None: + if self._thread_context_changed_listeners is None: + self.thread_context_changed(self.default_thread_context_changed) + if self._other_message_sub_event_listeners is None: + self._other_message_sub_event_listeners = [] + all_matchers = self._merge_matchers( + build_listener_matcher( + func=is_other_message_sub_event_in_assistant_thread, + asyncio=True, + base_logger=self.base_logger, + ), # type: ignore[arg-type] + None, + ) + # Preserve the middleware path's ack behavior for message_changed, message_deleted, and similar subevents. + self._append_and_register_listener( + self._other_message_sub_event_listeners, + self.build_listener( + listener_or_functions=self.default_other_message_sub_event, + matchers=all_matchers, + ), + ) + for listener_list in [ + self._thread_started_listeners, + self._thread_context_changed_listeners, + self._user_message_listeners, + self._bot_message_listeners, + self._other_message_sub_event_listeners, + ]: + if listener_list is not None: + for listener in listener_list: + listener_registrar(listener) + self._app_listener_registrars.append(listener_registrar) + + def _append_and_register_listener(self, listeners: List[AsyncListener], listener: AsyncListener) -> None: + listeners.append(listener) + for registrar in self._app_listener_registrars: + registrar(listener) + async def async_process( # type: ignore[return] self, *, diff --git a/tests/slack_bolt/app/test_app_assistant_middleware.py b/tests/slack_bolt/app/test_app_assistant_middleware.py new file mode 100644 index 000000000..3a92c4fe9 --- /dev/null +++ b/tests/slack_bolt/app/test_app_assistant_middleware.py @@ -0,0 +1,371 @@ +from typing import Awaitable, Callable, Optional + +import pytest +from slack_sdk import WebClient +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt import App, Assistant, BoltRequest +from slack_bolt.async_app import AsyncApp, AsyncAssistant, AsyncBoltRequest +from slack_bolt.authorization import AuthorizeResult +from slack_bolt.error import BoltError +from slack_bolt.middleware import Middleware +from slack_bolt.middleware.async_middleware import AsyncMiddleware +from slack_bolt.request import BoltRequest as BoltRequestType +from slack_bolt.response import BoltResponse +from tests.scenario_tests.test_events_assistant import thread_started_event_body, user_message_event_body + + +def authorize_test_app(context, enterprise_id, team_id, user_id): + return AuthorizeResult( + enterprise_id=enterprise_id, + team_id=team_id, + user_id=user_id, + bot_user_id="W111", + bot_id="B111", + bot_token="xoxb-valid", + ) + + +async def async_authorize_test_app(context, enterprise_id, team_id, user_id): + return authorize_test_app(context, enterprise_id, team_id, user_id) + + +class TestAppAssistantMiddleware: + def test_assistant_listener_mode_registers_handlers_as_app_listeners(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + + @app.message("") + def catch_all(): + pass + + @assistant.user_message + def handle_user_message(): + pass + + app.assistant(assistant, mode="listeners") + + second_assistant = Assistant() + + @second_assistant.user_message + def handle_second_user_message(): + pass + + app.assistant(second_assistant, mode="listeners") + + assert assistant not in app._middleware_list + listener_functions = [listener.ack_function for listener in app._listeners] + assert handle_user_message in listener_functions + assert listener_functions.index(handle_user_message) < listener_functions.index(catch_all) + assert listener_functions.index(handle_user_message) < listener_functions.index(handle_second_user_message) + assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) + + def test_assistant_listener_mode_inherits_app_middleware_registered_after_assistant(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class ListenerMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("listener") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None + return next() + + @assistant.user_message(middleware=[ListenerMiddleware()]) + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="listeners") + + @app.middleware + def app_middleware(req, next): + calls.append("app") + return next() + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["app", "listener", "handler"] + + def test_assistant_does_not_inherit_app_middleware_by_default(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + def test_assistant_middleware_mode_does_not_inherit_app_middleware(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="middleware") + app.middleware(AppMiddleware()) + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + def test_assistant_listener_mode_inherits_app_middleware_for_listeners_registered_later(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class AppMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return next() + + app.assistant(assistant, mode="listeners") + app.middleware(AppMiddleware()) + + @assistant.user_message + def handle_user_message(): + calls.append("handler") + + request = BoltRequest(body=user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert calls == ["app", "handler"] + + def test_assistant_listener_mode_inherited_app_middleware_can_short_circuit(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + assistant = Assistant() + calls = [] + + class BlockingMiddleware(Middleware): + def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]): + calls.append("app") + return BoltResponse(status=201) + + @assistant.thread_started + def start_thread(): + calls.append("handler") + + app.assistant(assistant, mode="listeners") + app.middleware(BlockingMiddleware()) + + request = BoltRequest(body=thread_started_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 201 + assert calls == ["app"] + + def test_assistant_rejects_unknown_mode(self): + app = App(client=WebClient(token=None), authorize=authorize_test_app, process_before_response=True) + + with pytest.raises(BoltError, match="Unsupported Assistant registration mode"): + app.assistant(Assistant(), mode="something") + + +class TestAsyncAppAssistantMiddleware: + def test_assistant_listener_mode_registers_handlers_as_app_listeners(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + + @app.message("") + async def catch_all(): + pass + + @assistant.user_message + async def handle_user_message(): + pass + + app.assistant(assistant, mode="listeners") + + second_assistant = AsyncAssistant() + + @second_assistant.user_message + async def handle_second_user_message(): + pass + + app.assistant(second_assistant, mode="listeners") + + assert assistant not in app._async_middleware_list + listener_functions = [listener.ack_function for listener in app._async_listeners] + assert handle_user_message in listener_functions + assert listener_functions.index(handle_user_message) < listener_functions.index(catch_all) + assert listener_functions.index(handle_user_message) < listener_functions.index(handle_second_user_message) + assert listener_functions.index(handle_second_user_message) < listener_functions.index(catch_all) + + @pytest.mark.asyncio + async def test_assistant_listener_mode_inherits_app_middleware_registered_after_assistant(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class ListenerMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("listener") + assert req.context.get("set_status") is not None + assert req.context.get("set_title") is not None + return await next() + + @assistant.user_message(middleware=[ListenerMiddleware()]) + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="listeners") + + @app.middleware + async def app_middleware(req, next): + calls.append("app") + return await next() + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["app", "listener", "handler"] + + @pytest.mark.asyncio + async def test_assistant_does_not_inherit_app_middleware_by_default(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant) + app.middleware(AppMiddleware()) + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + @pytest.mark.asyncio + async def test_assistant_middleware_mode_does_not_inherit_app_middleware(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + app.assistant(assistant, mode="middleware") + app.middleware(AppMiddleware()) + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["handler"] + + @pytest.mark.asyncio + async def test_assistant_listener_mode_inherits_app_middleware_for_listeners_registered_later(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class AppMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return await next() + + app.assistant(assistant, mode="listeners") + app.middleware(AppMiddleware()) + + @assistant.user_message + async def handle_user_message(): + calls.append("handler") + + request = AsyncBoltRequest(body=user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + assert calls == ["app", "handler"] + + @pytest.mark.asyncio + async def test_assistant_listener_mode_inherited_app_middleware_can_short_circuit(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + assistant = AsyncAssistant() + calls = [] + + class BlockingMiddleware(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + next: Callable[[], Awaitable[BoltResponse]], + ) -> Optional[BoltResponse]: + calls.append("app") + return BoltResponse(status=201) + + @assistant.thread_started + async def start_thread(): + calls.append("handler") + + app.assistant(assistant, mode="listeners") + app.middleware(BlockingMiddleware()) + + request = AsyncBoltRequest(body=thread_started_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 201 + assert calls == ["app"] + + def test_assistant_rejects_unknown_mode(self): + app = AsyncApp(client=AsyncWebClient(token=None), authorize=async_authorize_test_app, process_before_response=True) + + with pytest.raises(BoltError, match="Unsupported Assistant registration mode"): + app.assistant(AsyncAssistant(), mode="something")