Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
Expand Down
248 changes: 106 additions & 142 deletions internal/daemon/discord.go
Original file line number Diff line number Diff line change
@@ -1,61 +1,32 @@
package daemon

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"

"github.com/bwmarrin/discordgo"
)

// discordAPIBase is the Discord REST API root. It is a package var so tests can
// point it at a mock server.
var discordAPIBase = "https://discord.com/api/v10"

// DiscordGateway bridges hawk to Discord. Because no WebSocket library is present
// in go.mod, the gateway uses a REST long-poll-equivalent strategy: it does not
// open the Discord Gateway socket but instead relies on a poll over the bot's
// accessible channels for new @mentions/DMs. This keeps the bridge dependency-free
// while remaining bidirectional (it posts replies in-thread via the REST API).
//
// The polling transport is intentionally minimal and pluggable: fetchMessages is
// a field so a real Gateway-WebSocket implementation (or a test) can be swapped
// in without changing the message-handling logic.
// DiscordGateway bridges hawk to Discord via the official Gateway (WebSocket),
// using bwmarrin/discordgo. Unlike a REST-poll bridge it receives message events
// in real time — guild @mentions and direct messages — with reconnection,
// heartbeats, and rate limiting handled by the library. Authorized prompts are
// forwarded to the daemon's /v1/chat endpoint and the reply is posted back to the
// originating channel.
type DiscordGateway struct {
cfg DiscordConfig
daemonAddr string
apiKey string
client *http.Client
client *http.Client // for forwardToHawk
auth *authorizer
dispatch *asyncDispatcher

// pollChannels lists channel IDs the bot watches. When empty the gateway
// discovers DM channels lazily as it sees them via fetchMessages.
pollChannels []string
// lastSeen maps channelID -> last processed message ID for poll dedup.
lastSeen map[string]string

// fetchMessages retrieves new messages for a channel since the given message
// ID. Overridable for tests. Default uses the Discord REST API.
fetchMessages func(ctx context.Context, channelID, afterID string) ([]discordMessage, error)
}

// discordMessage is the subset of a Discord message object we use.
type discordMessage struct {
ID string `json:"id"`
ChannelID string `json:"channel_id"`
Content string `json:"content"`
Author discordUser `json:"author"`
Mentions []discordUser `json:"mentions"`
GuildID string `json:"guild_id,omitempty"`
}

type discordUser struct {
ID string `json:"id"`
Username string `json:"username"`
Bot bool `json:"bot,omitempty"`
// openSession is overridable in tests; production opens a real Gateway socket.
openSession func(ctx context.Context) (*discordgo.Session, error)
}

// newDiscordGateway builds a Discord gateway from config.
Expand All @@ -66,147 +37,140 @@ func newDiscordGateway(cfg DiscordConfig, daemonAddr, apiKey string) *DiscordGat
apiKey: apiKey,
client: &http.Client{Timeout: 30 * time.Second},
auth: newAuthorizer(cfg.PairingCode, cfg.AllowList),
lastSeen: make(map[string]string),
dispatch: newAsyncDispatcher(8),
}
g.fetchMessages = g.fetchMessagesREST
g.openSession = g.openGatewaySession
return g
}

// Name implements Gateway.
func (g *DiscordGateway) Name() string { return "discord" }

// Stop implements Gateway. Polling is driven by the Start context.
// setDaemonURL implements daemonURLSetter.
func (g *DiscordGateway) setDaemonURL(url string) { g.daemonAddr = url }

// Stop implements Gateway. The Gateway connection is driven by the Start context.
func (g *DiscordGateway) Stop() error { return nil }

// Start implements Gateway: poll watched channels until ctx is cancelled.
// Start opens the Discord Gateway, registers the message handler, and blocks
// until ctx is cancelled, then drains in-flight handlers and closes the socket.
func (g *DiscordGateway) Start(ctx context.Context) error {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
g.pollOnce(ctx)
}
s, err := g.openSession(ctx)
if err != nil {
return err
}
s.AddHandler(func(sess *discordgo.Session, m *discordgo.MessageCreate) {
g.onMessageCreate(ctx, sess, m)
})
if err := s.Open(); err != nil {
return fmt.Errorf("discord gateway open: %w", err)
}
defer func() { _ = s.Close() }()

<-ctx.Done()
g.dispatch.wait()
return ctx.Err()
}

