Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func newChatModel(ref *progRef, systemPrompt string, settings hawkconfig.Setting
startup.EndPhase("newChatModel:newHawkSession")

startup.MarkPhase("newChatModel:configureSession")
syncSessionFromPersistedSelection(sess, settings)
syncSessionFromPersistedSelection(sess)
sess.SetLogger(logger.New(io.Discard, logger.Error))
if cfgErr := configureSession(sess, settings); cfgErr != nil {
return chatModel{}, cfgErr
Expand Down
13 changes: 6 additions & 7 deletions cmd/chat_commands_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model,
m.messages = append(m.messages, displayMsg{role: "error", content: err.Error()})
return m, nil
}
engineProvider := hawkconfig.NormalizeProviderForEngine(value)
m.session.SetProvider(engineProvider)
m.syncSessionSelection()
// Use cached model or set first from cache
modelCacheMu.RLock()
cached, cacheHit := modelCache[engineProvider]
cached, cacheHit := modelCache[m.session.Provider()]
modelCacheMu.RUnlock()
if cacheHit && len(cached) > 0 {
m.session.SetModel(cached[0].ID)
Expand Down Expand Up @@ -52,8 +51,8 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model,
m.messages = append(m.messages, displayMsg{role: "error", content: err.Error()})
return m, nil
}
m.session.SetModel(value)
m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model switched to: %s\nSaved in eyrie (provider.json).", value)})
m.syncSessionSelection()
m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model switched to: %s\nSaved in eyrie (provider.json).", m.session.Model())})
return m, nil
}
if len(parts) >= 2 && parts[1] == "keys" {
Expand Down Expand Up @@ -95,9 +94,9 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model,
normalizedKey := strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(key, "-", ""), "_", ""))
switch normalizedKey {
case "model":
m.session.SetModel(value)
m.syncSessionSelection()
case "provider":
m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(value))
m.syncSessionSelection()
}
m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Updated %s = %s", key, value)})
return m, nil
Expand Down
2 changes: 2 additions & 0 deletions cmd/chat_commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func TestHandleCommandAddDir(t *testing.T) {
}

