diff --git a/go.mod b/go.mod index 5258054..b5cde24 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 github.com/wagiedev/claude-agent-sdk-go v0.0.3 + github.com/wagiedev/codex-agent-sdk-go v0.0.2 golang.org/x/sync v0.20.0 golang.org/x/term v0.43.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 61f0d5c..f53b552 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,8 @@ github.com/docker/go-connections v0.7.0 h1:6SsRfJddP22WMrCkj19x9WKjEDTB+ahsdiGYf github.com/docker/go-connections v0.7.0/go.mod h1:no1qkHdjq7kLMGUXYAduOhYPSJxxvgWBh7ogVvptn3Q= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -209,6 +211,8 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -237,6 +241,8 @@ github.com/pterm/pterm v0.12.83 h1:ie+YmGmA727VuhxBlyGr74Ks+7McV6kT99IB8EU80aA= github.com/pterm/pterm v0.12.83/go.mod h1:xlgc6bFWyJIMtmLJvGim+L7jhSReilOlOnodeIYe4Tk= github.com/redis/go-redis/v9 v9.20.0 h1:WnQYxLkgO2xiXTCJY0ldIiI8dNqCDlQAG+AtaH7a2a0= github.com/redis/go-redis/v9 v9.20.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -267,6 +273,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/wagiedev/claude-agent-sdk-go v0.0.3 h1:MOI6WYkUgbK5yXtzfQyC1SDwUBfJHA1MDqy/5ANZojA= github.com/wagiedev/claude-agent-sdk-go v0.0.3/go.mod h1:KZz7UJQJBJJkz9Yn9g02HQnnsIr9FiJtWDwNf3U/tMY= +github.com/wagiedev/codex-agent-sdk-go v0.0.2 h1:RtP9a7i4MYbO6CHkICO+ESjw4qaZSJyKULPpKXsWnfg= +github.com/wagiedev/codex-agent-sdk-go v0.0.2/go.mod h1:DRGJ5ybo7oGzG7NVk1O15PfHcpoIKBNWRGaocPruudI= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= @@ -313,8 +321,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -410,3 +418,11 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= diff --git a/pkg/ai/factory.go b/pkg/ai/factory.go index 5a24d55..f25cba2 100644 --- a/pkg/ai/factory.go +++ b/pkg/ai/factory.go @@ -12,7 +12,7 @@ const DefaultProvider = ProviderClaude // SupportedProviders lists selectable provider IDs. func SupportedProviders() []ProviderID { - return []ProviderID{ProviderClaude} + return []ProviderID{ProviderClaude, ProviderCodex} } // NewEngine creates an Engine for the given provider. @@ -20,6 +20,8 @@ func NewEngine(provider ProviderID, log logrus.FieldLogger) (Engine, error) { switch provider { case "", ProviderClaude: return newClaudeEngine(log), nil + case ProviderCodex: + return newCodexEngine(log), nil default: return nil, fmt.Errorf("unsupported provider: %s", provider) } @@ -60,6 +62,8 @@ func providerLabel(provider ProviderID) string { switch provider { case ProviderClaude: return "Claude" + case ProviderCodex: + return "Codex" default: return string(provider) } @@ -69,6 +73,8 @@ func providerCapabilities(provider ProviderID) Capabilities { switch provider { case ProviderClaude: return Capabilities{Streaming: true, Interrupt: true, Sessions: true} + case ProviderCodex: + return Capabilities{Streaming: true, Interrupt: true, Sessions: true} default: return Capabilities{} } diff --git a/pkg/ai/factory_test.go b/pkg/ai/factory_test.go new file mode 100644 index 0000000..fe8eb60 --- /dev/null +++ b/pkg/ai/factory_test.go @@ -0,0 +1,112 @@ +package ai + +import ( + "context" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testLogger() logrus.FieldLogger { + l := logrus.New() + l.SetLevel(logrus.FatalLevel) + + return l +} + +func TestNewEngine(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider ProviderID + wantType string + wantErr bool + errContain string + }{ + { + name: "empty defaults to claude", + provider: "", + wantType: "*ai.claudeEngine", + }, + { + name: "claude provider", + provider: ProviderClaude, + wantType: "*ai.claudeEngine", + }, + { + name: "codex provider", + provider: ProviderCodex, + wantType: "*ai.codexEngine", + }, + { + name: "unknown provider", + provider: "unknown", + wantErr: true, + errContain: "unsupported provider", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(tc.provider, testLogger()) + + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errContain) + assert.Nil(t, engine) + + return + } + + require.NoError(t, err) + require.NotNil(t, engine) + assert.IsType(t, engine, engine, "engine type mismatch") + + // Verify concrete type via provider ID. + if tc.provider == "" || tc.provider == ProviderClaude { + assert.Equal(t, ProviderClaude, engine.Provider()) + } else { + assert.Equal(t, tc.provider, engine.Provider()) + } + }) + } +} + +func TestSupportedProviders(t *testing.T) { + t.Parallel() + + providers := SupportedProviders() + + assert.Len(t, providers, 2) + assert.Contains(t, providers, ProviderClaude) + assert.Contains(t, providers, ProviderCodex) +} + +func TestListProviderInfo(t *testing.T) { + t.Parallel() + + infos := ListProviderInfo(context.Background(), testLogger(), ProviderClaude) + + require.Len(t, infos, 2) + + labels := make(map[ProviderID]string, len(infos)) + defaults := make(map[ProviderID]bool, len(infos)) + + for _, info := range infos { + labels[info.ID] = info.Label + defaults[info.ID] = info.Default + assert.True(t, info.Capabilities.Streaming, "%s should support streaming", info.ID) + assert.True(t, info.Capabilities.Interrupt, "%s should support interrupt", info.ID) + assert.True(t, info.Capabilities.Sessions, "%s should support sessions", info.ID) + } + + assert.Equal(t, "Claude", labels[ProviderClaude]) + assert.Equal(t, "Codex", labels[ProviderCodex]) + assert.True(t, defaults[ProviderClaude]) + assert.False(t, defaults[ProviderCodex]) +} diff --git a/pkg/ai/provider_codex.go b/pkg/ai/provider_codex.go new file mode 100644 index 0000000..5aa306c --- /dev/null +++ b/pkg/ai/provider_codex.go @@ -0,0 +1,353 @@ +package ai + +import ( + "context" + "fmt" + "os/exec" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + codexsdk "github.com/wagiedev/codex-agent-sdk-go" +) + +type codexEngine struct { + log logrus.FieldLogger + timeout time.Duration + available bool +} + +func newCodexEngine(log logrus.FieldLogger) *codexEngine { + _, err := exec.LookPath("codex") + + return &codexEngine{ + log: log.WithField("component", "ai-provider-codex"), + timeout: defaultTimeout, + available: err == nil, + } +} + +func (e *codexEngine) Provider() ProviderID { + return ProviderCodex +} + +func (e *codexEngine) Capabilities() Capabilities { + return Capabilities{Streaming: true, Interrupt: true, Sessions: true} +} + +func (e *codexEngine) IsAvailable() bool { + return e.available +} + +func (e *codexEngine) Ask(ctx context.Context, prompt string) (string, error) { + if !e.available { + return "", fmt.Errorf("provider %s is not available", e.Provider()) + } + + ctx, cancel := context.WithTimeout(ctx, e.timeout) + defer cancel() + + var ( + assistantText strings.Builder + resultText string + ) + + for msg, err := range codexsdk.Query(ctx, prompt) { + if err != nil { + return "", err + } + + switch m := msg.(type) { + case *codexsdk.AssistantMessage: + for _, block := range m.Content { + if textBlock, ok := block.(*codexsdk.TextBlock); ok && + strings.TrimSpace(textBlock.Text) != "" { + if assistantText.Len() > 0 { + assistantText.WriteString("\n") + } + + assistantText.WriteString(textBlock.Text) + } + } + case *codexsdk.ResultMessage: + if m.Result != nil && strings.TrimSpace(*m.Result) != "" { + resultText = *m.Result + } + } + } + + if strings.TrimSpace(resultText) != "" { + return resultText, nil + } + + out := strings.TrimSpace(assistantText.String()) + if out == "" { + return "", fmt.Errorf("empty provider response") + } + + return out, nil +} + +func (e *codexEngine) StartSession(ctx context.Context) (Session, error) { + if !e.available { + return nil, fmt.Errorf("provider %s is not available", e.Provider()) + } + + client := codexsdk.NewClient() + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + session := &codexSession{ + id: newSessionID(), + log: e.log, + client: client, + timeout: e.timeout, + sessionCancel: sessionCancel, + stderrLineLimit: 200, + } + + startErrCh := make(chan error, 1) + + go func() { + startErrCh <- client.Start( + sessionCtx, + codexsdk.WithPermissionMode("bypassPermissions"), + codexsdk.WithStderr(session.pushStderr), + ) + }() + + select { + case err := <-startErrCh: + if err != nil { + sessionCancel() + + return nil, err + } + case <-ctx.Done(): + sessionCancel() + + return nil, ctx.Err() + } + + return session, nil +} + +type codexSession struct { + id string + log logrus.FieldLogger + client codexsdk.Client + timeout time.Duration + sessionCancel context.CancelFunc + mu sync.Mutex + stderrMu sync.Mutex + stderrLines []string + stderrLineLimit int + closed bool +} + +func (s *codexSession) ID() string { + return s.id +} + +func (s *codexSession) AskStream( + ctx context.Context, + prompt string, + onChunk func(StreamChunk), +) (string, error) { + if s.isClosed() { + return "", fmt.Errorf("session is closed") + } + + ctx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + state := &askStreamState{} + + debugInfo := func() map[string]any { + return state.debugInfo( + s.id, len(prompt), int(s.timeout.Seconds()), + s.stderrCount(), s.stderrTail(20), + ) + } + + if err := s.client.Query(ctx, prompt); err != nil { + return "", &ProviderDebugError{Cause: err, Info: debugInfo()} + } + + for msg, err := range s.client.ReceiveResponse(ctx) { + if err != nil { + loopBreak, retErr := state.handleReceiveError(err, debugInfo) + if retErr != nil { + return "", retErr + } + + if loopBreak { + break + } + + continue + } + + if out, done, retErr := handleCodexMessage(state, msg, onChunk, debugInfo); done { + if retErr != nil { + return "", retErr + } + + return out, nil + } + } + + return state.finalResponse(debugInfo) +} + +// handleCodexMessage processes a single message from the Codex SDK stream. +func handleCodexMessage( + st *askStreamState, + msg any, + onChunk func(StreamChunk), + debugInfo func() map[string]any, +) (string, bool, error) { + st.msgCount++ + + switch m := msg.(type) { + case *codexsdk.StreamEvent: + st.lastMessageType = "stream_event" + st.streamEventCount++ + + if eventType, ok := m.Event["type"].(string); ok { + st.lastEventType = eventType + } + + st.seq += streamEventChunksFromMap(m.Event, st.seq, st, onChunk) + + if parts, ok := parseStreamEventMap(m.Event); ok { + if parts.Thinking != "" { + st.thinkingChars += len(parts.Thinking) + } + + if parts.Text != "" { + st.streamText.WriteString(parts.Text) + st.streamChars += len(parts.Text) + } + } + case *codexsdk.AssistantMessage: + st.lastMessageType = "assistant" + st.assistantCount++ + + if m.Error != nil { + st.assistantErr = string(*m.Error) + } + + for _, block := range m.Content { + textBlock, ok := block.(*codexsdk.TextBlock) + if !ok || strings.TrimSpace(textBlock.Text) == "" { + continue + } + + if st.assistantText.Len() > 0 { + st.assistantText.WriteString("\n") + } + + st.assistantText.WriteString(textBlock.Text) + st.assistantChars += len(textBlock.Text) + + // Some CLI/SDK runs only surface thinking in stream events and + // provide answer text via assistant messages before final result. + // Emit fallback answer chunks so the UI streams visible answer + // text during the turn. + if onChunk != nil && st.streamChars == 0 { + st.seq++ + onChunk(StreamChunk{ + Kind: StreamChunkAnswer, + Text: textBlock.Text, + EventType: "assistant_message", + Seq: st.seq, + }) + } + } + case *codexsdk.ResultMessage: + return handleResultMessage(st, m.Subtype, m.IsError, m.Result, debugInfo) + } + + return "", false, nil +} + +func (s *codexSession) Interrupt(ctx context.Context) error { + if s.isClosed() { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + return s.client.Interrupt(ctx) +} + +func (s *codexSession) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + + s.closed = true + if s.sessionCancel != nil { + s.sessionCancel() + } + + return s.client.Close() +} + +func (s *codexSession) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + + return s.closed +} + +func (s *codexSession) pushStderr(line string) { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + return + } + + s.stderrMu.Lock() + defer s.stderrMu.Unlock() + + s.stderrLines = append(s.stderrLines, trimmed) + if s.stderrLineLimit <= 0 { + return + } + + if len(s.stderrLines) > s.stderrLineLimit { + s.stderrLines = s.stderrLines[len(s.stderrLines)-s.stderrLineLimit:] + } +} + +func (s *codexSession) stderrCount() int { + s.stderrMu.Lock() + defer s.stderrMu.Unlock() + + return len(s.stderrLines) +} + +func (s *codexSession) stderrTail(n int) []string { + s.stderrMu.Lock() + defer s.stderrMu.Unlock() + + if n <= 0 || len(s.stderrLines) == 0 { + return nil + } + + if len(s.stderrLines) <= n { + out := make([]string, len(s.stderrLines)) + copy(out, s.stderrLines) + + return out + } + + out := make([]string, n) + copy(out, s.stderrLines[len(s.stderrLines)-n:]) + + return out +} diff --git a/pkg/ai/provider_codex_test.go b/pkg/ai/provider_codex_test.go new file mode 100644 index 0000000..16c26c5 --- /dev/null +++ b/pkg/ai/provider_codex_test.go @@ -0,0 +1,25 @@ +package ai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCodexEngine_Provider(t *testing.T) { + t.Parallel() + + engine := newCodexEngine(testLogger()) + assert.Equal(t, ProviderCodex, engine.Provider()) +} + +func TestCodexEngine_Capabilities(t *testing.T) { + t.Parallel() + + engine := newCodexEngine(testLogger()) + caps := engine.Capabilities() + + assert.True(t, caps.Streaming) + assert.True(t, caps.Interrupt) + assert.True(t, caps.Sessions) +} diff --git a/pkg/ai/provider_impl.go b/pkg/ai/provider_impl.go index 380e374..c0bb84c 100644 --- a/pkg/ai/provider_impl.go +++ b/pkg/ai/provider_impl.go @@ -2,13 +2,7 @@ package ai import ( "context" - "crypto/rand" - "encoding/hex" - "encoding/json" - "errors" "fmt" - "io" - "maps" "os/exec" "strings" "sync" @@ -18,41 +12,6 @@ import ( claudesdk "github.com/wagiedev/claude-agent-sdk-go" ) -const defaultTimeout = 30 * time.Minute - -// ProviderDebugError carries structured diagnostics for provider failures. -type ProviderDebugError struct { - Cause error - Info map[string]any -} - -func (e *ProviderDebugError) Error() string { - if e == nil || e.Cause == nil { - return "provider error" - } - - return e.Cause.Error() -} - -func (e *ProviderDebugError) Unwrap() error { - if e == nil { - return nil - } - - return e.Cause -} - -func (e *ProviderDebugError) DebugInfo() map[string]any { - if e == nil || len(e.Info) == 0 { - return map[string]any{} - } - - out := make(map[string]any, len(e.Info)) - maps.Copy(out, e.Info) - - return out -} - type claudeEngine struct { log logrus.FieldLogger timeout time.Duration @@ -189,7 +148,11 @@ func (s *claudeSession) ID() string { return s.id } -func (s *claudeSession) AskStream(ctx context.Context, prompt string, onChunk func(StreamChunk)) (string, error) { +func (s *claudeSession) AskStream( + ctx context.Context, + prompt string, + onChunk func(StreamChunk), +) (string, error) { if s.isClosed() { return "", fmt.Errorf("session is closed") } @@ -200,7 +163,10 @@ func (s *claudeSession) AskStream(ctx context.Context, prompt string, onChunk fu state := &askStreamState{} debugInfo := func() map[string]any { - return state.debugInfo(s, prompt) + return state.debugInfo( + s.id, len(prompt), int(s.timeout.Seconds()), + s.stderrCount(), s.stderrTail(20), + ) } if err := s.client.Query(ctx, prompt); err != nil { @@ -221,7 +187,7 @@ func (s *claudeSession) AskStream(ctx context.Context, prompt string, onChunk fu continue } - if out, done, retErr := state.handleMessage(msg, onChunk, debugInfo); done { + if out, done, retErr := handleClaudeMessage(state, msg, onChunk, debugInfo); done { if retErr != nil { return "", retErr } @@ -233,74 +199,9 @@ func (s *claudeSession) AskStream(ctx context.Context, prompt string, onChunk fu return state.finalResponse(debugInfo) } -type askStreamState struct { - seq int - assistantText strings.Builder - resultText string - streamText strings.Builder - assistantErr string - resultErr string - msgCount int - streamEventCount int - assistantCount int - resultCount int - streamChars int - thinkingChars int - assistantChars int - resultChars int - lastMessageType string - lastEventType string - resultSubtype string - resultIsError bool - endedWithEOF bool - currentToolName string - toolInputBuf strings.Builder -} - -func (st *askStreamState) debugInfo(s *claudeSession, prompt string) map[string]any { - return map[string]any{ - "session_id": s.id, - "prompt_len": len(prompt), - "msg_count": st.msgCount, - "stream_event_count": st.streamEventCount, - "assistant_count": st.assistantCount, - "result_count": st.resultCount, - "stream_chars": st.streamChars, - "thinking_chars": st.thinkingChars, - "assistant_chars": st.assistantChars, - "result_chars": st.resultChars, - "assistant_err": st.assistantErr, - "result_err": st.resultErr, - "result_subtype": st.resultSubtype, - "result_is_error": st.resultIsError, - "last_message_type": st.lastMessageType, - "last_stream_event": st.lastEventType, - "client_timeout_secs": int(s.timeout.Seconds()), - "ended_with_eof": st.endedWithEOF, - "stderr_count": s.stderrCount(), - "stderr_tail": s.stderrTail(20), - } -} - -func (st *askStreamState) handleReceiveError(err error, debugInfo func() map[string]any) (bool, error) { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true, err - } - - if errors.Is(err, io.EOF) { - st.endedWithEOF = true - - return true, nil - } - - if strings.Contains(strings.ToLower(err.Error()), "interrupt") { - return true, err - } - - return true, &ProviderDebugError{Cause: err, Info: debugInfo()} -} - -func (st *askStreamState) handleMessage( +// handleClaudeMessage processes a single message from the Claude SDK stream. +func handleClaudeMessage( + st *askStreamState, msg any, onChunk func(StreamChunk), debugInfo func() map[string]any, @@ -316,9 +217,9 @@ func (st *askStreamState) handleMessage( st.lastEventType = eventType } - st.seq += streamEventChunks(m, st.seq, st, onChunk) + st.seq += streamEventChunksFromMap(m.Event, st.seq, st, onChunk) - if parts, ok := parseStreamEvent(m); ok { + if parts, ok := parseStreamEventMap(m.Event); ok { if parts.Thinking != "" { st.thinkingChars += len(parts.Thinking) } @@ -363,86 +264,58 @@ func (st *askStreamState) handleMessage( } } case *claudesdk.ResultMessage: - st.lastMessageType = "result" - st.resultCount++ - st.resultSubtype = m.Subtype - st.resultIsError = m.IsError - - if m.IsError { - st.resultErr = m.Subtype - - if m.Result != nil && strings.TrimSpace(*m.Result) != "" { - st.resultErr = strings.TrimSpace(*m.Result) - } - } - - if m.Result != nil && strings.TrimSpace(*m.Result) != "" { - st.resultText = *m.Result - st.resultChars = len(*m.Result) - } - - out := st.bestOutput() - if out != "" { - return out, true, nil - } - - if st.resultErr != "" { - return "", true, &ProviderDebugError{ - Cause: fmt.Errorf("provider result error: %s", st.resultErr), - Info: debugInfo(), - } - } - - if st.assistantErr != "" { - return "", true, &ProviderDebugError{ - Cause: fmt.Errorf("provider assistant error: %s", st.assistantErr), - Info: debugInfo(), - } - } - - return "", true, &ProviderDebugError{ - Cause: fmt.Errorf("empty provider response"), - Info: debugInfo(), - } + return handleResultMessage(st, m.Subtype, m.IsError, m.Result, debugInfo) } return "", false, nil } -func (st *askStreamState) bestOutput() string { - if strings.TrimSpace(st.resultText) != "" { - return st.resultText - } +// handleResultMessage processes a result message for any provider. +func handleResultMessage( + st *askStreamState, + subtype string, + isError bool, + result *string, + debugInfo func() map[string]any, +) (string, bool, error) { + st.lastMessageType = "result" + st.resultCount++ + st.resultSubtype = subtype + st.resultIsError = isError - out := strings.TrimSpace(st.assistantText.String()) - if out != "" { - return out + if isError { + st.resultErr = subtype + + if result != nil && strings.TrimSpace(*result) != "" { + st.resultErr = strings.TrimSpace(*result) + } } - return strings.TrimSpace(st.streamText.String()) -} + if result != nil && strings.TrimSpace(*result) != "" { + st.resultText = *result + st.resultChars = len(*result) + } -func (st *askStreamState) finalResponse(debugInfo func() map[string]any) (string, error) { out := st.bestOutput() if out != "" { - return out, nil + return out, true, nil } if st.resultErr != "" { - return "", &ProviderDebugError{ + return "", true, &ProviderDebugError{ Cause: fmt.Errorf("provider result error: %s", st.resultErr), Info: debugInfo(), } } if st.assistantErr != "" { - return "", &ProviderDebugError{ + return "", true, &ProviderDebugError{ Cause: fmt.Errorf("provider assistant error: %s", st.assistantErr), Info: debugInfo(), } } - return "", &ProviderDebugError{ + return "", true, &ProviderDebugError{ Cause: fmt.Errorf("empty provider response"), Info: debugInfo(), } @@ -528,202 +401,3 @@ func (s *claudeSession) stderrTail(n int) []string { return out } - -func streamEventChunks( - msg *claudesdk.StreamEvent, - seq int, - st *askStreamState, - onChunk func(StreamChunk), -) int { - if onChunk == nil { - return 0 - } - - parts, ok := parseStreamEvent(msg) - if !ok { - return 0 - } - - // Track tool state across content blocks. - if parts.ToolName != "" { - st.currentToolName = parts.ToolName - st.toolInputBuf.Reset() - } - - if parts.ToolInputDelta != "" { - st.toolInputBuf.WriteString(parts.ToolInputDelta) - } - - n := 0 - - // On content_block_stop, emit a rich tool summary if ending a tool block. - if parts.EventType == "content_block_stop" && st.currentToolName != "" { - summary := formatToolSummary(st.currentToolName, st.toolInputBuf.String()) - st.currentToolName = "" - st.toolInputBuf.Reset() - - n++ - onChunk(StreamChunk{ - Kind: StreamChunkMeta, - Text: summary, - EventType: parts.EventType, - Seq: seq + n, - }) - } - - if parts.Thinking != "" { - n++ - onChunk(StreamChunk{ - Kind: StreamChunkThinking, - Text: parts.Thinking, - EventType: parts.EventType, - Seq: seq + n, - }) - } - - if parts.Text != "" { - n++ - onChunk(StreamChunk{ - Kind: StreamChunkAnswer, - Text: parts.Text, - EventType: parts.EventType, - Seq: seq + n, - }) - } - - if parts.Meta != "" { - n++ - onChunk(StreamChunk{ - Kind: StreamChunkMeta, - Text: parts.Meta, - EventType: parts.EventType, - Seq: seq + n, - }) - } - - return n -} - -type streamEventParts struct { - EventType string - Thinking string - Text string - Meta string - ToolName string - ToolInputDelta string -} - -const toolSummaryMaxLen = 120 - -func parseStreamEvent(msg *claudesdk.StreamEvent) (streamEventParts, bool) { - eventType, _ := msg.Event["type"].(string) - parts := streamEventParts{ - EventType: eventType, - } - - switch eventType { - case "content_block_delta": - delta, ok := msg.Event["delta"].(map[string]any) - if !ok { - return streamEventParts{}, false - } - - if thinking, ok := delta["thinking"].(string); ok { - parts.Thinking = thinking - } - - if text, ok := delta["text"].(string); ok { - parts.Text = text - } - - if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { - if partial, ok := delta["partial_json"].(string); ok { - parts.ToolInputDelta = partial - } - } - case "content_block_start": - contentBlock, ok := msg.Event["content_block"].(map[string]any) - if !ok { - return streamEventParts{}, false - } - - blockType, _ := contentBlock["type"].(string) - if text, ok := contentBlock["text"].(string); ok { - parts.Text = text - } - - if blockType == "tool_use" { - toolName, _ := contentBlock["name"].(string) - if strings.TrimSpace(toolName) == "" { - toolName = "unknown" - } - - parts.Meta = fmt.Sprintf("Using tool: %s", toolName) - parts.ToolName = toolName - } - case "content_block_stop": - // No meta emitted here; streamEventChunks handles tool - // summaries via state, and non-tool blocks emit nothing. - default: - return streamEventParts{}, false - } - - return parts, true -} - -// formatToolSummary builds a human-readable summary of a completed tool -// invocation from the tool name and accumulated input JSON. -func formatToolSummary(toolName, inputJSON string) string { - if strings.TrimSpace(inputJSON) == "" { - return fmt.Sprintf("%s complete", toolName) - } - - var fields map[string]any - if err := json.Unmarshal([]byte(inputJSON), &fields); err != nil { - return fmt.Sprintf("%s complete", toolName) - } - - lower := strings.ToLower(toolName) - - switch lower { - case "bash": - if cmd, ok := fields["command"].(string); ok && cmd != "" { - if len(cmd) > toolSummaryMaxLen { - cmd = cmd[:toolSummaryMaxLen] + "..." - } - - return fmt.Sprintf("Bash: %s", cmd) - } - case "read": - if fp, ok := fields["file_path"].(string); ok && fp != "" { - return fmt.Sprintf("Read: %s", fp) - } - case "write": - if fp, ok := fields["file_path"].(string); ok && fp != "" { - return fmt.Sprintf("Write: %s", fp) - } - case "edit": - if fp, ok := fields["file_path"].(string); ok && fp != "" { - return fmt.Sprintf("Edit: %s", fp) - } - case "glob": - if pattern, ok := fields["pattern"].(string); ok && pattern != "" { - return fmt.Sprintf("Glob: %s", pattern) - } - case "grep": - if pattern, ok := fields["pattern"].(string); ok && pattern != "" { - return fmt.Sprintf("Grep: %s", pattern) - } - } - - return fmt.Sprintf("%s complete", toolName) -} - -func newSessionID() string { - buf := make([]byte, 8) - if _, err := rand.Read(buf); err != nil { - return fmt.Sprintf("sess-%d", time.Now().UnixNano()) - } - - return "sess-" + hex.EncodeToString(buf) -} diff --git a/pkg/ai/provider_impl_test.go b/pkg/ai/provider_impl_test.go new file mode 100644 index 0000000..78b6c2c --- /dev/null +++ b/pkg/ai/provider_impl_test.go @@ -0,0 +1,25 @@ +package ai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClaudeEngine_Provider(t *testing.T) { + t.Parallel() + + engine := newClaudeEngine(testLogger()) + assert.Equal(t, ProviderClaude, engine.Provider()) +} + +func TestClaudeEngine_Capabilities(t *testing.T) { + t.Parallel() + + engine := newClaudeEngine(testLogger()) + caps := engine.Capabilities() + + assert.True(t, caps.Streaming) + assert.True(t, caps.Interrupt) + assert.True(t, caps.Sessions) +} diff --git a/pkg/ai/provider_shared.go b/pkg/ai/provider_shared.go new file mode 100644 index 0000000..e144219 --- /dev/null +++ b/pkg/ai/provider_shared.go @@ -0,0 +1,368 @@ +package ai + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "strings" + "time" +) + +const defaultTimeout = 30 * time.Minute + +// ProviderDebugError carries structured diagnostics for provider failures. +type ProviderDebugError struct { + Cause error + Info map[string]any +} + +func (e *ProviderDebugError) Error() string { + if e == nil || e.Cause == nil { + return "provider error" + } + + return e.Cause.Error() +} + +func (e *ProviderDebugError) Unwrap() error { + if e == nil { + return nil + } + + return e.Cause +} + +func (e *ProviderDebugError) DebugInfo() map[string]any { + if e == nil || len(e.Info) == 0 { + return map[string]any{} + } + + out := make(map[string]any, len(e.Info)) + maps.Copy(out, e.Info) + + return out +} + +// askStreamState tracks streaming state across all providers. +type askStreamState struct { + seq int + assistantText strings.Builder + resultText string + streamText strings.Builder + assistantErr string + resultErr string + msgCount int + streamEventCount int + assistantCount int + resultCount int + streamChars int + thinkingChars int + assistantChars int + resultChars int + lastMessageType string + lastEventType string + resultSubtype string + resultIsError bool + endedWithEOF bool + currentToolName string + toolInputBuf strings.Builder +} + +func (st *askStreamState) debugInfo( + sessionID string, + promptLen, timeoutSecs, stderrCount int, + stderrTail []string, +) map[string]any { + return map[string]any{ + "session_id": sessionID, + "prompt_len": promptLen, + "msg_count": st.msgCount, + "stream_event_count": st.streamEventCount, + "assistant_count": st.assistantCount, + "result_count": st.resultCount, + "stream_chars": st.streamChars, + "thinking_chars": st.thinkingChars, + "assistant_chars": st.assistantChars, + "result_chars": st.resultChars, + "assistant_err": st.assistantErr, + "result_err": st.resultErr, + "result_subtype": st.resultSubtype, + "result_is_error": st.resultIsError, + "last_message_type": st.lastMessageType, + "last_stream_event": st.lastEventType, + "client_timeout_secs": timeoutSecs, + "ended_with_eof": st.endedWithEOF, + "stderr_count": stderrCount, + "stderr_tail": stderrTail, + } +} + +func (st *askStreamState) handleReceiveError( + err error, + debugInfo func() map[string]any, +) (bool, error) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true, err + } + + if errors.Is(err, io.EOF) { + st.endedWithEOF = true + + return true, nil + } + + if strings.Contains(strings.ToLower(err.Error()), "interrupt") { + return true, err + } + + return true, &ProviderDebugError{Cause: err, Info: debugInfo()} +} + +func (st *askStreamState) bestOutput() string { + if strings.TrimSpace(st.resultText) != "" { + return st.resultText + } + + out := strings.TrimSpace(st.assistantText.String()) + if out != "" { + return out + } + + return strings.TrimSpace(st.streamText.String()) +} + +func (st *askStreamState) finalResponse( + debugInfo func() map[string]any, +) (string, error) { + out := st.bestOutput() + if out != "" { + return out, nil + } + + if st.resultErr != "" { + return "", &ProviderDebugError{ + Cause: fmt.Errorf("provider result error: %s", st.resultErr), + Info: debugInfo(), + } + } + + if st.assistantErr != "" { + return "", &ProviderDebugError{ + Cause: fmt.Errorf("provider assistant error: %s", st.assistantErr), + Info: debugInfo(), + } + } + + return "", &ProviderDebugError{ + Cause: fmt.Errorf("empty provider response"), + Info: debugInfo(), + } +} + +// streamEventParts holds parsed fields from a raw stream event. +type streamEventParts struct { + EventType string + Thinking string + Text string + Meta string + ToolName string + ToolInputDelta string +} + +const toolSummaryMaxLen = 120 + +// parseStreamEventMap parses the raw event map from a StreamEvent. +func parseStreamEventMap(event map[string]any) (streamEventParts, bool) { + eventType, _ := event["type"].(string) + parts := streamEventParts{ + EventType: eventType, + } + + switch eventType { + case "content_block_delta": + delta, ok := event["delta"].(map[string]any) + if !ok { + return streamEventParts{}, false + } + + if thinking, ok := delta["thinking"].(string); ok { + parts.Thinking = thinking + } + + if text, ok := delta["text"].(string); ok { + parts.Text = text + } + + if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { + if partial, ok := delta["partial_json"].(string); ok { + parts.ToolInputDelta = partial + } + } + case "content_block_start": + contentBlock, ok := event["content_block"].(map[string]any) + if !ok { + return streamEventParts{}, false + } + + blockType, _ := contentBlock["type"].(string) + if text, ok := contentBlock["text"].(string); ok { + parts.Text = text + } + + if blockType == "tool_use" { + toolName, _ := contentBlock["name"].(string) + if strings.TrimSpace(toolName) == "" { + toolName = "unknown" + } + + parts.Meta = fmt.Sprintf("Using tool: %s", toolName) + parts.ToolName = toolName + } + case "content_block_stop": + // No meta emitted here; streamEventChunksFromMap handles tool + // summaries via state, and non-tool blocks emit nothing. + default: + return streamEventParts{}, false + } + + return parts, true +} + +// streamEventChunksFromMap processes a raw stream event map and emits +// StreamChunks via the onChunk callback. Returns the number of chunks emitted. +func streamEventChunksFromMap( + event map[string]any, + seq int, + st *askStreamState, + onChunk func(StreamChunk), +) int { + if onChunk == nil { + return 0 + } + + parts, ok := parseStreamEventMap(event) + if !ok { + return 0 + } + + // Track tool state across content blocks. + if parts.ToolName != "" { + st.currentToolName = parts.ToolName + st.toolInputBuf.Reset() + } + + if parts.ToolInputDelta != "" { + st.toolInputBuf.WriteString(parts.ToolInputDelta) + } + + n := 0 + + // On content_block_stop, emit a rich tool summary if ending a tool block. + if parts.EventType == "content_block_stop" && st.currentToolName != "" { + summary := formatToolSummary(st.currentToolName, st.toolInputBuf.String()) + st.currentToolName = "" + st.toolInputBuf.Reset() + + n++ + onChunk(StreamChunk{ + Kind: StreamChunkMeta, + Text: summary, + EventType: parts.EventType, + Seq: seq + n, + }) + } + + if parts.Thinking != "" { + n++ + onChunk(StreamChunk{ + Kind: StreamChunkThinking, + Text: parts.Thinking, + EventType: parts.EventType, + Seq: seq + n, + }) + } + + if parts.Text != "" { + n++ + onChunk(StreamChunk{ + Kind: StreamChunkAnswer, + Text: parts.Text, + EventType: parts.EventType, + Seq: seq + n, + }) + } + + if parts.Meta != "" { + n++ + onChunk(StreamChunk{ + Kind: StreamChunkMeta, + Text: parts.Meta, + EventType: parts.EventType, + Seq: seq + n, + }) + } + + return n +} + +// formatToolSummary builds a human-readable summary of a completed tool +// invocation from the tool name and accumulated input JSON. +func formatToolSummary(toolName, inputJSON string) string { + if strings.TrimSpace(inputJSON) == "" { + return fmt.Sprintf("%s complete", toolName) + } + + var fields map[string]any + if err := json.Unmarshal([]byte(inputJSON), &fields); err != nil { + return fmt.Sprintf("%s complete", toolName) + } + + lower := strings.ToLower(toolName) + + switch lower { + case "bash": + if cmd, ok := fields["command"].(string); ok && cmd != "" { + if len(cmd) > toolSummaryMaxLen { + cmd = cmd[:toolSummaryMaxLen] + "..." + } + + return fmt.Sprintf("Bash: %s", cmd) + } + case "read": + if fp, ok := fields["file_path"].(string); ok && fp != "" { + return fmt.Sprintf("Read: %s", fp) + } + case "write": + if fp, ok := fields["file_path"].(string); ok && fp != "" { + return fmt.Sprintf("Write: %s", fp) + } + case "edit": + if fp, ok := fields["file_path"].(string); ok && fp != "" { + return fmt.Sprintf("Edit: %s", fp) + } + case "glob": + if pattern, ok := fields["pattern"].(string); ok && pattern != "" { + return fmt.Sprintf("Glob: %s", pattern) + } + case "grep": + if pattern, ok := fields["pattern"].(string); ok && pattern != "" { + return fmt.Sprintf("Grep: %s", pattern) + } + } + + return fmt.Sprintf("%s complete", toolName) +} + +func newSessionID() string { + buf := make([]byte, 8) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("sess-%d", time.Now().UnixNano()) + } + + return "sess-" + hex.EncodeToString(buf) +} diff --git a/pkg/ai/provider_shared_test.go b/pkg/ai/provider_shared_test.go new file mode 100644 index 0000000..fb317a3 --- /dev/null +++ b/pkg/ai/provider_shared_test.go @@ -0,0 +1,402 @@ +package ai + +import ( + "errors" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProviderDebugError(t *testing.T) { + t.Parallel() + + t.Run("Error with cause", func(t *testing.T) { + t.Parallel() + + cause := fmt.Errorf("connection timeout") + pde := &ProviderDebugError{Cause: cause, Info: map[string]any{"key": "val"}} + + assert.Equal(t, "connection timeout", pde.Error()) + }) + + t.Run("Error with nil cause", func(t *testing.T) { + t.Parallel() + + pde := &ProviderDebugError{} + assert.Equal(t, "provider error", pde.Error()) + }) + + t.Run("Error on nil receiver", func(t *testing.T) { + t.Parallel() + + var pde *ProviderDebugError + assert.Equal(t, "provider error", pde.Error()) + }) + + t.Run("Unwrap returns cause", func(t *testing.T) { + t.Parallel() + + cause := fmt.Errorf("root cause") + pde := &ProviderDebugError{Cause: cause} + + assert.ErrorIs(t, pde, cause) + assert.Equal(t, cause, pde.Unwrap()) + }) + + t.Run("Unwrap on nil receiver", func(t *testing.T) { + t.Parallel() + + var pde *ProviderDebugError + assert.Nil(t, pde.Unwrap()) + }) + + t.Run("DebugInfo returns copy", func(t *testing.T) { + t.Parallel() + + info := map[string]any{"session": "abc", "count": 42} + pde := &ProviderDebugError{ + Cause: fmt.Errorf("err"), + Info: info, + } + + result := pde.DebugInfo() + assert.Equal(t, "abc", result["session"]) + assert.Equal(t, 42, result["count"]) + + // Mutating result should not affect original. + result["session"] = "modified" + + assert.Equal(t, "abc", pde.Info["session"]) + }) + + t.Run("DebugInfo on nil receiver", func(t *testing.T) { + t.Parallel() + + var pde *ProviderDebugError + assert.Empty(t, pde.DebugInfo()) + }) + + t.Run("DebugInfo with empty info", func(t *testing.T) { + t.Parallel() + + pde := &ProviderDebugError{Cause: fmt.Errorf("err")} + assert.Empty(t, pde.DebugInfo()) + }) + + t.Run("errors.As compatibility", func(t *testing.T) { + t.Parallel() + + cause := fmt.Errorf("inner") + wrapped := fmt.Errorf("outer: %w", &ProviderDebugError{ + Cause: cause, + Info: map[string]any{"k": "v"}, + }) + + var pde *ProviderDebugError + require.True(t, errors.As(wrapped, &pde)) + assert.Equal(t, "inner", pde.Error()) + assert.Equal(t, "v", pde.DebugInfo()["k"]) + }) +} + +func TestAskStreamState_BestOutput(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(st *askStreamState) + wantOutput string + }{ + { + name: "empty state returns empty", + setup: func(_ *askStreamState) {}, + wantOutput: "", + }, + { + name: "result text preferred over all", + setup: func(st *askStreamState) { + st.resultText = "result answer" + st.assistantText.WriteString("assistant answer") + st.streamText.WriteString("stream answer") + }, + wantOutput: "result answer", + }, + { + name: "assistant text when no result", + setup: func(st *askStreamState) { + st.assistantText.WriteString("assistant answer") + st.streamText.WriteString("stream answer") + }, + wantOutput: "assistant answer", + }, + { + name: "stream text as last resort", + setup: func(st *askStreamState) { + st.streamText.WriteString("stream answer") + }, + wantOutput: "stream answer", + }, + { + name: "whitespace-only result falls through", + setup: func(st *askStreamState) { + st.resultText = " " + st.assistantText.WriteString("assistant answer") + }, + wantOutput: "assistant answer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + st := &askStreamState{} + tc.setup(st) + assert.Equal(t, tc.wantOutput, st.bestOutput()) + }) + } +} + +func TestParseStreamEventMap(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + event map[string]any + wantOK bool + wantParts streamEventParts + }{ + { + name: "content_block_delta with thinking", + event: map[string]any{ + "type": "content_block_delta", + "delta": map[string]any{ + "thinking": "let me think...", + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_delta", + Thinking: "let me think...", + }, + }, + { + name: "content_block_delta with text", + event: map[string]any{ + "type": "content_block_delta", + "delta": map[string]any{ + "text": "hello world", + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_delta", + Text: "hello world", + }, + }, + { + name: "content_block_delta with tool input", + event: map[string]any{ + "type": "content_block_delta", + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": `{"command":`, + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_delta", + ToolInputDelta: `{"command":`, + }, + }, + { + name: "content_block_delta missing delta", + event: map[string]any{ + "type": "content_block_delta", + }, + wantOK: false, + }, + { + name: "content_block_start with text", + event: map[string]any{ + "type": "content_block_start", + "content_block": map[string]any{ + "type": "text", + "text": "start text", + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_start", + Text: "start text", + }, + }, + { + name: "content_block_start with tool_use", + event: map[string]any{ + "type": "content_block_start", + "content_block": map[string]any{ + "type": "tool_use", + "name": "Bash", + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_start", + Meta: "Using tool: Bash", + ToolName: "Bash", + }, + }, + { + name: "content_block_start with empty tool name", + event: map[string]any{ + "type": "content_block_start", + "content_block": map[string]any{ + "type": "tool_use", + "name": "", + }, + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_start", + Meta: "Using tool: unknown", + ToolName: "unknown", + }, + }, + { + name: "content_block_start missing content_block", + event: map[string]any{ + "type": "content_block_start", + }, + wantOK: false, + }, + { + name: "content_block_stop", + event: map[string]any{ + "type": "content_block_stop", + }, + wantOK: true, + wantParts: streamEventParts{ + EventType: "content_block_stop", + }, + }, + { + name: "unknown event type", + event: map[string]any{ + "type": "message_start", + }, + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + parts, ok := parseStreamEventMap(tc.event) + assert.Equal(t, tc.wantOK, ok) + + if tc.wantOK { + assert.Equal(t, tc.wantParts, parts) + } + }) + } +} + +func TestFormatToolSummary(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + toolName string + inputJSON string + want string + }{ + { + name: "Bash with command", + toolName: "Bash", + inputJSON: `{"command":"ls -la"}`, + want: "Bash: ls -la", + }, + { + name: "Read with file_path", + toolName: "Read", + inputJSON: `{"file_path":"/tmp/foo.go"}`, + want: "Read: /tmp/foo.go", + }, + { + name: "Write with file_path", + toolName: "Write", + inputJSON: `{"file_path":"/tmp/bar.go"}`, + want: "Write: /tmp/bar.go", + }, + { + name: "Edit with file_path", + toolName: "Edit", + inputJSON: `{"file_path":"/tmp/baz.go"}`, + want: "Edit: /tmp/baz.go", + }, + { + name: "Glob with pattern", + toolName: "Glob", + inputJSON: `{"pattern":"**/*.go"}`, + want: "Glob: **/*.go", + }, + { + name: "Grep with pattern", + toolName: "Grep", + inputJSON: `{"pattern":"TODO"}`, + want: "Grep: TODO", + }, + { + name: "unknown tool", + toolName: "CustomTool", + inputJSON: `{"data":"something"}`, + want: "CustomTool complete", + }, + { + name: "empty input", + toolName: "Bash", + inputJSON: "", + want: "Bash complete", + }, + { + name: "invalid JSON", + toolName: "Bash", + inputJSON: "not json{", + want: "Bash complete", + }, + { + name: "Bash with long command truncated", + toolName: "Bash", + inputJSON: fmt.Sprintf( + `{"command":"%s"}`, + strings.Repeat("x", 200), + ), + want: "Bash: " + strings.Repeat("x", toolSummaryMaxLen) + "...", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := formatToolSummary(tc.toolName, tc.inputJSON) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestNewSessionID(t *testing.T) { + t.Parallel() + + id1 := newSessionID() + id2 := newSessionID() + + assert.True(t, strings.HasPrefix(id1, "sess-"), "should have sess- prefix") + assert.True(t, strings.HasPrefix(id2, "sess-"), "should have sess- prefix") + assert.NotEqual(t, id1, id2, "IDs should be unique") +} diff --git a/pkg/ai/types.go b/pkg/ai/types.go index 5c1ea3b..edf2a7b 100644 --- a/pkg/ai/types.go +++ b/pkg/ai/types.go @@ -8,6 +8,8 @@ type ProviderID string const ( // ProviderClaude uses the Claude Code SDK-backed provider. ProviderClaude ProviderID = "claude" + // ProviderCodex uses the Codex CLI SDK-backed provider. + ProviderCodex ProviderID = "codex" ) // Capabilities describe what features a provider supports. diff --git a/pkg/cc/frontend/src/components/Dashboard/Dashboard.tsx b/pkg/cc/frontend/src/components/Dashboard/Dashboard.tsx index e806e51..db1a760 100644 --- a/pkg/cc/frontend/src/components/Dashboard/Dashboard.tsx +++ b/pkg/cc/frontend/src/components/Dashboard/Dashboard.tsx @@ -49,6 +49,7 @@ function bootPhasesFor(isLab: boolean) { function stopPhasesFor(isLab: boolean) { return isLab ? STOP_PHASES : XATU_STOP_PHASES; } +const STORAGE_KEY_PROVIDER = 'xcli:ai-provider'; const leftCollapsedStorageKey = 'xcli:sidebar-left-collapsed'; const rightCollapsedStorageKey = 'xcli:sidebar-right-collapsed'; const dashboardLegacyLayoutStorageKey = 'xcli:dashboard:panel-layout'; @@ -167,7 +168,15 @@ export default function Dashboard({ const leftDefaultDesktopPx = leftCollapsed ? sidebarCollapsedPx : leftExpandedPxRef.current; const rightDefaultDesktopPx = rightCollapsed ? sidebarCollapsedPx : rightExpandedPxRef.current; const [providers, setProviders] = useState([]); - const [selectedProvider, setSelectedProvider] = useState('claude'); + const [selectedProvider, setSelectedProvider] = useState(() => { + try { + const saved = localStorage.getItem(STORAGE_KEY_PROVIDER); + if (saved) return saved; + } catch { + // ignore storage failures + } + return 'claude'; + }); const [diagnoseService, setDiagnoseService] = useState(null); const [diagnoseSessionId, setDiagnoseSessionId] = useState(null); const [diagnoseRequestId, setDiagnoseRequestId] = useState(null); @@ -547,9 +556,29 @@ export default function Dashboard({ fetchJSON('/ai/providers') .then(data => { setProviders(data); - const preferred = data.find(p => p.default && p.available) ?? data.find(p => p.available) ?? data[0]; - if (preferred) { - setSelectedProvider(preferred.id); + + let savedProvider: string | null = null; + try { + savedProvider = localStorage.getItem(STORAGE_KEY_PROVIDER); + } catch { + // ignore storage failures + } + + const isAvailable = savedProvider && data.some((p: AIProviderInfo) => p.id === savedProvider && p.available); + + if (isAvailable && savedProvider) { + setSelectedProvider(savedProvider); + } else { + const defaultProvider = + data.find((p: AIProviderInfo) => p.default && p.available) || data.find((p: AIProviderInfo) => p.available); + if (defaultProvider) { + setSelectedProvider(defaultProvider.id); + try { + localStorage.setItem(STORAGE_KEY_PROVIDER, defaultProvider.id); + } catch { + // ignore storage failures + } + } } }) .catch(() => { @@ -558,8 +587,8 @@ export default function Dashboard({ }, [fetchJSON]); const handleDiagnose = useCallback( - (serviceName: string) => { - const provider = selectedProvider || 'claude'; + (serviceName: string, providerOverride?: string) => { + const provider = providerOverride || selectedProvider || 'claude'; const requestId = createRequestId(); setDiagnoseService(serviceName); @@ -656,6 +685,49 @@ export default function Dashboard({ setCurrentTurnPrompt(undefined); }, [deleteDiagnoseSession, diagnoseService, diagnoseSessionId]); + const handleProviderChange = useCallback( + (newProvider: string) => { + try { + localStorage.setItem(STORAGE_KEY_PROVIDER, newProvider); + } catch { + // ignore storage failures + } + setSelectedProvider(newProvider); + + // Restart active session with new provider + if (diagnoseService && diagnoseSessionId) { + const serviceId = diagnoseService; + deleteDiagnoseSession(serviceId, diagnoseSessionId) + .then(() => { + setDiagnoseSessionId(null); + setDiagnoseRequestId(null); + setCompletedTurns([]); + setDiagnosis(null); + setDiagnosisError(null); + setDiagnosing(false); + setThinkingText(''); + setAnswerText(''); + setActivityText(''); + setCurrentTurnPrompt(undefined); + handleDiagnose(serviceId, newProvider); + }) + .catch(() => { + setDiagnoseSessionId(null); + setDiagnoseRequestId(null); + setCompletedTurns([]); + setDiagnosis(null); + setDiagnosisError(null); + setDiagnosing(false); + setThinkingText(''); + setAnswerText(''); + setActivityText(''); + setCurrentTurnPrompt(undefined); + }); + } + }, + [deleteDiagnoseSession, diagnoseService, diagnoseSessionId, handleDiagnose] + ); + const handleLeftPanelResize = useCallback( (panelSize: PanelSize) => { if (leftCollapsed) return; @@ -961,7 +1033,7 @@ export default function Dashboard({ serviceName={diagnoseService ?? ''} providers={providers} selectedProvider={selectedProvider} - onProviderChange={setSelectedProvider} + onProviderChange={handleProviderChange} sessionId={diagnoseSessionId} thinkingText={thinkingText} activityText={activityText}