func (g *DiscordGateway) pollOnce(ctx context.Context) {
for _, ch := range g.pollChannels {
msgs, err := g.fetchMessages(ctx, ch, g.lastSeen[ch])
if err != nil {
continue
}
for _, m := range msgs {
g.lastSeen[m.ChannelID] = m.ID
g.handleMessage(ctx, m)
}
// openGatewaySession creates a discordgo session configured with the intents the
// bridge needs: guild messages, message content (privileged — enable it in the
// Discord developer portal), and direct messages.
func (g *DiscordGateway) openGatewaySession(_ context.Context) (*discordgo.Session, error) {
s, err := discordgo.New("Bot " + g.cfg.Token)
if err != nil {
return nil, fmt.Errorf("discord session: %w", err)
}
s.Identify.Intents = discordgo.IntentsGuildMessages |
discordgo.IntentMessageContent |
discordgo.IntentsDirectMessages
return s, nil
}

// mentionsBot reports whether the message is a DM or @mentions the bot user.
func (g *DiscordGateway) mentionsBot(m discordMessage) bool {
// DM: no guild.
// onMessageCreate filters inbound events and dispatches authorized messages to
// the shared handler. Replies go back through the live Gateway session.
func (g *DiscordGateway) onMessageCreate(ctx context.Context, s *discordgo.Session, m *discordgo.MessageCreate) {
selfID := ""
if s.State != nil && s.State.User != nil {
selfID = s.State.User.ID
}
if !wantsDiscordMessage(m, selfID) {
return
}
text := stripDiscordMention(m.Content, selfID)
channelID := m.ChannelID
send := func(content string) error {
_, err := s.ChannelMessageSend(channelID, content)
return err
}
g.dispatch.run(ctx, func() {
g.handleMessage(ctx, m.Author.ID, channelID, text, send)
})
}

// wantsDiscordMessage reports whether a message should be processed: skip our own
// messages and other bots; accept direct messages and any message that @mentions
// the bot. selfID is the bot's own user ID (empty until the READY event lands).
func wantsDiscordMessage(m *discordgo.MessageCreate, selfID string) bool {
if m == nil || m.Author == nil || m.Author.Bot {
return false
}
if selfID != "" && m.Author.ID == selfID {
return false
}
if m.GuildID == "" {
return true
return true // direct message
}
for _, u := range m.Mentions {
if u.ID == g.cfg.AppID {
if u != nil && u.ID == selfID {
return true
}
}
return false
}

// stripMention removes a leading <@id> / <@!id> mention token from content.
func stripDiscordMention(content, appID string) string {
content = strings.TrimSpace(content)
for _, tok := range []string{"<@" + appID + ">", "<@!" + appID + ">"} {
if strings.HasPrefix(content, tok) {
return strings.TrimSpace(content[len(tok):])
// handleMessage applies the pairing/allowlist policy and forwards authorized
// prompts to the daemon. It is transport-agnostic: send delivers the reply, so
// the policy/forwarding logic is unit-testable without a live Gateway session.
func (g *DiscordGateway) handleMessage(ctx context.Context, senderID, channelID, text string, send func(string) error) {
reply := func(content string) {
if err := send(content); err != nil {
slog.Error("discord reply failed", "channel", channelID, "error", err)
}
}
return content
}

func (g *DiscordGateway) handleMessage(ctx context.Context, m discordMessage) {
if m.Author.Bot {
return // ignore other bots / our own echoes
}
if !g.mentionsBot(m) {
return
}
text := stripDiscordMention(m.Content, g.cfg.AppID)
sender := m.Author.ID

if isPair, ok := g.auth.tryPair(sender, text); isPair {
if isPair, ok := g.auth.tryPair(senderID, text); isPair {
if ok {
_ = g.postMessage(ctx, m.ChannelID, "Paired. You can now chat with hawk.")
reply("Paired. You can now chat with hawk.")
} else {
_ = g.postMessage(ctx, m.ChannelID, "Pairing failed: invalid code.")
reply("Pairing failed: invalid code.")
}
return
}
if !g.auth.allowed(sender) {
_ = g.postMessage(ctx, m.ChannelID, "Unauthorized. Send /pair <code> to authorize.")
if !g.auth.allowed(senderID) {
reply("Unauthorized. Send /pair <code> to authorize.")
return
}

reply, err := forwardToHawk(ctx, g.client, g.daemonAddr, g.apiKey, text)
if err != nil {
reply = fmt.Sprintf("Error: %v", err)
}
if len(reply) > 1900 { // Discord 2000-char limit, leave headroom
reply = reply[:1900] + "\n\n... (truncated)"
}
_ = g.postMessage(ctx, m.ChannelID, reply)
}

// postMessage sends a message to a Discord channel (in-thread reply context).
func (g *DiscordGateway) postMessage(ctx context.Context, channelID, content string) error {
if g.cfg.Token == "" || channelID == "" {
return nil
}
body, _ := json.Marshal(map[string]string{"content": content})
apiURL := fmt.Sprintf("%s/channels/%s/messages", discordAPIBase, channelID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
resp, err := forwardToHawk(ctx, g.client, g.daemonAddr, g.apiKey, text)
if err != nil {
return err
resp = fmt.Sprintf("Error: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bot "+g.cfg.Token)
resp, err := g.client.Do(req)
if err != nil {
return err
if len(resp) > 1900 { // Discord 2000-char limit, leave headroom
resp = resp[:1900] + "\n\n... (truncated)"
}
_ = resp.Body.Close()
return nil
reply(resp)
}

// fetchMessagesREST retrieves channel messages via the Discord REST API.
func (g *DiscordGateway) fetchMessagesREST(ctx context.Context, channelID, afterID string) ([]discordMessage, error) {
apiURL := fmt.Sprintf("%s/channels/%s/messages?limit=20", discordAPIBase, channelID)
if afterID != "" {
apiURL += "&after=" + afterID
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bot "+g.cfg.Token)
resp, err := g.client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("discord messages: HTTP %d: %s", resp.StatusCode, string(data))
// stripDiscordMention removes a leading <@id> / <@!id> mention token from content.
func stripDiscordMention(content, selfID string) string {
content = strings.TrimSpace(content)
if selfID == "" {
return content
}
var msgs []discordMessage
if err := json.NewDecoder(resp.Body).Decode(&msgs); err != nil {
return nil, err
for _, tok := range []string{"<@" + selfID + ">", "<@!" + selfID + ">"} {
if strings.HasPrefix(content, tok) {
return strings.TrimSpace(content[len(tok):])
}
}
return msgs, nil
return content
}
Loading
Loading