func TestLocalSlashCommands(t *testing.T) {
preserveCLICompilerVersionState(t)
version = "test-version"
sess := engine.NewSession("openai", "gpt-4o", "base", tool.NewRegistry(tool.LSTool{}))
m := &chatModel{
Expand Down Expand Up @@ -80,6 +81,7 @@ func TestLocalSlashCommands(t *testing.T) {
}

func TestDiagnosticSummaries(t *testing.T) {
preserveCLICompilerVersionState(t)
version = "test-version"
settings := hawkconfig.Settings{
Provider: "openai",
Expand Down
11 changes: 0 additions & 11 deletions cmd/chat_config_cache.go

This file was deleted.

10 changes: 7 additions & 3 deletions cmd/chat_config_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (m chatModel) handleConfigApplyCredentialsMsg(msg configApplyCredentialsMsg
}
if msg.providerID == configProviderOllama {
_ = hawkconfig.SetGlobalSetting("provider", configProviderOllama)
next.session.SetProvider(hawkconfig.NormalizeProviderForEngine(configProviderOllama))
next.syncSessionSelection()
}
next.configGuideAfterKey = false
if len(msg.modelOptions) == 0 {
Expand All @@ -221,10 +221,14 @@ func (m chatModel) handleConfigApplyCredentialsMsg(msg configApplyCredentialsMsg
}

func (m chatModel) rebuildSessionTransport() (chatModel, tea.Cmd) {
if err := engine.RebuildSessionTransport(context.Background(), m.session, hawkconfig.DeploymentRoutingEnabled(m.settings), m.session.Provider()); err != nil {
selection := runtime.EffectiveSelection(context.Background(), runtime.SelectionOpts{
ProviderOverride: firstNonEmptyTrimmed(m.session.Provider(), m.settings.Provider),
ModelOverride: firstNonEmptyTrimmed(m.session.Model(), m.settings.Model),
})
if err := engine.RebuildSessionTransport(context.Background(), m.session, selection, m.session.Provider()); err != nil {
m.configNotice = sanitizeConfigNotice(err.Error())
}
syncSessionFromPersistedSelection(m.session, m.settings)
syncSessionFromPersistedSelection(m.session)
m.invalidateConnStatus()
return m, nil
}
75 changes: 46 additions & 29 deletions cmd/chat_config_gateways.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import (
)

type configGatewayRow struct {
ID string
DisplayName string
HasKey bool
ModelCount int
Active bool
ID string
DisplayName string
HasKey bool
ModelCount int
Active bool
RegionLabel string
RegionRequired bool
}

type configGatewayRefreshMsg struct {
Expand All @@ -27,47 +29,52 @@ type configGatewayRefreshMsg struct {

func (m chatModel) configGatewayRows() []configGatewayRow {
ctx := context.Background()
providers := hawkconfig.AllSetupGateways()
configured := configuredGatewayKeys()
active := strings.TrimSpace(m.configModelProvider)
activeModel := ""
if active == "" && m.session != nil {
active = strings.TrimSpace(m.session.Provider())
}
var rows []configGatewayRow
for _, id := range providers {
if id == "" {
if m.session != nil {
activeModel = strings.TrimSpace(m.session.Model())
}
statuses := hawkconfig.GatewayStatuses(ctx, active, activeModel)
rows := make([]configGatewayRow, 0, len(statuses))
for _, status := range statuses {
if status.ID == "" {
continue
}
count := hawkconfig.CachedModelCountForProvider(id)
count := status.ModelCount
if count == 0 {
modelCacheMu.RLock()
if cached, ok := modelCache[id]; ok {
if cached, ok := modelCache[status.ID]; ok {
count = len(cached)
}
modelCacheMu.RUnlock()
}
hasKey := configured[id] || hawkconfig.HasStoredCredentialForProvider(ctx, id)
display := hawkconfig.GatewayDisplayName(id)
if id == hawkconfig.ProviderXiaomiTokenPlan {
if reg := hawkconfig.XiaomiTokenPlanRegionLabel(); reg != "" {
hasKey := status.HasStoredCredential
display := status.DisplayName
if status.ID == hawkconfig.ProviderXiaomiTokenPlan {
if reg := status.RegionLabel; reg != "" {
display += " · " + reg
} else {
display += " · region required"
}
}
if id == hawkconfig.ProviderZAICoding {
if reg := hawkconfig.ZAIRegionLabel(id); reg != "" {
if status.ID == hawkconfig.ProviderZAICoding {
if reg := status.RegionLabel; reg != "" {
display += " · " + reg
} else {
display += " · region"
}
}
rows = append(rows, configGatewayRow{
ID: id,
DisplayName: display,
HasKey: hasKey,
ModelCount: count,
Active: hawkconfig.NormalizeProviderForEngine(id) == hawkconfig.NormalizeProviderForEngine(active),
ID: status.ID,
DisplayName: display,
HasKey: hasKey,
ModelCount: count,
Active: status.Active || hawkconfig.ActiveProviderID(status.ID) == hawkconfig.ActiveProviderID(active),
RegionLabel: status.RegionLabel,
RegionRequired: status.RegionRequired,
})
}
return rows
Expand Down Expand Up @@ -117,11 +124,11 @@ func (m chatModel) refreshConfigGateway() (chatModel, tea.Cmd) {
}
idx := m.configGatewayRefreshTargetIndex(rows)
row := rows[idx]
if row.ID == hawkconfig.ProviderXiaomiTokenPlan && hawkconfig.NeedsXiaomiTokenPlanRegion(row.ID) {
if row.ID == hawkconfig.ProviderXiaomiTokenPlan && row.RegionRequired {
m.configNotice = "Pick Token Plan region (cn / sgp / ams) before refresh"
return m.startConfigXiaomiTokenPlanRegion(), nil
}
if row.ID == hawkconfig.ProviderZAICoding && hawkconfig.NeedsZAIRegion(row.ID) {
if row.ID == hawkconfig.ProviderZAICoding && row.RegionRequired {
m.configNotice = "Pick Coding Plan region (international / cn) before refresh"
return m.startConfigZAIRegion(row.ID), nil
}
Expand Down Expand Up @@ -274,12 +281,12 @@ func (m chatModel) handleConfigGatewaysSelect() (chatModel, tea.Cmd) {
}
row := rows[m.configSel]
if row.ID == hawkconfig.ProviderXiaomiTokenPlan {
if !row.HasKey || hawkconfig.NeedsXiaomiTokenPlanRegion(row.ID) {
if !row.HasKey || row.RegionRequired {
m.configGatewayFocus = m.configSel
return m.startConfigXiaomiTokenPlanRegion(), nil
}
}
if row.ID == hawkconfig.ProviderZAICoding && (!row.HasKey || hawkconfig.NeedsZAIRegion(row.ID)) {
if row.ID == hawkconfig.ProviderZAICoding && (!row.HasKey || row.RegionRequired) {
m.configGatewayFocus = m.configSel
return m.startConfigZAIRegion(row.ID), nil
}
Expand All @@ -295,7 +302,9 @@ func (m chatModel) handleConfigGatewaysSelect() (chatModel, tea.Cmd) {
m.configGatewayFocus = m.configSel
m.configModelProvider = gw
_ = hawkconfig.SetGlobalSetting("provider", gw)
m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(gw))
if active := hawkconfig.ActiveProvider(context.Background()); active != "" {
m.session.SetProvider(active)
}
m.configTab = configTabModels
m.configSel = 0
m.configScroll = 0
Expand Down Expand Up @@ -343,7 +352,15 @@ func (m chatModel) trackConfigGatewayFocus() chatModel {
if m.configTab != configTabGateways {
return m
}
rows := len(hawkconfig.AllSetupGateways())
active := strings.TrimSpace(m.configModelProvider)
activeModel := ""
if m.session != nil {
if active == "" {
active = strings.TrimSpace(m.session.Provider())
}
activeModel = strings.TrimSpace(m.session.Model())
}
rows := len(hawkconfig.GatewayStatuses(context.Background(), active, activeModel))
if m.configSel >= 0 && m.configSel < rows {
m.configGatewayFocus = m.configSel
}
Expand Down
3 changes: 1 addition & 2 deletions cmd/chat_config_panel.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,10 @@ func (m chatModel) selectConfigModelFromOptions(opts []configModelOption) (chatM
m.session.SetModel(modelID)
if gw := strings.TrimSpace(m.configModelProvider); gw != "" {
_ = hawkconfig.SetGlobalSetting("provider", gw)
m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(gw))
} else if prov := hawkconfig.ProviderOfModel(modelID); prov != "" {
_ = hawkconfig.SetGlobalSetting("provider", prov)
m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(prov))
}
m.syncSessionSelection()
next, cmd := m.rebuildSessionTransport()
next.invalidateConnStatus()
next = next.stopConfigModelSearch(true)
Expand Down
25 changes: 20 additions & 5 deletions cmd/chat_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

"github.com/GrayCodeAI/eyrie/runtime"
"github.com/charmbracelet/lipgloss"

hawkconfig "github.com/GrayCodeAI/hawk/internal/config"
Expand Down Expand Up @@ -76,11 +77,25 @@ func (m chatModel) sessionGatewayModel() (gateway, model string) {
}
if gateway == "" || model == "" {
ctx := context.Background()
if gateway == "" {
gateway = hawkconfig.ActiveGateway(ctx)
}
if model == "" {
model = strings.TrimSpace(hawkconfig.ActiveModel(ctx))
explicitGateway, explicitModel := explicitSelection(ctx)
if explicitGateway != "" || explicitModel != "" {
if gateway == "" {
gateway = explicitGateway
}
if model == "" {
model = explicitModel
}
} else {
selection := runtime.EffectiveSelection(ctx, runtime.SelectionOpts{
ProviderOverride: gateway,
ModelOverride: model,
})
if gateway == "" {
gateway = strings.TrimSpace(selection.Provider)
}
if model == "" {
model = strings.TrimSpace(selection.Model)
}
}
}
return gateway, model
Expand Down
7 changes: 5 additions & 2 deletions cmd/chat_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ func TestChatConnectionStatus_NoGatewayNoModel(t *testing.T) {

m := chatModel{session: &engine.Session{}}
got := m.chatConnectionStatus()
if got != "pick model" {
t.Fatalf("status = %q", got)
if !strings.Contains(got, "Anthropic · ") {
t.Fatalf("expected runtime-selected anthropic status, got %q", got)
}
if strings.Contains(got, "pick model") {
t.Fatalf("expected auto-selected model when only anthropic credentials exist, got %q", got)
}
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/chat_subcommand_simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ func init() {
handler: func(m *chatModel, args []string, text string) (tea.Model, tea.Cmd) {
savedModel := hawkconfig.ActiveModel(context.Background())
if m.session.Model() == savedModel {
norm := hawkconfig.NormalizeProviderForEngine(m.session.Provider())
fastModel := hawkconfig.CheapestModelForProvider(norm, m.session.Model())
providerName := strings.TrimSpace(m.session.Provider())
fastModel := hawkconfig.CheapestModelForProvider(providerName, m.session.Model())
if strings.TrimSpace(fastModel) == "" {
fastModel = hawkconfig.DefaultModelForProvider(norm)
fastModel = hawkconfig.DefaultModelForProvider(providerName)
}
if strings.TrimSpace(fastModel) == "" {
m.messages = append(m.messages, displayMsg{role: "error", content: "Fast mode: no catalog model resolved for this provider"})
Expand Down
1 change: 1 addition & 0 deletions cmd/completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func TestAllProvidersInFishCompletion(t *testing.T) {
}

func TestCompletionJSONCommand(t *testing.T) {
preserveCLICompilerVersionState(t)
SetVersion("test-version")

buf := new(bytes.Buffer)
Expand Down
2 changes: 2 additions & 0 deletions cmd/dx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

func TestDoctorOutputContainsSections(t *testing.T) {
preserveCLICompilerVersionState(t)
version = "test-dx-version"
settings := hawkconfig.Settings{
Provider: "openai",
Expand Down Expand Up @@ -47,6 +48,7 @@ func TestDoctorOutputContainsSections(t *testing.T) {
}

func TestDoctorOutputWithMCPServers(t *testing.T) {
preserveCLICompilerVersionState(t)
version = "test-dx-version"
settings := hawkconfig.Settings{
Provider: "anthropic",
Expand Down
6 changes: 5 additions & 1 deletion cmd/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -376,7 +377,10 @@ func validateStartup(settings hawkconfig.Settings) []StartupWarning {
var warnings []StartupWarning

// 1. Check API key for configured provider
providerName := hawkconfig.NormalizeProviderForEngine(settings.Provider)
providerName := strings.TrimSpace(settings.Provider)
if providerName == "" {
providerName = strings.TrimSpace(hawkconfig.ActiveProvider(context.Background()))
}
if providerName != "" && providerName != "ollama" {
envKey := hawkconfig.ProviderAPIKeyEnv(providerName)
if envKey != "" && os.Getenv(envKey) == "" {
Expand Down
2 changes: 2 additions & 0 deletions cmd/manpage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
)

func TestGenerateManPage(t *testing.T) {
preserveCLICompilerVersionState(t)
version = "1.0.0"
page := GenerateManPage()

Expand Down Expand Up @@ -48,6 +49,7 @@ func TestGenerateManPage(t *testing.T) {
}

func TestGenerateManPage_EmptyVersion(t *testing.T) {
preserveCLICompilerVersionState(t)
version = ""
page := GenerateManPage()
if !strings.Contains(page, "dev") {
Expand Down
2 changes: 1 addition & 1 deletion cmd/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ var modelsListCmd = &cobra.Command{
if providerName == "" {
return fmt.Errorf("provider required with --live (e.g. hawk models list canopywave --live --json)")
}
models, err = catalog.FetchLiveModelEntriesForProvider(eyriecfg.DiscoveryEnvMap(ctx), hawkconfig.NormalizeProviderForEngine(providerName))
models, err = catalog.FetchLiveModelEntriesForProvider(eyriecfg.DiscoveryEnvMap(ctx), hawkconfig.ActiveProviderID(providerName))
} else {
models, err = hawkconfig.FetchModelsForProvider(providerName)
}
Expand Down
Loading
Loading