Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions slack_bolt/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CallableAuthorize,
)

from slack_bolt.app.listener_registry import ListenerRegistry
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore

from slack_bolt.error import BoltError, BoltUnhandledRequestError
Expand Down Expand Up @@ -348,7 +349,7 @@ def message_hello(message, say):
# --------------------------------------

self._middleware_list: List[Middleware] = []
self._listeners: List[Listener] = []
self._listeners: ListenerRegistry[Listener] = ListenerRegistry()

if listener_executor is None:
listener_executor = ThreadPoolExecutor(max_workers=5)
Expand Down Expand Up @@ -696,11 +697,25 @@ def middleware_func(logger, body, next):
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
return None

def _register_assistant_listeners(self, assistant: Assistant) -> None:
if assistant.thread_context_store is not None:
self._assistant_thread_context_store = assistant.thread_context_store

def register_listener(listener: Listener) -> None:
self._listeners.append_assistant(listener)

assistant._register_app_listeners(register_listener)

# -------------------------
# AI Agents & Assistants

def assistant(self, assistant: Assistant) -> Optional[Callable]:
return self.middleware(assistant)
def assistant(self, assistant: Assistant, mode: str = "middleware") -> Optional[Callable]:
if mode == "middleware":
return self.middleware(assistant)
if mode == "listeners":
self._register_assistant_listeners(assistant)
return None
raise BoltError(f"Unsupported Assistant registration mode ({mode})")

# -------------------------
# Workflows: Steps from apps
Expand Down
21 changes: 18 additions & 3 deletions slack_bolt/app/async_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiohttp import web

from slack_bolt.app.async_server import AsyncSlackAppServer
from slack_bolt.app.listener_registry import ListenerRegistry
from slack_bolt.context.assistant.thread_context_store.async_store import (
AsyncAssistantThreadContextStore,
)
Expand Down Expand Up @@ -360,7 +361,7 @@ async def message_hello(message, say): # async function
# --------------------------------------

self._async_middleware_list: List[AsyncMiddleware] = []
self._async_listeners: List[AsyncListener] = []
self._async_listeners: ListenerRegistry[AsyncListener] = ListenerRegistry()

self._assistant_thread_context_store = assistant_thread_context_store
self._attaching_conversation_kwargs_enabled = attaching_conversation_kwargs_enabled
Expand Down Expand Up @@ -723,8 +724,22 @@ async def middleware_func(logger, body, next):
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
return None

def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]:
return self.middleware(assistant)
def _register_assistant_listeners(self, assistant: AsyncAssistant) -> None:
if assistant.thread_context_store is not None:
self._assistant_thread_context_store = assistant.thread_context_store

def register_listener(listener: AsyncListener) -> None:
self._async_listeners.append_assistant(listener)

assistant._register_app_listeners(register_listener)

def assistant(self, assistant: AsyncAssistant, mode: str = "middleware") -> Optional[Callable]:
if mode == "middleware":
return self.middleware(assistant)
if mode == "listeners":
self._register_assistant_listeners(assistant)
return None
raise BoltError(f"Unsupported Assistant registration mode ({mode})")

# -------------------------
# Workflows: Steps from apps
Expand Down
33 changes: 33 additions & 0 deletions slack_bolt/app/listener_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Generic, Iterator, List, TypeVar, Union, overload

ListenerT = TypeVar("ListenerT")


class ListenerRegistry(Generic[ListenerT]):
def __init__(self) -> None:
self._assistant_listeners: List[ListenerT] = []
self._listeners: List[ListenerT] = []

def append(self, listener: ListenerT) -> None:
self._listeners.append(listener)

def append_assistant(self, listener: ListenerT) -> None:
self._assistant_listeners.append(listener)

def __iter__(self) -> Iterator[ListenerT]:
yield from self._assistant_listeners
yield from self._listeners

def __len__(self) -> int:
return len(self._assistant_listeners) + len(self._listeners)

@overload
def __getitem__(self, index: int) -> ListenerT:
pass

@overload
def __getitem__(self, index: slice) -> List[ListenerT]:
pass

def __getitem__(self, index: Union[int, slice]) -> Union[ListenerT, List[ListenerT]]:
return list(self)[index]
78 changes: 62 additions & 16 deletions slack_bolt/middleware/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Assistant(Middleware):
_thread_context_changed_listeners: Optional[List[Listener]]
_user_message_listeners: Optional[List[Listener]]
_bot_message_listeners: Optional[List[Listener]]
_other_message_sub_event_listeners: Optional[List[Listener]]
_app_listener_registrars: List[Callable[[Listener], None]]

