From 99d0966a1c756cec7a33601ac67ee20c0c066f7b Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Sun, 28 Jun 2026 08:28:33 +0530 Subject: [PATCH] feat(daemon): real Discord Gateway + harden Telegram/Slack bridges Discord: replace the REST-poll bridge (which only watched hand-configured channels and never actually fetched anything as wired) with the official Discord Gateway over WebSocket via bwmarrin/discordgo. Receives guild @mentions and DMs in real time; the bot user is learned from READY so no app/channel IDs need configuring. Library handles heartbeats, resume, reconnect, and rate limits. Message filtering and the pairing/allowlist forward policy are split into pure, unit-tested helpers. Cross-gateway hardening: - Bounded, lifecycle-tracked goroutines (asyncDispatcher) drained on Stop; Slack handlers now run under the gateway context, not context.Background. - Platform API errors are checked and logged instead of swallowed (Slack ok=false, Discord/Telegram non-2xx, Telegram getUpdates ok=false). - Telegram: fail-closed bare constructor, context-aware exponential backoff. - setDaemonURL via a daemonURLSetter interface instead of a type switch. Adds github.com/bwmarrin/discordgo (+ gorilla/websocket indirect). --- go.mod | 2 + go.sum | 4 + go.work.sum | 1 + internal/daemon/discord.go | 248 +++++++++++++------------------ internal/daemon/gateway.go | 62 ++++++-- internal/daemon/gateway_test.go | 139 +++++++++++------ internal/daemon/slack.go | 67 ++++++++- internal/daemon/telegram.go | 106 ++++++++++--- internal/daemon/telegram_test.go | 36 ++++- 9 files changed, 428 insertions(+), 237 deletions(-) diff --git a/go.mod b/go.mod index e6e3fde9..651f0f4f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/GrayCodeAI/sight v0.1.0 github.com/GrayCodeAI/tok v0.1.0 github.com/GrayCodeAI/yaad v0.1.0 + github.com/bwmarrin/discordgo v0.28.1 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 @@ -81,6 +82,7 @@ require ( github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/google/cel-go v0.27.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/gorilla/websocket v1.4.2 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/go-version v1.7.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect diff --git a/go.sum b/go.sum index 171b99d9..4f9a7123 100644 --- a/go.sum +++ b/go.sum @@ -299,3 +299,7 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4= +github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/go.work.sum b/go.work.sum index 38e64695..fa06d341 100644 --- a/go.work.sum +++ b/go.work.sum @@ -576,6 +576,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= diff --git a/internal/daemon/discord.go b/internal/daemon/discord.go index 323a8748..a0e59b0a 100644 --- a/internal/daemon/discord.go +++ b/internal/daemon/discord.go @@ -1,61 +1,32 @@ package daemon import ( - "bytes" "context" - "encoding/json" "fmt" - "io" + "log/slog" "net/http" "strings" "time" + + "github.com/bwmarrin/discordgo" ) -// discordAPIBase is the Discord REST API root. It is a package var so tests can -// point it at a mock server. -var discordAPIBase = "https://discord.com/api/v10" - -// DiscordGateway bridges hawk to Discord. Because no WebSocket library is present -// in go.mod, the gateway uses a REST long-poll-equivalent strategy: it does not -// open the Discord Gateway socket but instead relies on a poll over the bot's -// accessible channels for new @mentions/DMs. This keeps the bridge dependency-free -// while remaining bidirectional (it posts replies in-thread via the REST API). -// -// The polling transport is intentionally minimal and pluggable: fetchMessages is -// a field so a real Gateway-WebSocket implementation (or a test) can be swapped -// in without changing the message-handling logic. +// DiscordGateway bridges hawk to Discord via the official Gateway (WebSocket), +// using bwmarrin/discordgo. Unlike a REST-poll bridge it receives message events +// in real time — guild @mentions and direct messages — with reconnection, +// heartbeats, and rate limiting handled by the library. Authorized prompts are +// forwarded to the daemon's /v1/chat endpoint and the reply is posted back to the +// originating channel. type DiscordGateway struct { cfg DiscordConfig daemonAddr string apiKey string - client *http.Client + client *http.Client // for forwardToHawk auth *authorizer + dispatch *asyncDispatcher - // pollChannels lists channel IDs the bot watches. When empty the gateway - // discovers DM channels lazily as it sees them via fetchMessages. - pollChannels []string - // lastSeen maps channelID -> last processed message ID for poll dedup. - lastSeen map[string]string - - // fetchMessages retrieves new messages for a channel since the given message - // ID. Overridable for tests. Default uses the Discord REST API. - fetchMessages func(ctx context.Context, channelID, afterID string) ([]discordMessage, error) -} - -// discordMessage is the subset of a Discord message object we use. -type discordMessage struct { - ID string `json:"id"` - ChannelID string `json:"channel_id"` - Content string `json:"content"` - Author discordUser `json:"author"` - Mentions []discordUser `json:"mentions"` - GuildID string `json:"guild_id,omitempty"` -} - -type discordUser struct { - ID string `json:"id"` - Username string `json:"username"` - Bot bool `json:"bot,omitempty"` + // openSession is overridable in tests; production opens a real Gateway socket. + openSession func(ctx context.Context) (*discordgo.Session, error) } // newDiscordGateway builds a Discord gateway from config. @@ -66,147 +37,140 @@ func newDiscordGateway(cfg DiscordConfig, daemonAddr, apiKey string) *DiscordGat apiKey: apiKey, client: &http.Client{Timeout: 30 * time.Second}, auth: newAuthorizer(cfg.PairingCode, cfg.AllowList), - lastSeen: make(map[string]string), + dispatch: newAsyncDispatcher(8), } - g.fetchMessages = g.fetchMessagesREST + g.openSession = g.openGatewaySession return g } // Name implements Gateway. func (g *DiscordGateway) Name() string { return "discord" } -// Stop implements Gateway. Polling is driven by the Start context. +// setDaemonURL implements daemonURLSetter. +func (g *DiscordGateway) setDaemonURL(url string) { g.daemonAddr = url } + +// Stop implements Gateway. The Gateway connection is driven by the Start context. func (g *DiscordGateway) Stop() error { return nil } -// Start implements Gateway: poll watched channels until ctx is cancelled. +// Start opens the Discord Gateway, registers the message handler, and blocks +// until ctx is cancelled, then drains in-flight handlers and closes the socket. func (g *DiscordGateway) Start(ctx context.Context) error { - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - g.pollOnce(ctx) - } + s, err := g.openSession(ctx) + if err != nil { + return err } + s.AddHandler(func(sess *discordgo.Session, m *discordgo.MessageCreate) { + g.onMessageCreate(ctx, sess, m) + }) + if err := s.Open(); err != nil { + return fmt.Errorf("discord gateway open: %w", err) + } + defer func() { _ = s.Close() }() + + <-ctx.Done() + g.dispatch.wait() + return ctx.Err() } -func (g *DiscordGateway) pollOnce(ctx context.Context) { - for _, ch := range g.pollChannels { - msgs, err := g.fetchMessages(ctx, ch, g.lastSeen[ch]) - if err != nil { - continue - } - for _, m := range msgs { - g.lastSeen[m.ChannelID] = m.ID - g.handleMessage(ctx, m) - } +// openGatewaySession creates a discordgo session configured with the intents the +// bridge needs: guild messages, message content (privileged — enable it in the +// Discord developer portal), and direct messages. +func (g *DiscordGateway) openGatewaySession(_ context.Context) (*discordgo.Session, error) { + s, err := discordgo.New("Bot " + g.cfg.Token) + if err != nil { + return nil, fmt.Errorf("discord session: %w", err) } + s.Identify.Intents = discordgo.IntentsGuildMessages | + discordgo.IntentMessageContent | + discordgo.IntentsDirectMessages + return s, nil } -// mentionsBot reports whether the message is a DM or @mentions the bot user. -func (g *DiscordGateway) mentionsBot(m discordMessage) bool { - // DM: no guild. +// onMessageCreate filters inbound events and dispatches authorized messages to +// the shared handler. Replies go back through the live Gateway session. +func (g *DiscordGateway) onMessageCreate(ctx context.Context, s *discordgo.Session, m *discordgo.MessageCreate) { + selfID := "" + if s.State != nil && s.State.User != nil { + selfID = s.State.User.ID + } + if !wantsDiscordMessage(m, selfID) { + return + } + text := stripDiscordMention(m.Content, selfID) + channelID := m.ChannelID + send := func(content string) error { + _, err := s.ChannelMessageSend(channelID, content) + return err + } + g.dispatch.run(ctx, func() { + g.handleMessage(ctx, m.Author.ID, channelID, text, send) + }) +} + +// wantsDiscordMessage reports whether a message should be processed: skip our own +// messages and other bots; accept direct messages and any message that @mentions +// the bot. selfID is the bot's own user ID (empty until the READY event lands). +func wantsDiscordMessage(m *discordgo.MessageCreate, selfID string) bool { + if m == nil || m.Author == nil || m.Author.Bot { + return false + } + if selfID != "" && m.Author.ID == selfID { + return false + } if m.GuildID == "" { - return true + return true // direct message } for _, u := range m.Mentions { - if u.ID == g.cfg.AppID { + if u != nil && u.ID == selfID { return true } } return false } -// stripMention removes a leading <@id> / <@!id> mention token from content. -func stripDiscordMention(content, appID string) string { - content = strings.TrimSpace(content) - for _, tok := range []string{"<@" + appID + ">", "<@!" + appID + ">"} { - if strings.HasPrefix(content, tok) { - return strings.TrimSpace(content[len(tok):]) +// handleMessage applies the pairing/allowlist policy and forwards authorized +// prompts to the daemon. It is transport-agnostic: send delivers the reply, so +// the policy/forwarding logic is unit-testable without a live Gateway session. +func (g *DiscordGateway) handleMessage(ctx context.Context, senderID, channelID, text string, send func(string) error) { + reply := func(content string) { + if err := send(content); err != nil { + slog.Error("discord reply failed", "channel", channelID, "error", err) } } - return content -} - -func (g *DiscordGateway) handleMessage(ctx context.Context, m discordMessage) { - if m.Author.Bot { - return // ignore other bots / our own echoes - } - if !g.mentionsBot(m) { - return - } - text := stripDiscordMention(m.Content, g.cfg.AppID) - sender := m.Author.ID - if isPair, ok := g.auth.tryPair(sender, text); isPair { + if isPair, ok := g.auth.tryPair(senderID, text); isPair { if ok { - _ = g.postMessage(ctx, m.ChannelID, "Paired. You can now chat with hawk.") + reply("Paired. You can now chat with hawk.") } else { - _ = g.postMessage(ctx, m.ChannelID, "Pairing failed: invalid code.") + reply("Pairing failed: invalid code.") } return } - if !g.auth.allowed(sender) { - _ = g.postMessage(ctx, m.ChannelID, "Unauthorized. Send /pair to authorize.") + if !g.auth.allowed(senderID) { + reply("Unauthorized. Send /pair to authorize.") return } - reply, err := forwardToHawk(ctx, g.client, g.daemonAddr, g.apiKey, text) - if err != nil { - reply = fmt.Sprintf("Error: %v", err) - } - if len(reply) > 1900 { // Discord 2000-char limit, leave headroom - reply = reply[:1900] + "\n\n... (truncated)" - } - _ = g.postMessage(ctx, m.ChannelID, reply) -} - -// postMessage sends a message to a Discord channel (in-thread reply context). -func (g *DiscordGateway) postMessage(ctx context.Context, channelID, content string) error { - if g.cfg.Token == "" || channelID == "" { - return nil - } - body, _ := json.Marshal(map[string]string{"content": content}) - apiURL := fmt.Sprintf("%s/channels/%s/messages", discordAPIBase, channelID) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + resp, err := forwardToHawk(ctx, g.client, g.daemonAddr, g.apiKey, text) if err != nil { - return err + resp = fmt.Sprintf("Error: %v", err) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bot "+g.cfg.Token) - resp, err := g.client.Do(req) - if err != nil { - return err + if len(resp) > 1900 { // Discord 2000-char limit, leave headroom + resp = resp[:1900] + "\n\n... (truncated)" } - _ = resp.Body.Close() - return nil + reply(resp) } -// fetchMessagesREST retrieves channel messages via the Discord REST API. -func (g *DiscordGateway) fetchMessagesREST(ctx context.Context, channelID, afterID string) ([]discordMessage, error) { - apiURL := fmt.Sprintf("%s/channels/%s/messages?limit=20", discordAPIBase, channelID) - if afterID != "" { - apiURL += "&after=" + afterID - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bot "+g.cfg.Token) - resp, err := g.client.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return nil, fmt.Errorf("discord messages: HTTP %d: %s", resp.StatusCode, string(data)) +// stripDiscordMention removes a leading <@id> / <@!id> mention token from content. +func stripDiscordMention(content, selfID string) string { + content = strings.TrimSpace(content) + if selfID == "" { + return content } - var msgs []discordMessage - if err := json.NewDecoder(resp.Body).Decode(&msgs); err != nil { - return nil, err + for _, tok := range []string{"<@" + selfID + ">", "<@!" + selfID + ">"} { + if strings.HasPrefix(content, tok) { + return strings.TrimSpace(content[len(tok):]) + } } - return msgs, nil + return content } diff --git a/internal/daemon/gateway.go b/internal/daemon/gateway.go index 8dcd004a..739537a1 100644 --- a/internal/daemon/gateway.go +++ b/internal/daemon/gateway.go @@ -77,12 +77,13 @@ type TelegramConfig struct { AllowList []string `json:"allow_list,omitempty"` } -// DiscordConfig configures the Discord gateway. +// DiscordConfig configures the Discord Gateway (WebSocket) bridge. The bot's own +// user ID is learned from the Gateway READY event, so no application/channel IDs +// need to be configured; the bot responds to DMs and to @mentions in any guild +// channel it can see. type DiscordConfig struct { - Enabled bool `json:"enabled,omitempty"` - Token string `json:"token,omitempty"` // bot token - // AppID is the application/bot user ID, used to detect @mentions. - AppID string `json:"app_id,omitempty"` + Enabled bool `json:"enabled,omitempty"` + Token string `json:"token,omitempty"` // bot token PairingCode string `json:"pairing_code,omitempty"` AllowList []string `json:"allow_list,omitempty"` } @@ -157,6 +158,48 @@ func (a *authorizer) tryPair(senderID, text string) (isPair, ok bool) { return true, true } +// asyncDispatcher runs message handlers in bounded goroutines and tracks them so +// a gateway can drain in-flight work on shutdown. It replaces ad-hoc `go f()` +// spawns that had no concurrency limit and were not tied to the gateway lifecycle. +type asyncDispatcher struct { + sem chan struct{} + wg sync.WaitGroup +} + +// newAsyncDispatcher builds a dispatcher allowing at most max concurrent handlers. +func newAsyncDispatcher(max int) *asyncDispatcher { + if max <= 0 { + max = 8 + } + return &asyncDispatcher{sem: make(chan struct{}, max)} +} + +// run executes fn in a bounded goroutine unless ctx is already done. It blocks +// only to acquire a concurrency slot (respecting ctx cancellation), then runs fn +// asynchronously. Handlers are tracked so wait can drain them on shutdown. +func (d *asyncDispatcher) run(ctx context.Context, fn func()) { + select { + case <-ctx.Done(): + return + case d.sem <- struct{}{}: + } + d.wg.Add(1) + go func() { + defer d.wg.Done() + defer func() { <-d.sem }() + fn() + }() +} + +// wait blocks until all dispatched handlers have finished. +func (d *asyncDispatcher) wait() { d.wg.Wait() } + +// daemonURLSetter is implemented by poll/forward gateways whose forward target is +// only known once the daemon's real listening address resolves (port 0 at Start). +type daemonURLSetter interface { + setDaemonURL(url string) +} + // gatewayManager owns the lifecycle of all enabled gateways for a daemon. type gatewayManager struct { mu sync.Mutex @@ -190,13 +233,8 @@ func (m *gatewayManager) setDaemonURL(url string) { m.mu.Lock() defer m.mu.Unlock() for _, g := range m.gateways { - switch gw := g.(type) { - case *TelegramGateway: - gw.DaemonAddr = url - case *DiscordGateway: - gw.daemonAddr = url - case *SlackGateway: - gw.daemonAddr = url + if s, ok := g.(daemonURLSetter); ok { + s.setDaemonURL(url) } } } diff --git a/internal/daemon/gateway_test.go b/internal/daemon/gateway_test.go index 3d81cc82..ffe340b7 100644 --- a/internal/daemon/gateway_test.go +++ b/internal/daemon/gateway_test.go @@ -7,14 +7,35 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "io" + "net" "net/http" "net/http/httptest" + "os" "strconv" + "strings" "testing" "time" + + "github.com/bwmarrin/discordgo" ) +func newIPv4GatewayServer(t *testing.T, h http.Handler) *httptest.Server { + t.Helper() + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + if errors.Is(err, os.ErrPermission) || strings.Contains(err.Error(), "operation not permitted") { + t.Skipf("sandbox does not allow local listeners: %v", err) + } + t.Fatalf("listen tcp4: %v", err) + } + srv := httptest.NewUnstartedServer(h) + srv.Listener = ln + srv.Start() + return srv +} + // Compile-time assertions that all three gateways satisfy Gateway. var ( _ Gateway = (*TelegramGateway)(nil) @@ -192,7 +213,7 @@ func TestAuthorizer_PairingAndAllowlist(t *testing.T) { func TestForwardToHawk(t *testing.T) { var gotAuth, gotPrompt string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := newIPv4GatewayServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") var body struct { Prompt string `json:"prompt"` @@ -346,59 +367,85 @@ func TestSlackGateway_RejectsBadSignature(t *testing.T) { } func TestDiscordGateway_HandleMessage_FlowsThroughAllowlist(t *testing.T) { - var posted []string - // Mock Discord REST for postMessage. - mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var body struct { - Content string `json:"content"` - } - _ = decodeForTest(r, &body) - posted = append(posted, body.Content) - w.WriteHeader(http.StatusOK) + hawk := newIPv4GatewayServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, ChatResponse{Response: "hawk-reply"}) })) - defer mock.Close() - old := discordAPIBase - discordAPIBase = mock.URL - defer func() { discordAPIBase = old }() - - g := newDiscordGateway(DiscordConfig{Token: "bot", AppID: "BOT", PairingCode: "code"}, "http://x", "") - - // Unauthorized non-pair mention -> "Unauthorized" reply. - g.handleMessage(context.Background(), discordMessage{ - ID: "1", ChannelID: "C", Content: "<@BOT> hello", GuildID: "G", - Author: discordUser{ID: "u1"}, - Mentions: []discordUser{{ID: "BOT"}}, - }) - // Pair, then it should be authorized. - g.handleMessage(context.Background(), discordMessage{ - ID: "2", ChannelID: "C", Content: "<@BOT> /pair code", GuildID: "G", - Author: discordUser{ID: "u1"}, - Mentions: []discordUser{{ID: "BOT"}}, - }) - - if len(posted) < 2 { - t.Fatalf("expected at least 2 posts, got %d: %v", len(posted), posted) + defer hawk.Close() + + g := newDiscordGateway(DiscordConfig{Token: "bot", PairingCode: "code"}, hawk.URL, "") + var sent []string + send := func(s string) error { sent = append(sent, s); return nil } + + // Unauthorized non-pair message -> "Unauthorized", no hawk call. + g.handleMessage(context.Background(), "u1", "C", "hello", send) + // Wrong pairing code -> failure. + g.handleMessage(context.Background(), "u1", "C", "/pair nope", send) + // Correct pairing code -> paired. + g.handleMessage(context.Background(), "u1", "C", "/pair code", send) + // Authorized -> forwarded to hawk. + g.handleMessage(context.Background(), "u1", "C", "do it", send) + + if len(sent) != 4 { + t.Fatalf("expected 4 sends, got %d: %v", len(sent), sent) + } + if !strings.Contains(sent[0], "Unauthorized") { + t.Errorf("sent[0]=%q want Unauthorized", sent[0]) } - if posted[0] == "" || posted[1] == "" { - t.Errorf("unexpected empty replies: %v", posted) + if !strings.Contains(sent[1], "failed") { + t.Errorf("sent[1]=%q want pairing failure", sent[1]) + } + if !strings.Contains(sent[2], "Paired") { + t.Errorf("sent[2]=%q want Paired", sent[2]) + } + if sent[3] != "hawk-reply" { + t.Errorf("sent[3]=%q want hawk-reply", sent[3]) } if !g.auth.allowed("u1") { t.Errorf("u1 should be allowed after pairing") } } -func TestDiscordGateway_IgnoresBotsAndNonMentions(t *testing.T) { - g := newDiscordGateway(DiscordConfig{Token: "bot", AppID: "BOT"}, "http://x", "") - // Bot author -> ignored (no panic, no post attempt path beyond guard). - g.handleMessage(context.Background(), discordMessage{ - ID: "1", ChannelID: "C", GuildID: "G", Author: discordUser{ID: "x", Bot: true}, - }) - // Guild message without mention -> ignored. - g.handleMessage(context.Background(), discordMessage{ - ID: "2", ChannelID: "C", GuildID: "G", Content: "no mention", Author: discordUser{ID: "y"}, - }) - if g.auth.allowed("x") || g.auth.allowed("y") { - t.Errorf("no sender should be allowed") +func TestWantsDiscordMessage(t *testing.T) { + const self = "BOT" + mc := func(authorID string, bot bool, guildID string, mentionIDs ...string) *discordgo.MessageCreate { + mentions := make([]*discordgo.User, 0, len(mentionIDs)) + for _, id := range mentionIDs { + mentions = append(mentions, &discordgo.User{ID: id}) + } + return &discordgo.MessageCreate{Message: &discordgo.Message{ + Author: &discordgo.User{ID: authorID, Bot: bot}, + GuildID: guildID, + Mentions: mentions, + }} + } + + cases := []struct { + name string + msg *discordgo.MessageCreate + want bool + }{ + {"dm accepted", mc("u1", false, ""), true}, + {"guild mention accepted", mc("u1", false, "G", self), true}, + {"guild no mention ignored", mc("u1", false, "G"), false}, + {"bot author ignored", mc("x", true, "G", self), false}, + {"self ignored", mc(self, false, ""), false}, + } + for _, tc := range cases { + if got := wantsDiscordMessage(tc.msg, self); got != tc.want { + t.Errorf("%s: wantsDiscordMessage=%v want %v", tc.name, got, tc.want) + } + } +} + +func TestStripDiscordMention(t *testing.T) { + for _, tc := range []struct{ in, want string }{ + {"<@BOT> hello", "hello"}, + {"<@!BOT> spaced ", "spaced"}, + {"no mention here", "no mention here"}, + } { + if got := stripDiscordMention(tc.in, "BOT"); got != tc.want { + t.Errorf("stripDiscordMention(%q)=%q want %q", tc.in, got, tc.want) + } } } diff --git a/internal/daemon/slack.go b/internal/daemon/slack.go index c52a68ee..1ac7993d 100644 --- a/internal/daemon/slack.go +++ b/internal/daemon/slack.go @@ -10,9 +10,11 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "strconv" "strings" + "sync" "time" ) @@ -39,6 +41,13 @@ type SlackGateway struct { auth *authorizer server *Server path string + dispatch *asyncDispatcher + + // mu guards runCtx, which is set to the Start context so webhook-triggered + // handlers run under (and are cancelled by) the gateway lifecycle rather than + // a detached context.Background(). + mu sync.Mutex + runCtx context.Context // now is overridable for deterministic signature-skew tests. now func() time.Time @@ -58,6 +67,8 @@ func newSlackGateway(cfg SlackConfig, daemonAddr, apiKey string, s *Server) *Sla auth: newAuthorizer(cfg.PairingCode, cfg.AllowList), server: s, path: path, + dispatch: newAsyncDispatcher(8), + runCtx: context.Background(), now: time.Now, } if s != nil { @@ -69,16 +80,33 @@ func newSlackGateway(cfg SlackConfig, daemonAddr, apiKey string, s *Server) *Sla // Name implements Gateway. func (g *SlackGateway) Name() string { return "slack" } +// setDaemonURL implements daemonURLSetter. Slack forwards to the daemon via +// forwardToHawk, so it needs the resolved address even though it replies via the +// Slack Web API. +func (g *SlackGateway) setDaemonURL(url string) { g.daemonAddr = url } + // Start implements Gateway. The webhook is already registered at construction; we -// simply block until the daemon shuts the gateway down. +// record the lifecycle context (so handlers are cancelled on shutdown), block +// until the daemon stops the gateway, then drain in-flight handlers. func (g *SlackGateway) Start(ctx context.Context) error { + g.mu.Lock() + g.runCtx = ctx + g.mu.Unlock() <-ctx.Done() + g.dispatch.wait() return ctx.Err() } // Stop implements Gateway. func (g *SlackGateway) Stop() error { return nil } +// handlerContext returns the current lifecycle context for webhook-triggered work. +func (g *SlackGateway) handlerContext() context.Context { + g.mu.Lock() + defer g.mu.Unlock() + return g.runCtx +} + // slackEnvelope is the outer Events API payload. type slackEnvelope struct { Type string `json:"type"` @@ -124,7 +152,9 @@ func (g *SlackGateway) handleEvents(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) if env.Type == "event_callback" && env.Event.Type == "app_mention" && env.Event.BotID == "" { - go g.handleMention(context.Background(), env.Event) + ctx := g.handlerContext() + ev := env.Event + g.dispatch.run(ctx, func() { g.handleMention(ctx, ev) }) } } @@ -137,14 +167,14 @@ func (g *SlackGateway) handleMention(ctx context.Context, ev slackEventInner) { if isPair, ok := g.auth.tryPair(ev.User, text); isPair { if ok { - _ = g.postMessage(ctx, ev.Channel, threadTS, "Paired. You can now chat with hawk.") + g.reply(ctx, ev.Channel, threadTS, "Paired. You can now chat with hawk.") } else { - _ = g.postMessage(ctx, ev.Channel, threadTS, "Pairing failed: invalid code.") + g.reply(ctx, ev.Channel, threadTS, "Pairing failed: invalid code.") } return } if !g.auth.allowed(ev.User) { - _ = g.postMessage(ctx, ev.Channel, threadTS, "Unauthorized. Send /pair to authorize.") + g.reply(ctx, ev.Channel, threadTS, "Unauthorized. Send /pair to authorize.") return } @@ -152,7 +182,14 @@ func (g *SlackGateway) handleMention(ctx context.Context, ev slackEventInner) { if err != nil { reply = fmt.Sprintf("Error: %v", err) } - _ = g.postMessage(ctx, ev.Channel, threadTS, reply) + g.reply(ctx, ev.Channel, threadTS, reply) +} + +// reply posts a threaded message and logs (rather than swallows) any failure. +func (g *SlackGateway) reply(ctx context.Context, channel, threadTS, text string) { + if err := g.postMessage(ctx, channel, threadTS, text); err != nil { + slog.Error("slack postMessage failed", "channel", channel, "error", err) + } } // stripSlackMention removes a leading <@U...> mention from the message text. @@ -218,6 +255,22 @@ func (g *SlackGateway) postMessage(ctx context.Context, channel, threadTS, text if err != nil { return err } - _ = resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + // Slack returns HTTP 200 with {"ok":false,"error":"..."} on logical failures + // (bad token, not_in_channel, ...), so the status code alone is insufficient. + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + var result struct { + OK bool `json:"ok"` + Error string `json:"error"` + } + if err := json.Unmarshal(body, &result); err != nil { + if resp.StatusCode/100 != 2 { + return fmt.Errorf("slack chat.postMessage: HTTP %d", resp.StatusCode) + } + return nil + } + if !result.OK { + return fmt.Errorf("slack chat.postMessage: %s", result.Error) + } return nil } diff --git a/internal/daemon/telegram.go b/internal/daemon/telegram.go index 4f0d70ef..d9ff1361 100644 --- a/internal/daemon/telegram.go +++ b/internal/daemon/telegram.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "net/url" "strings" @@ -22,9 +23,11 @@ type TelegramGateway struct { // apiKey, when set, is sent as a Bearer token on forwarded daemon requests. apiKey string - // auth, when set, enforces a pairing-code/allowlist policy. When nil, all - // senders are permitted (legacy behaviour for the bare constructor). - auth *authorizer + // auth enforces a pairing-code/allowlist policy. It is always non-nil; the + // bare constructor seeds an empty (fail-closed) authorizer that refuses all + // senders until a pairing code or allowlist is configured. + auth *authorizer + dispatch *asyncDispatcher } // TelegramUpdate represents an incoming Telegram message. @@ -45,12 +48,16 @@ type TelegramMessage struct { } `json:"from"` } -// NewTelegramGateway creates a gateway with the given bot token. +// NewTelegramGateway creates a gateway with the given bot token. The authorizer +// is seeded empty and therefore fails closed: no sender is authorized until a +// pairing code or allowlist is supplied (see newTelegramGatewayFromConfig). func NewTelegramGateway(token, daemonAddr string) *TelegramGateway { return &TelegramGateway{ Token: token, DaemonAddr: daemonAddr, client: &http.Client{Timeout: 30 * time.Second}, + auth: newAuthorizer("", nil), + dispatch: newAsyncDispatcher(8), } } @@ -66,6 +73,9 @@ func newTelegramGatewayFromConfig(cfg TelegramConfig, daemonAddr, apiKey string) // Name implements Gateway. func (tg *TelegramGateway) Name() string { return "telegram" } +// setDaemonURL implements daemonURLSetter. +func (tg *TelegramGateway) setDaemonURL(url string) { tg.DaemonAddr = url } + // Start implements Gateway by delegating to the long-poll Run loop. func (tg *TelegramGateway) Start(ctx context.Context) error { return tg.Run(ctx) } @@ -82,27 +92,51 @@ func telegramSenderID(msg *TelegramMessage) string { return fmt.Sprintf("%d", msg.Chat.ID) } -// Run starts the long-polling loop. Blocks until context is cancelled. +// Run starts the long-polling loop. Blocks until context is cancelled, then +// drains any in-flight message handlers before returning. func (tg *TelegramGateway) Run(ctx context.Context) error { + const ( + baseBackoff = time.Second + maxBackoff = 30 * time.Second + ) + backoff := baseBackoff for { select { case <-ctx.Done(): + tg.dispatch.wait() return ctx.Err() default: } updates, err := tg.getUpdates(ctx) if err != nil { - time.Sleep(5 * time.Second) + if ctx.Err() != nil { + tg.dispatch.wait() + return ctx.Err() + } + slog.Warn("telegram getUpdates failed; backing off", "backoff", backoff, "error", err) + // Context-aware sleep with exponential backoff so a bad token or + // outage does not hot-loop and Stop is honored promptly. + select { + case <-ctx.Done(): + tg.dispatch.wait() + return ctx.Err() + case <-time.After(backoff): + } + if backoff *= 2; backoff > maxBackoff { + backoff = maxBackoff + } continue } + backoff = baseBackoff for _, u := range updates { tg.offset = u.UpdateID + 1 if u.Message == nil || u.Message.Text == "" { continue } - go tg.handleMessage(ctx, u.Message) + msg := u.Message + tg.dispatch.run(ctx, func() { tg.handleMessage(ctx, msg) }) } } } @@ -120,30 +154,32 @@ func (tg *TelegramGateway) getUpdates(ctx context.Context) ([]TelegramUpdate, er defer func() { _ = resp.Body.Close() }() var result struct { - OK bool `json:"ok"` - Result []TelegramUpdate `json:"result"` + OK bool `json:"ok"` + Description string `json:"description"` + Result []TelegramUpdate `json:"result"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, err } + if !result.OK { + return nil, fmt.Errorf("telegram getUpdates: %s", result.Description) + } return result.Result, nil } func (tg *TelegramGateway) handleMessage(ctx context.Context, msg *TelegramMessage) { - if tg.auth != nil { - sender := telegramSenderID(msg) - if isPair, ok := tg.auth.tryPair(sender, msg.Text); isPair { - if ok { - _ = tg.sendMessage(ctx, msg.Chat.ID, "Paired. You can now chat with hawk.") - } else { - _ = tg.sendMessage(ctx, msg.Chat.ID, "Pairing failed: invalid code.") - } - return - } - if !tg.auth.allowed(sender) { - _ = tg.sendMessage(ctx, msg.Chat.ID, "Unauthorized. Send /pair to authorize.") - return + sender := telegramSenderID(msg) + if isPair, ok := tg.auth.tryPair(sender, msg.Text); isPair { + if ok { + tg.reply(ctx, msg.Chat.ID, "Paired. You can now chat with hawk.") + } else { + tg.reply(ctx, msg.Chat.ID, "Pairing failed: invalid code.") } + return + } + if !tg.auth.allowed(sender) { + tg.reply(ctx, msg.Chat.ID, "Unauthorized. Send /pair to authorize.") + return } // Forward to hawk daemon @@ -157,7 +193,14 @@ func (tg *TelegramGateway) handleMessage(ctx context.Context, msg *TelegramMessa response = string([]rune(response)[:4000]) + "\n\n... (truncated)" } - _ = tg.sendMessage(ctx, msg.Chat.ID, response) + tg.reply(ctx, msg.Chat.ID, response) +} + +// reply sends text and logs (rather than swallows) any delivery failure. +func (tg *TelegramGateway) reply(ctx context.Context, chatID int64, text string) { + if err := tg.sendMessage(ctx, chatID, text); err != nil { + slog.Error("telegram sendMessage failed", "chat", chatID, "error", err) + } } func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (string, error) { @@ -211,6 +254,21 @@ func (tg *TelegramGateway) sendMessage(ctx context.Context, chatID int64, text s if err != nil { return err } - _ = resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + var result struct { + OK bool `json:"ok"` + Description string `json:"description"` + } + if err := json.Unmarshal(body, &result); err != nil { + // Non-JSON body: fall back to the HTTP status for a useful error. + if resp.StatusCode/100 != 2 { + return fmt.Errorf("telegram sendMessage: HTTP %d", resp.StatusCode) + } + return nil + } + if !result.OK { + return fmt.Errorf("telegram sendMessage: %s", result.Description) + } return nil } diff --git a/internal/daemon/telegram_test.go b/internal/daemon/telegram_test.go index d9fbf7e8..182e8c19 100644 --- a/internal/daemon/telegram_test.go +++ b/internal/daemon/telegram_test.go @@ -2,20 +2,38 @@ package daemon import ( "context" + "errors" "io" + "net" "net/http" "net/http/httptest" + "os" "strings" "sync" "testing" ) +func newIPv4TelegramServer(t *testing.T, h http.Handler) *httptest.Server { + t.Helper() + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + if errors.Is(err, os.ErrPermission) || strings.Contains(err.Error(), "operation not permitted") { + t.Skipf("sandbox does not allow local listeners: %v", err) + } + t.Fatalf("listen tcp4: %v", err) + } + srv := httptest.NewUnstartedServer(h) + srv.Listener = ln + srv.Start() + return srv +} + // telegramMockAPI captures sendMessage calls and serves a fake Telegram + hawk // endpoint. The Telegram bot API base is hardcoded in telegram.go, so we instead // drive handleMessage directly and observe outbound sends via a mock daemon. func TestTelegram_HandleMessage_Authorization(t *testing.T) { // Mock hawk daemon /v1/chat. - hawk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hawk := newIPv4TelegramServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, ChatResponse{Response: "hawk-reply"}) })) defer hawk.Close() @@ -92,9 +110,12 @@ func TestTelegramSenderID(t *testing.T) { } } -func TestTelegram_NilAuthAllowsAll(t *testing.T) { - // The bare constructor leaves auth nil: legacy behaviour, all senders pass. - hawk := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func TestTelegram_BareConstructorFailsClosed(t *testing.T) { + // The bare constructor seeds an empty authorizer, so it must refuse all + // senders (no pairing code / allowlist) rather than forwarding to hawk. + hawkCalled := false + hawk := newIPv4TelegramServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hawkCalled = true writeJSON(w, http.StatusOK, ChatResponse{Response: "ok"}) })) defer hawk.Close() @@ -113,8 +134,11 @@ func TestTelegram_NilAuthAllowsAll(t *testing.T) { m := &TelegramMessage{Text: "hi"} m.Chat.ID = 1 tg.handleMessage(context.Background(), m) - if len(sends) != 1 || sends[0] != "ok" { - t.Fatalf("expected forwarded reply 'ok', got %v", sends) + if len(sends) != 1 || !strings.Contains(sends[0], "Unauthorized") { + t.Fatalf("expected an Unauthorized reply, got %v", sends) + } + if hawkCalled { + t.Fatal("bare constructor must not forward unauthorized messages to hawk") } }