thread_context_store: Optional[AssistantThreadContextStore]
base_logger: Optional[logging.Logger]
Expand All @@ -51,6 +53,8 @@ def __init__(
self._thread_context_changed_listeners = None
self._user_message_listeners = None
self._bot_message_listeners = None
self._other_message_sub_event_listeners = None
self._app_listener_registrars = []

def thread_started(
self,
Expand All @@ -64,23 +68,25 @@ def thread_started(
all_matchers = self._merge_matchers(is_assistant_thread_started_event, matchers)
if is_used_without_argument(args):
func = args[0]
self._thread_started_listeners.append(
self._append_and_register_listener(
self._thread_started_listeners,
self.build_listener(
listener_or_functions=func,
matchers=all_matchers,
middleware=middleware, # type: ignore[arg-type]
)
),
)
return func

def _inner(func):
functions = [func] + (lazy if lazy is not None else [])
self._thread_started_listeners.append(
self._append_and_register_listener(
self._thread_started_listeners,
self.build_listener(
listener_or_functions=functions,
matchers=all_matchers,
middleware=middleware,
)
),
)

@wraps(func)
Expand All @@ -103,23 +109,25 @@ def user_message(
all_matchers = self._merge_matchers(is_user_message_event_in_assistant_thread, matchers)
if is_used_without_argument(args):
func = args[0]
self._user_message_listeners.append(
self._append_and_register_listener(
self._user_message_listeners,
self.build_listener(
listener_or_functions=func,
matchers=all_matchers,
middleware=middleware, # type: ignore[arg-type]
)
),
)
return func

def _inner(func):
functions = [func] + (lazy if lazy is not None else [])
self._user_message_listeners.append(
self._append_and_register_listener(
self._user_message_listeners,
self.build_listener(
listener_or_functions=functions,
matchers=all_matchers,
middleware=middleware,
)
),
)

@wraps(func)
Expand All @@ -142,23 +150,25 @@ def bot_message(
all_matchers = self._merge_matchers(is_bot_message_event_in_assistant_thread, matchers)
if is_used_without_argument(args):
func = args[0]
self._bot_message_listeners.append(
self._append_and_register_listener(
self._bot_message_listeners,
self.build_listener(
listener_or_functions=func,
matchers=all_matchers,
middleware=middleware, # type: ignore[arg-type]
)
),
)
return func

def _inner(func):
functions = [func] + (lazy if lazy is not None else [])
self._bot_message_listeners.append(
self._append_and_register_listener(
self._bot_message_listeners,
self.build_listener(
listener_or_functions=functions,
matchers=all_matchers,
middleware=middleware,
)
),
)

@wraps(func)
Expand All @@ -181,23 +191,25 @@ def thread_context_changed(
all_matchers = self._merge_matchers(is_assistant_thread_context_changed_event, matchers)
if is_used_without_argument(args):
func = args[0]
self._thread_context_changed_listeners.append(
self._append_and_register_listener(
self._thread_context_changed_listeners,
self.build_listener(
listener_or_functions=func,
matchers=all_matchers,
middleware=middleware, # type: ignore[arg-type]
)
),
)
return func

def _inner(func):
functions = [func] + (lazy if lazy is not None else [])
self._thread_context_changed_listeners.append(
self._append_and_register_listener(
self._thread_context_changed_listeners,
self.build_listener(
listener_or_functions=functions,
matchers=all_matchers,
middleware=middleware,
)
),
)

@wraps(func)
Expand All @@ -221,6 +233,40 @@ def _merge_matchers(
def default_thread_context_changed(save_thread_context: SaveThreadContext, payload: dict):
save_thread_context(payload["assistant_thread"]["context"])

@staticmethod
def default_other_message_sub_event(ack):
ack()

def _register_app_listeners(self, listener_registrar: Callable[[Listener], None]) -> None:
if self._thread_context_changed_listeners is None:
self.thread_context_changed(self.default_thread_context_changed)
if self._other_message_sub_event_listeners is None:
self._other_message_sub_event_listeners = []
# Preserve the middleware path's ack behavior for message_changed, message_deleted, and similar subevents.
self._append_and_register_listener(
self._other_message_sub_event_listeners,
self.build_listener(
listener_or_functions=self.default_other_message_sub_event,
matchers=[is_other_message_sub_event_in_assistant_thread],
),
)
for listener_list in [
self._thread_started_listeners,
self._thread_context_changed_listeners,
self._user_message_listeners,
self._bot_message_listeners,
self._other_message_sub_event_listeners,
]:
if listener_list is not None:
for listener in listener_list:
listener_registrar(listener)
self._app_listener_registrars.append(listener_registrar)

def _append_and_register_listener(self, listeners: List[Listener], listener: Listener) -> None:
listeners.append(listener)
for registrar in self._app_listener_registrars:
registrar(listener)

def process( # type: ignore[return]
self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]
) -> Optional[BoltResponse]:
Expand Down
Loading