diff --git a/cmd/chat.go b/cmd/chat.go index 754c1a29..acc4d707 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -138,7 +138,7 @@ func newChatModel(ref *progRef, systemPrompt string, settings hawkconfig.Setting startup.EndPhase("newChatModel:newHawkSession") startup.MarkPhase("newChatModel:configureSession") - syncSessionFromPersistedSelection(sess, settings) + syncSessionFromPersistedSelection(sess) sess.SetLogger(logger.New(io.Discard, logger.Error)) if cfgErr := configureSession(sess, settings); cfgErr != nil { return chatModel{}, cfgErr diff --git a/cmd/chat_commands_config.go b/cmd/chat_commands_config.go index 7c0e813f..dba9f0fa 100644 --- a/cmd/chat_commands_config.go +++ b/cmd/chat_commands_config.go @@ -17,11 +17,10 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model, m.messages = append(m.messages, displayMsg{role: "error", content: err.Error()}) return m, nil } - engineProvider := hawkconfig.NormalizeProviderForEngine(value) - m.session.SetProvider(engineProvider) + m.syncSessionSelection() // Use cached model or set first from cache modelCacheMu.RLock() - cached, cacheHit := modelCache[engineProvider] + cached, cacheHit := modelCache[m.session.Provider()] modelCacheMu.RUnlock() if cacheHit && len(cached) > 0 { m.session.SetModel(cached[0].ID) @@ -52,8 +51,8 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model, m.messages = append(m.messages, displayMsg{role: "error", content: err.Error()}) return m, nil } - m.session.SetModel(value) - m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model switched to: %s\nSaved in eyrie (provider.json).", value)}) + m.syncSessionSelection() + m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Model switched to: %s\nSaved in eyrie (provider.json).", m.session.Model())}) return m, nil } if len(parts) >= 2 && parts[1] == "keys" { @@ -95,9 +94,9 @@ func (m *chatModel) handleConfigCommand(parts []string, text string) (tea.Model, normalizedKey := strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(key, "-", ""), "_", "")) switch normalizedKey { case "model": - m.session.SetModel(value) + m.syncSessionSelection() case "provider": - m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(value)) + m.syncSessionSelection() } m.messages = append(m.messages, displayMsg{role: "system", content: fmt.Sprintf("Updated %s = %s", key, value)}) return m, nil diff --git a/cmd/chat_commands_test.go b/cmd/chat_commands_test.go index cf5053c3..e513d7e3 100644 --- a/cmd/chat_commands_test.go +++ b/cmd/chat_commands_test.go @@ -52,6 +52,7 @@ func TestHandleCommandAddDir(t *testing.T) { } func TestLocalSlashCommands(t *testing.T) { + preserveCLICompilerVersionState(t) version = "test-version" sess := engine.NewSession("openai", "gpt-4o", "base", tool.NewRegistry(tool.LSTool{})) m := &chatModel{ @@ -80,6 +81,7 @@ func TestLocalSlashCommands(t *testing.T) { } func TestDiagnosticSummaries(t *testing.T) { + preserveCLICompilerVersionState(t) version = "test-version" settings := hawkconfig.Settings{ Provider: "openai", diff --git a/cmd/chat_config_cache.go b/cmd/chat_config_cache.go deleted file mode 100644 index faa36cf1..00000000 --- a/cmd/chat_config_cache.go +++ /dev/null @@ -1,11 +0,0 @@ -package cmd - -import hawkconfig "github.com/GrayCodeAI/hawk/internal/config" - -func configuredGatewayKeys() map[string]bool { - out := map[string]bool{} - for _, p := range hawkconfig.ConfiguredCredentialProviders() { - out[p] = true - } - return out -} diff --git a/cmd/chat_config_deployment.go b/cmd/chat_config_deployment.go index 3f563a20..8b612272 100644 --- a/cmd/chat_config_deployment.go +++ b/cmd/chat_config_deployment.go @@ -202,7 +202,7 @@ func (m chatModel) handleConfigApplyCredentialsMsg(msg configApplyCredentialsMsg } if msg.providerID == configProviderOllama { _ = hawkconfig.SetGlobalSetting("provider", configProviderOllama) - next.session.SetProvider(hawkconfig.NormalizeProviderForEngine(configProviderOllama)) + next.syncSessionSelection() } next.configGuideAfterKey = false if len(msg.modelOptions) == 0 { @@ -221,10 +221,14 @@ func (m chatModel) handleConfigApplyCredentialsMsg(msg configApplyCredentialsMsg } func (m chatModel) rebuildSessionTransport() (chatModel, tea.Cmd) { - if err := engine.RebuildSessionTransport(context.Background(), m.session, hawkconfig.DeploymentRoutingEnabled(m.settings), m.session.Provider()); err != nil { + selection := runtime.EffectiveSelection(context.Background(), runtime.SelectionOpts{ + ProviderOverride: firstNonEmptyTrimmed(m.session.Provider(), m.settings.Provider), + ModelOverride: firstNonEmptyTrimmed(m.session.Model(), m.settings.Model), + }) + if err := engine.RebuildSessionTransport(context.Background(), m.session, selection, m.session.Provider()); err != nil { m.configNotice = sanitizeConfigNotice(err.Error()) } - syncSessionFromPersistedSelection(m.session, m.settings) + syncSessionFromPersistedSelection(m.session) m.invalidateConnStatus() return m, nil } diff --git a/cmd/chat_config_gateways.go b/cmd/chat_config_gateways.go index 546b1cda..cd1f52aa 100644 --- a/cmd/chat_config_gateways.go +++ b/cmd/chat_config_gateways.go @@ -12,11 +12,13 @@ import ( ) type configGatewayRow struct { - ID string - DisplayName string - HasKey bool - ModelCount int - Active bool + ID string + DisplayName string + HasKey bool + ModelCount int + Active bool + RegionLabel string + RegionRequired bool } type configGatewayRefreshMsg struct { @@ -27,47 +29,52 @@ type configGatewayRefreshMsg struct { func (m chatModel) configGatewayRows() []configGatewayRow { ctx := context.Background() - providers := hawkconfig.AllSetupGateways() - configured := configuredGatewayKeys() active := strings.TrimSpace(m.configModelProvider) + activeModel := "" if active == "" && m.session != nil { active = strings.TrimSpace(m.session.Provider()) } - var rows []configGatewayRow - for _, id := range providers { - if id == "" { + if m.session != nil { + activeModel = strings.TrimSpace(m.session.Model()) + } + statuses := hawkconfig.GatewayStatuses(ctx, active, activeModel) + rows := make([]configGatewayRow, 0, len(statuses)) + for _, status := range statuses { + if status.ID == "" { continue } - count := hawkconfig.CachedModelCountForProvider(id) + count := status.ModelCount if count == 0 { modelCacheMu.RLock() - if cached, ok := modelCache[id]; ok { + if cached, ok := modelCache[status.ID]; ok { count = len(cached) } modelCacheMu.RUnlock() } - hasKey := configured[id] || hawkconfig.HasStoredCredentialForProvider(ctx, id) - display := hawkconfig.GatewayDisplayName(id) - if id == hawkconfig.ProviderXiaomiTokenPlan { - if reg := hawkconfig.XiaomiTokenPlanRegionLabel(); reg != "" { + hasKey := status.HasStoredCredential + display := status.DisplayName + if status.ID == hawkconfig.ProviderXiaomiTokenPlan { + if reg := status.RegionLabel; reg != "" { display += " · " + reg } else { display += " · region required" } } - if id == hawkconfig.ProviderZAICoding { - if reg := hawkconfig.ZAIRegionLabel(id); reg != "" { + if status.ID == hawkconfig.ProviderZAICoding { + if reg := status.RegionLabel; reg != "" { display += " · " + reg } else { display += " · region" } } rows = append(rows, configGatewayRow{ - ID: id, - DisplayName: display, - HasKey: hasKey, - ModelCount: count, - Active: hawkconfig.NormalizeProviderForEngine(id) == hawkconfig.NormalizeProviderForEngine(active), + ID: status.ID, + DisplayName: display, + HasKey: hasKey, + ModelCount: count, + Active: status.Active || hawkconfig.ActiveProviderID(status.ID) == hawkconfig.ActiveProviderID(active), + RegionLabel: status.RegionLabel, + RegionRequired: status.RegionRequired, }) } return rows @@ -117,11 +124,11 @@ func (m chatModel) refreshConfigGateway() (chatModel, tea.Cmd) { } idx := m.configGatewayRefreshTargetIndex(rows) row := rows[idx] - if row.ID == hawkconfig.ProviderXiaomiTokenPlan && hawkconfig.NeedsXiaomiTokenPlanRegion(row.ID) { + if row.ID == hawkconfig.ProviderXiaomiTokenPlan && row.RegionRequired { m.configNotice = "Pick Token Plan region (cn / sgp / ams) before refresh" return m.startConfigXiaomiTokenPlanRegion(), nil } - if row.ID == hawkconfig.ProviderZAICoding && hawkconfig.NeedsZAIRegion(row.ID) { + if row.ID == hawkconfig.ProviderZAICoding && row.RegionRequired { m.configNotice = "Pick Coding Plan region (international / cn) before refresh" return m.startConfigZAIRegion(row.ID), nil } @@ -274,12 +281,12 @@ func (m chatModel) handleConfigGatewaysSelect() (chatModel, tea.Cmd) { } row := rows[m.configSel] if row.ID == hawkconfig.ProviderXiaomiTokenPlan { - if !row.HasKey || hawkconfig.NeedsXiaomiTokenPlanRegion(row.ID) { + if !row.HasKey || row.RegionRequired { m.configGatewayFocus = m.configSel return m.startConfigXiaomiTokenPlanRegion(), nil } } - if row.ID == hawkconfig.ProviderZAICoding && (!row.HasKey || hawkconfig.NeedsZAIRegion(row.ID)) { + if row.ID == hawkconfig.ProviderZAICoding && (!row.HasKey || row.RegionRequired) { m.configGatewayFocus = m.configSel return m.startConfigZAIRegion(row.ID), nil } @@ -295,7 +302,9 @@ func (m chatModel) handleConfigGatewaysSelect() (chatModel, tea.Cmd) { m.configGatewayFocus = m.configSel m.configModelProvider = gw _ = hawkconfig.SetGlobalSetting("provider", gw) - m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(gw)) + if active := hawkconfig.ActiveProvider(context.Background()); active != "" { + m.session.SetProvider(active) + } m.configTab = configTabModels m.configSel = 0 m.configScroll = 0 @@ -343,7 +352,15 @@ func (m chatModel) trackConfigGatewayFocus() chatModel { if m.configTab != configTabGateways { return m } - rows := len(hawkconfig.AllSetupGateways()) + active := strings.TrimSpace(m.configModelProvider) + activeModel := "" + if m.session != nil { + if active == "" { + active = strings.TrimSpace(m.session.Provider()) + } + activeModel = strings.TrimSpace(m.session.Model()) + } + rows := len(hawkconfig.GatewayStatuses(context.Background(), active, activeModel)) if m.configSel >= 0 && m.configSel < rows { m.configGatewayFocus = m.configSel } diff --git a/cmd/chat_config_panel.go b/cmd/chat_config_panel.go index 98ac13b1..e79a3719 100644 --- a/cmd/chat_config_panel.go +++ b/cmd/chat_config_panel.go @@ -697,11 +697,10 @@ func (m chatModel) selectConfigModelFromOptions(opts []configModelOption) (chatM m.session.SetModel(modelID) if gw := strings.TrimSpace(m.configModelProvider); gw != "" { _ = hawkconfig.SetGlobalSetting("provider", gw) - m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(gw)) } else if prov := hawkconfig.ProviderOfModel(modelID); prov != "" { _ = hawkconfig.SetGlobalSetting("provider", prov) - m.session.SetProvider(hawkconfig.NormalizeProviderForEngine(prov)) } + m.syncSessionSelection() next, cmd := m.rebuildSessionTransport() next.invalidateConnStatus() next = next.stopConfigModelSearch(true) diff --git a/cmd/chat_status.go b/cmd/chat_status.go index 7d647c57..95b444ff 100644 --- a/cmd/chat_status.go +++ b/cmd/chat_status.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/GrayCodeAI/eyrie/runtime" "github.com/charmbracelet/lipgloss" hawkconfig "github.com/GrayCodeAI/hawk/internal/config" @@ -76,11 +77,25 @@ func (m chatModel) sessionGatewayModel() (gateway, model string) { } if gateway == "" || model == "" { ctx := context.Background() - if gateway == "" { - gateway = hawkconfig.ActiveGateway(ctx) - } - if model == "" { - model = strings.TrimSpace(hawkconfig.ActiveModel(ctx)) + explicitGateway, explicitModel := explicitSelection(ctx) + if explicitGateway != "" || explicitModel != "" { + if gateway == "" { + gateway = explicitGateway + } + if model == "" { + model = explicitModel + } + } else { + selection := runtime.EffectiveSelection(ctx, runtime.SelectionOpts{ + ProviderOverride: gateway, + ModelOverride: model, + }) + if gateway == "" { + gateway = strings.TrimSpace(selection.Provider) + } + if model == "" { + model = strings.TrimSpace(selection.Model) + } } } return gateway, model diff --git a/cmd/chat_status_test.go b/cmd/chat_status_test.go index 9e0b93e6..d7b3dc72 100644 --- a/cmd/chat_status_test.go +++ b/cmd/chat_status_test.go @@ -163,8 +163,11 @@ func TestChatConnectionStatus_NoGatewayNoModel(t *testing.T) { m := chatModel{session: &engine.Session{}} got := m.chatConnectionStatus() - if got != "pick model" { - t.Fatalf("status = %q", got) + if !strings.Contains(got, "Anthropic · ") { + t.Fatalf("expected runtime-selected anthropic status, got %q", got) + } + if strings.Contains(got, "pick model") { + t.Fatalf("expected auto-selected model when only anthropic credentials exist, got %q", got) } } diff --git a/cmd/chat_subcommand_simple.go b/cmd/chat_subcommand_simple.go index 14b2681c..65a89892 100644 --- a/cmd/chat_subcommand_simple.go +++ b/cmd/chat_subcommand_simple.go @@ -121,10 +121,10 @@ func init() { handler: func(m *chatModel, args []string, text string) (tea.Model, tea.Cmd) { savedModel := hawkconfig.ActiveModel(context.Background()) if m.session.Model() == savedModel { - norm := hawkconfig.NormalizeProviderForEngine(m.session.Provider()) - fastModel := hawkconfig.CheapestModelForProvider(norm, m.session.Model()) + providerName := strings.TrimSpace(m.session.Provider()) + fastModel := hawkconfig.CheapestModelForProvider(providerName, m.session.Model()) if strings.TrimSpace(fastModel) == "" { - fastModel = hawkconfig.DefaultModelForProvider(norm) + fastModel = hawkconfig.DefaultModelForProvider(providerName) } if strings.TrimSpace(fastModel) == "" { m.messages = append(m.messages, displayMsg{role: "error", content: "Fast mode: no catalog model resolved for this provider"}) diff --git a/cmd/completions_test.go b/cmd/completions_test.go index 8e19f897..0a516d6f 100644 --- a/cmd/completions_test.go +++ b/cmd/completions_test.go @@ -418,6 +418,7 @@ func TestAllProvidersInFishCompletion(t *testing.T) { } func TestCompletionJSONCommand(t *testing.T) { + preserveCLICompilerVersionState(t) SetVersion("test-version") buf := new(bytes.Buffer) diff --git a/cmd/dx_test.go b/cmd/dx_test.go index 3d1a0118..63e58125 100644 --- a/cmd/dx_test.go +++ b/cmd/dx_test.go @@ -13,6 +13,7 @@ import ( ) func TestDoctorOutputContainsSections(t *testing.T) { + preserveCLICompilerVersionState(t) version = "test-dx-version" settings := hawkconfig.Settings{ Provider: "openai", @@ -47,6 +48,7 @@ func TestDoctorOutputContainsSections(t *testing.T) { } func TestDoctorOutputWithMCPServers(t *testing.T) { + preserveCLICompilerVersionState(t) version = "test-dx-version" settings := hawkconfig.Settings{ Provider: "anthropic", diff --git a/cmd/errors.go b/cmd/errors.go index 1657b02e..9ca070f3 100644 --- a/cmd/errors.go +++ b/cmd/errors.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "net" "os" @@ -376,7 +377,10 @@ func validateStartup(settings hawkconfig.Settings) []StartupWarning { var warnings []StartupWarning // 1. Check API key for configured provider - providerName := hawkconfig.NormalizeProviderForEngine(settings.Provider) + providerName := strings.TrimSpace(settings.Provider) + if providerName == "" { + providerName = strings.TrimSpace(hawkconfig.ActiveProvider(context.Background())) + } if providerName != "" && providerName != "ollama" { envKey := hawkconfig.ProviderAPIKeyEnv(providerName) if envKey != "" && os.Getenv(envKey) == "" { diff --git a/cmd/manpage_test.go b/cmd/manpage_test.go index ea3f86e4..d2edcc49 100644 --- a/cmd/manpage_test.go +++ b/cmd/manpage_test.go @@ -6,6 +6,7 @@ import ( ) func TestGenerateManPage(t *testing.T) { + preserveCLICompilerVersionState(t) version = "1.0.0" page := GenerateManPage() @@ -48,6 +49,7 @@ func TestGenerateManPage(t *testing.T) { } func TestGenerateManPage_EmptyVersion(t *testing.T) { + preserveCLICompilerVersionState(t) version = "" page := GenerateManPage() if !strings.Contains(page, "dev") { diff --git a/cmd/models.go b/cmd/models.go index 234c9548..b20e61cd 100644 --- a/cmd/models.go +++ b/cmd/models.go @@ -98,7 +98,7 @@ var modelsListCmd = &cobra.Command{ if providerName == "" { return fmt.Errorf("provider required with --live (e.g. hawk models list canopywave --live --json)") } - models, err = catalog.FetchLiveModelEntriesForProvider(eyriecfg.DiscoveryEnvMap(ctx), hawkconfig.NormalizeProviderForEngine(providerName)) + models, err = catalog.FetchLiveModelEntriesForProvider(eyriecfg.DiscoveryEnvMap(ctx), hawkconfig.ActiveProviderID(providerName)) } else { models, err = hawkconfig.FetchModelsForProvider(providerName) } diff --git a/cmd/options.go b/cmd/options.go index 6f1d4bad..d8da3cc3 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/GrayCodeAI/eyrie/runtime" hawkconfig "github.com/GrayCodeAI/hawk/internal/config" ctxrepomap "github.com/GrayCodeAI/hawk/internal/context/repomap" "github.com/GrayCodeAI/hawk/internal/engine" @@ -162,48 +163,28 @@ func loadEffectiveSettings() (hawkconfig.Settings, error) { } func effectiveModelAndProvider(settings hawkconfig.Settings) (string, string) { - ctx := context.Background() - hawkconfig.SyncSelectionWithCredentials(ctx) - if !hawkconfig.HasConfiguredDeployment(ctx) { - return "", "" - } - effectiveModel := hawkconfig.ActiveModel(ctx) - if strings.TrimSpace(model) != "" { - effectiveModel = strings.TrimSpace(model) - } - if strings.TrimSpace(settings.Model) != "" { - effectiveModel = strings.TrimSpace(settings.Model) - } - effectiveProvider := hawkconfig.ActiveProvider(ctx) - if strings.TrimSpace(provider) != "" { - effectiveProvider = strings.TrimSpace(provider) - } - if strings.TrimSpace(settings.Provider) != "" { - effectiveProvider = strings.TrimSpace(settings.Provider) - } - // If the configured provider's API key is missing, fall back to auto-detection - // so users with ANTHROPIC_API_KEY don't get confusing errors about canopywave. - normalized := hawkconfig.NormalizeProviderForEngine(effectiveProvider) - if normalized != "" && hawkconfig.APIKeyForProvider(normalized) == "" { - detected := types.DetectProvider() - if detected != "" && hawkconfig.APIKeyForProvider(detected) != "" { - normalized = detected - effectiveModel = "" - } - } - if normalized != "" && strings.TrimSpace(effectiveModel) == "" { - if resolved := hawkconfig.DefaultModelForProvider(normalized); resolved != "" { - effectiveModel = resolved - } - } - if hawkconfig.DeploymentRoutingEnabled(settings) && strings.TrimSpace(effectiveModel) != "" { - effectiveModel = hawkconfig.ResolveCanonicalModel(effectiveModel) - } - return effectiveModel, normalized + selection := runtime.EffectiveSelection(context.Background(), runtime.SelectionOpts{ + ProviderOverride: firstNonEmptyTrimmed(provider, settings.Provider), + ModelOverride: firstNonEmptyTrimmed(model, settings.Model), + }) + return selection.Model, selection.Provider } func newHawkSession(settings hawkconfig.Settings, effectiveProvider, effectiveModel, systemPrompt string, registry *tool.Registry) *engine.Session { - return engine.NewHawkSession(context.Background(), hawkconfig.DeploymentRoutingEnabled(settings), effectiveProvider, effectiveModel, systemPrompt, registry) + selection := runtime.EffectiveSelection(context.Background(), runtime.SelectionOpts{ + ProviderOverride: firstNonEmptyTrimmed(provider, settings.Provider), + ModelOverride: firstNonEmptyTrimmed(model, settings.Model), + }) + return engine.NewHawkSession(context.Background(), selection, effectiveProvider, effectiveModel, systemPrompt, registry) +} + +func firstNonEmptyTrimmed(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" } func configureSession(sess *engine.Session, settings hawkconfig.Settings, maxTurnsOverride ...int) error { @@ -226,10 +207,9 @@ func configureSession(sess *engine.Session, settings hawkconfig.Settings, maxTur enhancedMem.StartSession(fmt.Sprintf("session_%d", time.Now().UnixNano())) } // Hawk: API keys from OS secret store only - normalizedProvider := hawkconfig.NormalizeProviderForEngine(settings.Provider) - if normalizedProvider != "" { - if key := hawkconfig.APIKeyForProvider(normalizedProvider); key != "" { - sess.SetAPIKey(normalizedProvider, key) + if providerName := strings.TrimSpace(sess.Provider()); providerName != "" { + if key := hawkconfig.APIKeyForProvider(providerName); key != "" { + sess.SetAPIKey(providerName, key) } } sess.SetAPIKeys(hawkconfig.LoadAPIKeysFromStore()) diff --git a/cmd/options_welcome_test.go b/cmd/options_welcome_test.go index 18d02b8f..760da4e5 100644 --- a/cmd/options_welcome_test.go +++ b/cmd/options_welcome_test.go @@ -14,8 +14,10 @@ import ( func isolateCredentialHome(t *testing.T) { t.Helper() home := t.TempDir() - _ = os.MkdirAll(filepath.Join(home, ".hawk"), 0o700) + hawkDir := filepath.Join(home, ".hawk") + _ = os.MkdirAll(hawkDir, 0o700) t.Setenv("HOME", home) + t.Setenv("HAWK_CONFIG_DIR", hawkDir) } func TestEffectiveModelAndProvider_ClearsWithoutCredentials(t *testing.T) { diff --git a/cmd/review_analyze.go b/cmd/review_analyze.go index 1844c235..2ee55f9d 100644 --- a/cmd/review_analyze.go +++ b/cmd/review_analyze.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/GrayCodeAI/eyrie/runtime" reviewcontracts "github.com/GrayCodeAI/hawk-core-contracts/review" hawkSight "github.com/GrayCodeAI/hawk/internal/bridge/sight" "github.com/GrayCodeAI/hawk/internal/types" @@ -123,12 +124,19 @@ func runReviewAnalyze(_ *cobra.Command, args []string) error { return nil } - // Build sight bridge for analysis. - prov := provider - if prov == "" { - prov = types.DetectProvider() + // Build sight bridge from the runtime-owned transport resolution. + transport, err := runtime.ResolveChatTransport(context.Background(), runtime.ChatTransportOpts{ + Selection: runtime.SelectionOpts{ + ProviderOverride: strings.TrimSpace(provider), + ModelOverride: analyzeModel, + }, + }) + if err != nil { + return fmt.Errorf("resolve runtime transport: %w", err) + } + if transport.Provider == nil { + return fmt.Errorf("runtime transport unavailable for provider %q", transport.Selection.Provider) } - eyrieClient := types.NewClient(&types.ClientConfig{Provider: prov}) var opts []sightLib.Option if analyzeModel != "" { @@ -136,7 +144,7 @@ func runReviewAnalyze(_ *cobra.Command, args []string) error { } opts = append(opts, sightLib.WithConcerns(analysisType)) - bridge := hawkSight.NewBridge(eyrieClient, prov, opts...) + bridge := hawkSight.NewBridge(types.WrapClientProvider(transport.Provider), transport.Selection.Provider, opts...) if !bridge.Ready() { return fmt.Errorf("sight bridge not ready (check API key)") } diff --git a/cmd/review_run.go b/cmd/review_run.go index f9783b21..b9b8ea56 100644 --- a/cmd/review_run.go +++ b/cmd/review_run.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/GrayCodeAI/eyrie/runtime" reviewcontracts "github.com/GrayCodeAI/hawk-core-contracts/review" hawkSight "github.com/GrayCodeAI/hawk/internal/bridge/sight" "github.com/GrayCodeAI/hawk/internal/types" @@ -85,12 +86,21 @@ func runReviewRun(_ *cobra.Command, args []string) error { return nil } - // Build sight bridge. - prov := provider - if prov == "" { - prov = types.DetectProvider() + // Build sight bridge from the runtime-owned transport resolution. + transport, err := runtime.ResolveChatTransport(context.Background(), runtime.ChatTransportOpts{ + Selection: runtime.SelectionOpts{ + ProviderOverride: strings.TrimSpace(provider), + ModelOverride: reviewRunModel, + }, + }) + if err != nil { + _ = store.SetStatus(id, ReviewStatusFailed) + return silentErr(fmt.Errorf("resolve runtime transport: %w", err), "init bridge") + } + if transport.Provider == nil { + _ = store.SetStatus(id, ReviewStatusFailed) + return silentErr(fmt.Errorf("runtime transport unavailable for provider %q", transport.Selection.Provider), "init bridge") } - eyrieClient := types.NewClient(&types.ClientConfig{Provider: prov}) var opts []sightLib.Option if reviewRunModel != "" { @@ -104,7 +114,7 @@ func runReviewRun(_ *cobra.Command, args []string) error { opts = append(opts, sightLib.WithConcerns(concerns...)) } - bridge := hawkSight.NewBridge(eyrieClient, prov, opts...) + bridge := hawkSight.NewBridge(types.WrapClientProvider(transport.Provider), transport.Selection.Provider, opts...) if !bridge.Ready() { _ = store.SetStatus(id, ReviewStatusFailed) return silentErr(fmt.Errorf("sight bridge not ready"), "init bridge") diff --git a/cmd/review_test.go b/cmd/review_test.go index 42bfc76a..cf250521 100644 --- a/cmd/review_test.go +++ b/cmd/review_test.go @@ -9,11 +9,19 @@ import ( reviewcontracts "github.com/GrayCodeAI/hawk-core-contracts/review" contracts "github.com/GrayCodeAI/hawk-core-contracts/types" + "github.com/GrayCodeAI/hawk/internal/storage" "github.com/GrayCodeAI/hawk/internal/ui/icons" ) +func setReviewTestDirs(t *testing.T) string { + t.Helper() + root := t.TempDir() + storage.SetTestDirs(t, root) + return root +} + func TestReviewStore_CreateAndGet(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) @@ -43,7 +51,7 @@ func TestReviewStore_CreateAndGet(t *testing.T) { } func TestReviewStore_Update(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) @@ -83,7 +91,7 @@ func TestReviewStore_Update(t *testing.T) { } func TestReviewStore_GetBySHA(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) @@ -105,7 +113,7 @@ func TestReviewStore_GetBySHA(t *testing.T) { } func TestReviewStore_ListOpen(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) @@ -132,7 +140,7 @@ func TestReviewStore_ListOpen(t *testing.T) { } func TestReviewStore_Summary(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) @@ -162,7 +170,7 @@ func TestReviewStore_Summary(t *testing.T) { } func TestReviewStore_SetStatus(t *testing.T) { - dir := t.TempDir() + dir := setReviewTestDirs(t) os.MkdirAll(filepath.Join(dir, ".hawk"), 0o755) store, err := OpenReviewStore(dir) diff --git a/cmd/session_sync.go b/cmd/session_sync.go index b6d2a691..7c1e0479 100644 --- a/cmd/session_sync.go +++ b/cmd/session_sync.go @@ -9,35 +9,38 @@ import ( "github.com/GrayCodeAI/hawk/internal/engine" ) -// syncSessionFromPersistedSelection copies eyrie provider.json selection into the -// live session when the session fields are empty (status bar can show ActiveModel -// while the model field is still unset, which breaks deployment routing). -func syncSessionFromPersistedSelection(sess *engine.Session, settings hawkconfig.Settings) { +func explicitSelection(ctx context.Context) (provider, model string) { + if ctx == nil { + ctx = context.Background() + } + return strings.TrimSpace(hawkconfig.ActiveGateway(ctx)), strings.TrimSpace(hawkconfig.ActiveModel(ctx)) +} + +// syncSessionFromPersistedSelection copies explicit eyrie provider.json +// selection into the live session when the session fields are empty. +// It intentionally avoids runtime defaults so Hawk can preserve the +// "gateway selected, model still missing" setup state. +func syncSessionFromPersistedSelection(sess *engine.Session) { if sess == nil { return } - ctx := context.Background() - hawkconfig.SyncSelectionWithCredentials(ctx) + provider, model := explicitSelection(context.Background()) if strings.TrimSpace(sess.Model()) == "" { - model := strings.TrimSpace(hawkconfig.ActiveModel(ctx)) - if model != "" && hawkconfig.DeploymentRoutingEnabled(settings) { - model = hawkconfig.ResolveCanonicalModel(model) - } if model != "" { sess.SetModel(model) } } if strings.TrimSpace(sess.Provider()) == "" { - if provider := strings.TrimSpace(hawkconfig.ActiveProvider(ctx)); provider != "" { - sess.SetProvider(hawkconfig.NormalizeProviderForEngine(provider)) + if provider != "" { + sess.SetProvider(provider) } } } func (m *chatModel) syncSessionSelection() { - syncSessionFromPersistedSelection(m.session, m.settings) + syncSessionFromPersistedSelection(m.session) if m.session != nil { gw, model := m.sessionGatewayModel() applyLiveModelMetadata(m.session, gw, model) diff --git a/cmd/session_sync_test.go b/cmd/session_sync_test.go index c08499cb..c09659dd 100644 --- a/cmd/session_sync_test.go +++ b/cmd/session_sync_test.go @@ -27,8 +27,7 @@ func TestSyncSessionFromPersistedSelection_FillsEmptySessionModel(t *testing.T) _ = hawkconfig.SetActiveModel(ctx, "moonshotai/kimi-k2.6") sess := engine.NewSession("", "", "test", nil) - settings := hawkconfig.Settings{} - syncSessionFromPersistedSelection(sess, settings) + syncSessionFromPersistedSelection(sess) if got := sess.Model(); got != "moonshotai/kimi-k2.6" { t.Fatalf("model = %q, want moonshotai/kimi-k2.6", got) diff --git a/cmd/version_state_test.go b/cmd/version_state_test.go new file mode 100644 index 00000000..b7ba2037 --- /dev/null +++ b/cmd/version_state_test.go @@ -0,0 +1,27 @@ +package cmd + +import "testing" + +func preserveCLICompilerVersionState(t *testing.T) { + t.Helper() + + oldVersion := version + oldBuildDate := buildDate + t.Cleanup(func() { + version = oldVersion + buildDate = oldBuildDate + }) +} + +func preserveLibraryVersionState(t *testing.T) { + t.Helper() + + oldVersion := Version + oldCommit := Commit + oldDate := Date + t.Cleanup(func() { + Version = oldVersion + Commit = oldCommit + Date = oldDate + }) +} diff --git a/cmd/version_test.go b/cmd/version_test.go index 880f570f..e87d1ce1 100644 --- a/cmd/version_test.go +++ b/cmd/version_test.go @@ -3,6 +3,7 @@ package cmd import "testing" func TestSetVersion(t *testing.T) { + preserveCLICompilerVersionState(t) SetVersion("1.2.3") if version != "1.2.3" { t.Errorf("version = %q, want %q", version, "1.2.3") @@ -10,6 +11,7 @@ func TestSetVersion(t *testing.T) { } func TestSetBuildDate(t *testing.T) { + preserveCLICompilerVersionState(t) SetBuildDate("2026-01-01") if buildDate != "2026-01-01" { t.Errorf("buildDate = %q, want %q", buildDate, "2026-01-01") @@ -17,6 +19,7 @@ func TestSetBuildDate(t *testing.T) { } func TestVersionString(t *testing.T) { + preserveLibraryVersionState(t) Version = "test-ver" Commit = "abc123" Date = "2026-05-15" @@ -27,6 +30,7 @@ func TestVersionString(t *testing.T) { } func TestShortVersion(t *testing.T) { + preserveLibraryVersionState(t) Version = "0.1.0" got := ShortVersion() if got != "0.1.0" { diff --git a/external/eyrie b/external/eyrie index 7546bb6c..967ab850 160000 --- a/external/eyrie +++ b/external/eyrie @@ -1 +1 @@ -Subproject commit 7546bb6c282132402a9d66d8e6174dca98331f12 +Subproject commit 967ab850444727e9dd6c73f1be15ebd7b4320610 diff --git a/internal/bridge/sight/bridge.go b/internal/bridge/sight/bridge.go index 44c638f5..a85667a7 100644 --- a/internal/bridge/sight/bridge.go +++ b/internal/bridge/sight/bridge.go @@ -12,13 +12,13 @@ import ( // EyrieAdapter implements sight's Provider interface using hawk's eyrie client. // It translates between sight.Message/sight.ChatOpts and Hawk runtime DTOs. type EyrieAdapter struct { - client *types.EyrieClient + client types.ChatProvider provider string } // NewEyrieAdapter creates an adapter that satisfies sight.Provider using // the given eyrie client and provider name (e.g. "anthropic", "openai"). -func NewEyrieAdapter(c *types.EyrieClient, provider string) *EyrieAdapter { +func NewEyrieAdapter(c types.ChatProvider, provider string) *EyrieAdapter { return &EyrieAdapter{client: c, provider: provider} } @@ -77,13 +77,13 @@ type Bridge struct { // NewBridge creates a bridge to the sight library using the given Hawk // transport client and provider name. Additional sight options (model, // concerns, etc.) are applied to all operations. -func NewBridge(c *types.EyrieClient, provider string, opts ...sightLib.Option) *Bridge { +func NewBridge(c types.ChatProvider, provider string, opts ...sightLib.Option) *Bridge { b := &Bridge{} b.init(c, provider, opts...) return b } -func (b *Bridge) init(c *types.EyrieClient, provider string, opts ...sightLib.Option) { +func (b *Bridge) init(c types.ChatProvider, provider string, opts ...sightLib.Option) { if c == nil { return } diff --git a/internal/config/catalog_api.go b/internal/config/catalog_api.go index 489689d5..e9127506 100644 --- a/internal/config/catalog_api.go +++ b/internal/config/catalog_api.go @@ -6,10 +6,11 @@ import ( "strings" "github.com/GrayCodeAI/eyrie/catalog" - "github.com/GrayCodeAI/eyrie/catalog/registry" "github.com/GrayCodeAI/eyrie/runtime" ) +type GatewayStatus = runtime.GatewayStatus + // CompiledCatalogV1 loads the eyrie catalog from cache or bootstrap wiring (no network). func CompiledCatalogV1() *catalog.CompiledCatalogV1 { return compiledCatalogOrBootstrap() @@ -42,7 +43,7 @@ func AllCatalogProviders() []string { seen := map[string]bool{} var out []string for _, id := range catalog.ProviderIDsFromCompiled(compiled) { - p := catalogProviderID(id) + p := runtime.CatalogProviderID(id) if p == "" || seen[p] { continue } @@ -56,59 +57,21 @@ func AllCatalogProviders() []string { // AllSetupGateways returns gateway IDs where users paste API keys (eyrie registry only). // Aggregator owner slugs from OpenRouter/CanopyWave catalogs (ai21, alibaba, …) are excluded. func AllSetupGateways() []string { - specs := registry.CredentialRegistry() - out := make([]string, len(specs)) - for i, s := range specs { - out[i] = s.ProviderID - } - return out -} - -// setupGatewayRegistryID maps catalog/engine aliases to credential registry gateway ids. -// Most registry IDs use underscores for multi-word plans (e.g. xiaomi_mimo_token_plan). -// Z.AI uses underscore naming for uniformity with Xiaomi/MiniMax plan splits: zai_payg and zai_coding (no legacy aliases). -func setupGatewayRegistryID(provider string) string { - p := strings.ToLower(strings.TrimSpace(provider)) - switch p { - case "google": - return "gemini" - case "xai": - return "grok" - case "zai_payg": - return "zai_payg" - case "zai_coding": - return "zai_coding" - case "xiaomi_mimo", "xiaomi-mimo": - return "xiaomi_mimo_payg" - case "xiaomi_mimo_token_plan", "xiaomi-mimo-token-plan": - return "xiaomi_mimo_token_plan" - case "xiaomi_mimo_payg", "xiaomi-mimo-payg": - return "xiaomi_mimo_payg" - default: - return p - } + return runtime.SetupGateways() } // SetupGatewayCredentialEnv returns the registry env var for a setup gateway (e.g. XIAOMI_MIMO_PAYG_API_KEY). func SetupGatewayCredentialEnv(providerID string) string { - spec, ok := registry.DefaultRegistry.Get(setupGatewayRegistryID(providerID)) - if !ok || !spec.RequiresKey { - return "" - } - return strings.TrimSpace(spec.CredentialEnv) + return runtime.SetupGatewayCredentialEnv(providerID) } // IsSetupGateway reports whether id is a registered setup gateway. func IsSetupGateway(providerID string) bool { - return catalog.IsSetupGateway(setupGatewayRegistryID(providerID)) + return runtime.IsSetupGateway(providerID) } func GatewayDisplayName(gatewayID string) string { - gatewayID = setupGatewayRegistryID(gatewayID) - if name := registry.DisplayName(gatewayID); name != gatewayID { - return name - } - return gatewayID + return runtime.GatewayDisplayName(gatewayID) } // ActiveGateway returns the user's setup gateway (never an aggregator owner slug like moonshotai). @@ -116,15 +79,17 @@ func ActiveGateway(ctx context.Context) string { if ctx == nil { ctx = context.Background() } - if p := catalogProviderID(ActiveProvider(ctx)); catalog.IsSetupGateway(p) { - return setupGatewayRegistryID(p) - } - if m := strings.TrimSpace(ActiveModel(ctx)); m != "" { - if gw := GatewayForModel(m); gw != "" { - return setupGatewayRegistryID(gw) - } + return runtime.ActiveGateway(ctx) +} + +func GatewayStatuses(ctx context.Context, activeProvider, activeModel string) []GatewayStatus { + if ctx == nil { + ctx = context.Background() } - return "" + return runtime.GatewayStatuses(ctx, runtime.GatewayStatusOpts{ + ActiveProvider: activeProvider, + ActiveModel: activeModel, + }) } // GatewayForModel resolves the setup gateway for a model id. @@ -134,20 +99,7 @@ func GatewayForModel(modelID string) string { // ShouldClearSelectionAfterCredentialRemove reports whether provider/model should reset. func ShouldClearSelectionAfterCredentialRemove(ctx context.Context, removedProvider string) bool { - if ctx == nil { - ctx = context.Background() - } - removedProvider = catalogProviderID(removedProvider) - if !HasConfiguredDeployment(ctx) { - return true - } - if gw := ActiveGateway(ctx); gw == removedProvider { - return true - } - if m := strings.TrimSpace(ActiveModel(ctx)); m != "" && GatewayForModel(m) == removedProvider { - return true - } - return false + return runtime.ShouldClearSelectionAfterCredentialRemove(ctx, removedProvider) } // ClearActiveSelection removes persisted provider/model from provider.json. @@ -163,63 +115,16 @@ func SyncSelectionWithCredentials(ctx context.Context) { if ctx == nil { ctx = context.Background() } - if !HasConfiguredDeployment(ctx) { - if HasSelectedModel() || strings.TrimSpace(ActiveProvider(ctx)) != "" { - _ = ClearActiveSelection(ctx) - } - return - } - gw := ActiveGateway(ctx) - if gw == "" { - return - } - if !credentialConfiguredForGateway(ctx, gw) { - _ = ClearActiveSelection(ctx) - } -} - -func credentialConfiguredForGateway(ctx context.Context, gateway string) bool { - ensureCredSnapshot(ctx) - uiCacheMu.RLock() - defer uiCacheMu.RUnlock() - if !credValid { - return false - } - gateway = setupGatewayRegistryID(gateway) - return credConfigured[gateway] + runtime.SyncSelectionWithCredentials(ctx) } func DefaultModelForProvider(provider string) string { - compiled := CompiledCatalogV1() - if compiled != nil { - if id := catalog.FirstModelForProvider(compiled, provider); id != "" { - return id - } - } - // All providers are fully dynamic — try live API if credentials are available. - if APIKeyForProvider(provider) != "" { - models, err := runtime.ListModels(context.Background(), runtime.ListModelsOpts{ - ProviderID: provider, - Source: runtime.ListSourceAuto, - }) - if err == nil && len(models) > 0 { - return models[0].ID - } - } - return "" + return runtime.DefaultModelForProvider(context.Background(), provider) } // CachedModelCountForProvider returns model count from the on-disk catalog only (no network). func CachedModelCountForProvider(provider string) int { - provider = setupGatewayRegistryID(provider) - if provider == "" { - return 0 - } - compiled := CompiledCatalogV1() - if compiled == nil { - return 0 - } - return len(catalog.ModelEntriesForProvider(compiled, provider)) + return runtime.CachedModelCountForProvider(context.Background(), provider) } func ModelIDsForProvider(provider string) ([]string, error) { @@ -262,7 +167,7 @@ func ProviderOfModel(modelName string) string { } if canonical, ok := compiled.CanonicalModelForAliasOrID(modelName); ok { if model := compiled.ModelsByID[canonical]; model.ID != "" { - return catalogProviderID(model.ProviderID) + return runtime.CatalogProviderID(model.ProviderID) } } return "" @@ -323,7 +228,7 @@ func ProviderIDForDeployment(deploymentID string) string { if !ok { return "" } - return catalogProviderID(dep.ProviderID) + return runtime.CatalogProviderID(dep.ProviderID) } // PrimaryAPIKeyEnvForDeployment returns the env var name for a deployment's API key. @@ -334,27 +239,3 @@ func PrimaryAPIKeyEnvForDeployment(deploymentID string) string { } return catalog.PrimaryAPIKeyEnvForDeployment(compiled, deploymentID) } - -// ConfigProviderList returns provider names for the /config UI from catalog + custom providers. -func ConfigProviderList(custom []CustomProviderConfig) []string { - seen := map[string]bool{} - var out []string - for _, p := range AllCatalogProviders() { - engine := NormalizeProviderForEngine(p) - if engine == "" || seen[engine] { - continue - } - seen[engine] = true - out = append(out, engine) - } - for _, cp := range custom { - name := strings.TrimSpace(cp.Name) - if name == "" || seen[name] { - continue - } - seen[name] = true - out = append(out, name) - } - sort.Strings(out) - return out -} diff --git a/internal/config/catalog_gateways_test.go b/internal/config/catalog_gateways_test.go index 51fc00b7..ecd4dec3 100644 --- a/internal/config/catalog_gateways_test.go +++ b/internal/config/catalog_gateways_test.go @@ -50,19 +50,25 @@ func containsString(list []string, s string) bool { return false } -func TestSetupGatewayRegistryID_PreservesUnderscores(t *testing.T) { - if got := setupGatewayRegistryID("xiaomi_mimo_payg"); got != "xiaomi_mimo_payg" { +func TestSetupGatewayID_PreservesUnderscores(t *testing.T) { + if got := ActiveProviderID("xiaomi_mimo_payg"); got != "xiaomi_mimo_payg" { t.Fatalf("xiaomi_mimo_payg = %q", got) } - if got := setupGatewayRegistryID("xiaomi_mimo"); got != "xiaomi_mimo_payg" { + if got := ActiveProviderID("xiaomi_mimo"); got != "xiaomi_mimo_payg" { t.Fatalf("legacy xiaomi_mimo = %q", got) } - if got := setupGatewayRegistryID("zai_payg"); got != "zai_payg" { + if got := ActiveProviderID("zai_payg"); got != "zai_payg" { t.Fatalf("zai_payg = %q", got) } - if got := setupGatewayRegistryID("zai_coding"); got != "zai_coding" { + if got := ActiveProviderID("z-ai-payg"); got != "zai_payg" { + t.Fatalf("z-ai-payg = %q", got) + } + if got := ActiveProviderID("zai_coding"); got != "zai_coding" { t.Fatalf("zai_coding = %q", got) } + if got := ActiveProviderID("z-ai-coding"); got != "zai_coding" { + t.Fatalf("z-ai-coding = %q", got) + } } func TestCredentialInferenceForProvider_XiaomiPayg(t *testing.T) { diff --git a/internal/config/catalog_health.go b/internal/config/catalog_health.go index fc549d0b..7f467786 100644 --- a/internal/config/catalog_health.go +++ b/internal/config/catalog_health.go @@ -3,12 +3,12 @@ package config import ( "context" "fmt" - "os" "strings" "sync" "time" "github.com/GrayCodeAI/eyrie/catalog" + "github.com/GrayCodeAI/hawk/internal/env" ) var ( @@ -147,7 +147,7 @@ func EnsureCatalogAvailable(ctx context.Context) error { // CatalogCachePathForDisplay returns the path users should care about. func CatalogCachePathForDisplay() string { - if p := strings.TrimSpace(os.Getenv("EYRIE_MODEL_CATALOG_PATH")); p != "" { + if p := strings.TrimSpace(env.Getenv("EYRIE_MODEL_CATALOG_PATH")); p != "" { return p } return catalog.DefaultCachePath() diff --git a/internal/config/catalog_startup.go b/internal/config/catalog_startup.go index eec77d12..2c9462e5 100644 --- a/internal/config/catalog_startup.go +++ b/internal/config/catalog_startup.go @@ -10,6 +10,7 @@ import ( "time" "github.com/GrayCodeAI/eyrie/credentials" + "github.com/GrayCodeAI/hawk/internal/env" ) type gatewayModelCount struct { @@ -222,7 +223,7 @@ func catalogRefreshFailureHint(ctx context.Context) string { } func autoRefreshCatalogEnabled() bool { - switch strings.ToLower(strings.TrimSpace(os.Getenv("HAWK_AUTO_REFRESH_CATALOG"))) { + switch strings.ToLower(strings.TrimSpace(env.Getenv("HAWK_AUTO_REFRESH_CATALOG"))) { case "0", "false", "no", "off": return false default: @@ -231,7 +232,7 @@ func autoRefreshCatalogEnabled() bool { } func catalogRefreshAlways() bool { - switch strings.ToLower(strings.TrimSpace(os.Getenv("HAWK_CATALOG_REFRESH_ALWAYS"))) { + switch strings.ToLower(strings.TrimSpace(env.Getenv("HAWK_CATALOG_REFRESH_ALWAYS"))) { case "1", "true", "yes", "on": return true default: diff --git a/internal/config/credentials_store.go b/internal/config/credentials_store.go index 34127a2f..4b006775 100644 --- a/internal/config/credentials_store.go +++ b/internal/config/credentials_store.go @@ -33,7 +33,7 @@ func PrepareCredentialDiscovery(ctx context.Context) { if ctx == nil { ctx = context.Background() } - ApplyXiaomiTokenPlanRegionEnv(ctx) + runtime.PrepareCredentialDiscovery(ctx) } // ModelOption is one hawk /config model row. @@ -116,15 +116,7 @@ func SaveCredential(ctx context.Context, inference CredentialInference, secret s // HasStoredCredentialForProvider reports whether the OS secret store has a key for this gateway. func HasStoredCredentialForProvider(ctx context.Context, providerID string) bool { - if ctx == nil { - ctx = context.Background() - } - for _, envKey := range credentialEnvKeysForTarget(providerID) { - if credentials.HasSecret(ctx, envKey) { - return true - } - } - return false + return runtime.HasStoredCredential(ctx, providerID) } // ConfiguredCredentialProviders returns setup gateways with a stored API key. @@ -212,7 +204,7 @@ func maskCredentialSecret(secret string) string { // CredentialInferenceForProvider returns save metadata for a gateway chosen in /config. func CredentialInferenceForProvider(providerID string) (CredentialInference, error) { - providerID = setupGatewayRegistryID(providerID) + providerID = runtime.SetupGatewayID(providerID) inf, err := runtime.InferenceForProvider(providerID) if err != nil { return CredentialInference{}, err @@ -229,30 +221,7 @@ func credentialEnvKeysForTarget(target string) []string { if strings.Contains(target, "_") && strings.ToUpper(target) == target { return []string{strings.TrimSpace(target)} } - provider := setupGatewayRegistryID(target) - seen := map[string]struct{}{} - var keys []string - add := func(k string) { - k = strings.TrimSpace(k) - if k == "" { - return - } - if _, ok := seen[k]; ok { - return - } - seen[k] = struct{}{} - keys = append(keys, k) - } - if env := SetupGatewayCredentialEnv(provider); env != "" { - add(env) - } - if primary := ProviderAPIKeyEnv(provider); primary != "" { - add(primary) - } - for _, alt := range providerCredentialEnvAliases(provider) { - add(alt) - } - return keys + return runtime.CredentialEnvKeys(target) } // LocalCredentialInference returns setup metadata for no-key providers (e.g. Ollama). diff --git a/internal/config/deployment.go b/internal/config/deployment.go index 165d68ea..04ed40a9 100644 --- a/internal/config/deployment.go +++ b/internal/config/deployment.go @@ -1,26 +1,12 @@ package config import ( - "os" - "strings" + "context" - eyriecfg "github.com/GrayCodeAI/eyrie/config" - "github.com/GrayCodeAI/eyrie/setup" + "github.com/GrayCodeAI/eyrie/runtime" ) -// DeploymentRoutingEnabled decides whether hawk uses catalog-backed deployment routing -// (same rules as eyrie CLI). HAWK_DEPLOYMENT_ROUTING overrides; otherwise settings flag, -// otherwise provider.json shape via eyrie/setup. +// DeploymentRoutingEnabled delegates deployment-routing policy ownership to Eyrie runtime. func DeploymentRoutingEnabled(s Settings) bool { - switch strings.ToLower(strings.TrimSpace(os.Getenv("HAWK_DEPLOYMENT_ROUTING"))) { - case "1", "true", "yes", "on": - return true - case "0", "false", "no", "off": - return false - } - if s.DeploymentRouting != nil { - return *s.DeploymentRouting - } - cfg := eyriecfg.LoadProviderConfig("") - return setup.UseDeploymentRouting(cfg) + return runtime.DeploymentRoutingEnabled(context.Background(), s.DeploymentRouting) } diff --git a/internal/config/deployment_status.go b/internal/config/deployment_status.go index a1110557..89ffbc55 100644 --- a/internal/config/deployment_status.go +++ b/internal/config/deployment_status.go @@ -6,40 +6,26 @@ import ( "strings" eyriecfg "github.com/GrayCodeAI/eyrie/config" - "github.com/GrayCodeAI/eyrie/setup" + "github.com/GrayCodeAI/eyrie/runtime" ) // ResolveCanonicalModel maps aliases and native IDs to catalog canonical model IDs. func ResolveCanonicalModel(model string) string { - model = strings.TrimSpace(model) - if model == "" { - return "" - } - compiled, err := loadEyrieCatalogV1(context.Background(), false) - if err != nil || compiled == nil { - return model - } - if canonical, ok := compiled.CanonicalModelForAliasOrID(model); ok { - return canonical - } - if strings.Contains(model, "/") { - return model - } - return model + return runtime.ResolveCanonicalModel(context.Background(), strings.TrimSpace(model)) } // DeploymentStatusReport returns hawk deployment routing diagnostics. func DeploymentStatusReport(ctx context.Context, activeModel string) (string, error) { - report, err := setup.DeploymentStatus(ctx, activeModel) + report, err := runtime.DeploymentStatus(ctx, activeModel) if err != nil { return "", err } - return setup.FormatStatus(report), nil + return runtime.FormatDeploymentStatus(report), nil } // RoutingPreviewJSON returns effective routing for a model (eyrie routing JSON preview). func RoutingPreviewJSON(ctx context.Context, model string) (string, error) { - return setup.RoutingPreview(ctx, model) + return runtime.RoutingPreview(ctx, model) } // MigrateProviderConfig upgrades ~/.hawk/provider.json to deployment v2 in place. diff --git a/internal/config/eyrie_apply.go b/internal/config/eyrie_apply.go index f52981ae..44debcdb 100644 --- a/internal/config/eyrie_apply.go +++ b/internal/config/eyrie_apply.go @@ -2,11 +2,10 @@ package config import ( "context" - "fmt" "time" - "github.com/GrayCodeAI/eyrie/catalog" eyriecfg "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/runtime" "github.com/GrayCodeAI/eyrie/setup" ) @@ -14,14 +13,7 @@ import ( func ApplyEyrieCredentialsForProvider(ctx context.Context, providerID string) (*setup.ApplyCredentialsResult, error) { ctx, cancel := context.WithTimeout(ctx, 90*time.Second) defer cancel() - PrepareCredentialDiscovery(ctx) - if providerID == ProviderXiaomiTokenPlan { - ApplyXiaomiTokenPlanRegionEnv(ctx) - } - if providerID == ProviderZAICoding { - ApplyZAIRegionEnv(ctx) - } - result, err := setup.ApplyCredentialsForProvider(ctx, providerID, eyriecfg.DiscoveryCredentials(ctx)) + result, err := runtime.ApplyCredentialsForProvider(ctx, providerID) if err != nil { return nil, err } @@ -46,33 +38,9 @@ func ApplyEyrieCredentials(ctx context.Context) (*setup.ApplyCredentialsResult, func RefreshGatewayCatalog(ctx context.Context, providerID string) (string, error) { ctx, cancel := context.WithTimeout(ctx, 90*time.Second) defer cancel() - PrepareCredentialDiscovery(ctx) - if providerID == ProviderXiaomiTokenPlan { - ApplyXiaomiTokenPlanRegionEnv(ctx) - } - if providerID == ProviderZAICoding { - ApplyZAIRegionEnv(ctx) - } - result, err := setup.DiscoverProviderCatalog(ctx, providerID, eyriecfg.DiscoveryCredentials(ctx)) - if err != nil { - return "", err - } - n := 0 - if result.Compiled != nil { - n = len(catalog.ModelEntriesForProvider(result.Compiled, providerID)) - } - return fmt.Sprintf("Refreshed %s (%d models)", providerID, n), nil + return runtime.RefreshGatewayCatalog(ctx, providerID) } func FormatApplyCredentialsSummary(result *setup.ApplyCredentialsResult) string { - if result == nil || result.Catalog == nil || result.Catalog.Compiled == nil { - return "Eyrie credentials applied" - } - nModels := len(result.Catalog.Compiled.ModelsByID) - nDeps := 0 - if result.ProviderConfig != nil { - nDeps = len(result.ProviderConfig.Deployments) - } - return fmt.Sprintf("Eyrie: %d models, %d deployments configured, routing updated → %s", - nModels, nDeps, result.ProviderConfigPath) + return runtime.FormatApplyCredentialsSummary(result) } diff --git a/internal/config/eyrie_selection.go b/internal/config/eyrie_selection.go index 3a12ecc1..420dbb89 100644 --- a/internal/config/eyrie_selection.go +++ b/internal/config/eyrie_selection.go @@ -20,7 +20,12 @@ func ActiveProvider(ctx context.Context) string { if ctx == nil { ctx = context.Background() } - return runtime.ActiveProvider(ctx) + return runtime.ActiveProviderID(runtime.ActiveProvider(ctx)) +} + +// ActiveProviderID canonicalizes a host-facing provider/gateway id through Eyrie runtime. +func ActiveProviderID(provider string) string { + return runtime.ActiveProviderID(provider) } // SetActiveModel persists model selection to eyrie provider.json. @@ -36,7 +41,7 @@ func SetActiveProvider(ctx context.Context, provider string) error { if ctx == nil { ctx = context.Background() } - return runtime.SetActiveProvider(ctx, provider) + return runtime.SetActiveProvider(ctx, runtime.ActiveProviderID(provider)) } // migrateLegacyModelProvider moves model/provider from ~/.hawk/settings.json into eyrie once. diff --git a/internal/config/settings.go b/internal/config/settings.go index 00bbdfe4..9cd5e569 100644 --- a/internal/config/settings.go +++ b/internal/config/settings.go @@ -491,7 +491,7 @@ func ProviderAPIKeyEnv(provider string) string { if compiled == nil { return "" } - return catalog.PrimaryAPIKeyEnvForProvider(compiled, catalogProviderID(provider)) + return catalog.PrimaryAPIKeyEnvForProvider(compiled, runtime.CatalogProviderID(provider)) } // EnvKeyStatus returns set, empty, or local from the OS credential store. @@ -500,7 +500,7 @@ func EnvKeyStatus(provider string) string { if compiled == nil { return "empty" } - provider = catalogProviderID(provider) + provider = runtime.CatalogProviderID(provider) envs := catalog.APIKeyEnvsForProvider(compiled, provider) if len(envs) == 0 { return "local" @@ -546,14 +546,14 @@ func APIKeyForProvider(provider string) string { if compiled == nil { return "" } - provider = catalogProviderID(provider) + provider = runtime.CatalogProviderID(provider) ctx := context.Background() for _, env := range catalog.APIKeyEnvsForProvider(compiled, provider) { if v := credentials.LookupSecret(ctx, env); v != "" { return v } } - for _, env := range providerCredentialEnvAliases(provider) { + for _, env := range runtime.CredentialEnvKeys(provider) { if v := credentials.LookupSecret(ctx, env); v != "" { return v } @@ -562,33 +562,18 @@ func APIKeyForProvider(provider string) string { } func providerCredentialEnvAliases(provider string) []string { - switch strings.ToLower(provider) { - case "anthropic": - return []string{"CLAUDE_API_KEY"} - case "gemini", "google": - return []string{"GOOGLE_API_KEY"} - case "grok", "xai": - return nil - case "xiaomi_mimo", "xiaomi_mimo_payg": - return []string{"XIAOMI_MIMO_PAYG_API_KEY"} - default: - return nil - } -} - -// NormalizeProviderForEngine maps hawk provider aliases to eyrie canonical names. -// This is the boundary where hawk names become engine/eyrie names. -func NormalizeProviderForEngine(provider string) string { - if gw := setupGatewayRegistryID(provider); catalog.IsSetupGateway(gw) { - return gw - } - p := normalizeProviderName(provider) - switch p { - case "xai": - return "grok" // eyrie calls it "grok", env var is XAI_API_KEY - default: - return p + primary := strings.TrimSpace(ProviderAPIKeyEnv(provider)) + seen := map[string]bool{} + var out []string + for _, env := range runtime.CredentialEnvKeys(provider) { + env = strings.TrimSpace(env) + if env == "" || env == primary || seen[env] { + continue + } + seen[env] = true + out = append(out, env) } + return out } // ───────────────────────────────────────────────────────────── @@ -598,7 +583,7 @@ func NormalizeProviderForEngine(provider string) string { // FetchModelsForProvider returns models from the eyrie catalog (dynamic; no hawk hardcoded lists). // RefreshModelCatalogV1 is the explicit network refresh boundary. func FetchModelsForProvider(provider string) ([]catalog.ModelCatalogEntry, error) { - provider = catalogProviderID(provider) + provider = runtime.CatalogProviderID(provider) if provider == "" { return nil, fmt.Errorf("no provider specified") } @@ -615,7 +600,7 @@ func FetchModelsForProvider(provider string) ([]catalog.ModelCatalogEntry, error } // Custom OpenAI-compatible providers: single model from settings, not hawk catalog data. for _, cp := range LoadSettings().CustomProviders { - if NormalizeProviderForEngine(cp.Name) != provider { + if runtime.ActiveProviderID(cp.Name) != provider { continue } if id := strings.TrimSpace(cp.Model); id != "" { @@ -659,17 +644,3 @@ func loadEyrieCatalogV1(ctx context.Context, refreshRemote bool) (*catalog.Compi RequireCache: false, }) } - -func catalogProviderID(provider string) string { - if gw := setupGatewayRegistryID(provider); catalog.IsSetupGateway(gw) { - return gw - } - switch NormalizeProviderForEngine(provider) { - case "gemini": - return "google" - case "grok": - return "xai" - default: - return NormalizeProviderForEngine(provider) - } -} diff --git a/internal/config/settings_test.go b/internal/config/settings_test.go index 96f37a15..85c84ad3 100644 --- a/internal/config/settings_test.go +++ b/internal/config/settings_test.go @@ -181,7 +181,7 @@ func TestNormalizeProviderName(t *testing.T) { } } -func TestNormalizeProviderForEngine_XiaomiAliases(t *testing.T) { +func TestActiveProviderID_XiaomiAliases(t *testing.T) { t.Parallel() tests := []struct { input string @@ -196,9 +196,9 @@ func TestNormalizeProviderForEngine_XiaomiAliases(t *testing.T) { } for _, tt := range tests { - got := NormalizeProviderForEngine(tt.input) + got := ActiveProviderID(tt.input) if got != tt.expected { - t.Errorf("NormalizeProviderForEngine(%q) = %q, want %q", tt.input, got, tt.expected) + t.Errorf("ActiveProviderID(%q) = %q, want %q", tt.input, got, tt.expected) } } } @@ -279,7 +279,7 @@ func TestSettings_UnmarshalJSON_SnakeCase(t *testing.T) { } } -func TestNormalizeProviderForEngine(t *testing.T) { +func TestActiveProviderID(t *testing.T) { t.Parallel() tests := []struct { input string @@ -289,12 +289,16 @@ func TestNormalizeProviderForEngine(t *testing.T) { {"openai", "openai"}, {"xai", "grok"}, {"XAI", "grok"}, + {"zai_coding", "zai_coding"}, + {"z-ai-coding", "zai_coding"}, + {"zai_payg", "zai_payg"}, + {"z-ai-payg", "zai_payg"}, } for _, tt := range tests { - got := NormalizeProviderForEngine(tt.input) + got := ActiveProviderID(tt.input) if got != tt.expected { - t.Errorf("NormalizeProviderForEngine(%q) = %q, want %q", tt.input, got, tt.expected) + t.Errorf("ActiveProviderID(%q) = %q, want %q", tt.input, got, tt.expected) } } } diff --git a/internal/config/setup_status.go b/internal/config/setup_status.go index fdcd5352..6df7ac28 100644 --- a/internal/config/setup_status.go +++ b/internal/config/setup_status.go @@ -4,7 +4,7 @@ import ( "context" "strings" - eyriecfg "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/runtime" ) // SetupState is a single evaluation of first-run /config requirements. @@ -21,7 +21,7 @@ func EvaluateSetup(ctx context.Context) SetupState { ctx = context.Background() } PrepareCredentialDiscovery(ctx) - return evaluateSetupFrom(hasConfiguredDeployment(ctx), HasSelectedModel()) + return evaluateSetupFrom(runtime.HasConfiguredDeployment(ctx), HasSelectedModel()) } // EvaluateSetupCached uses the in-memory credential snapshot (fast; for TUI hot paths). @@ -40,7 +40,8 @@ func evaluateSetupFrom(hasCreds, hasModel bool) SetupState { } switch { case !hasCreds: - // Splash uses footer "Press Enter to set up and start" only — no duplicate line here. + // Splash uses footer guidance only until credentials exist. + st.Hint = "" case !hasModel: st.Hint = "Almost ready: /config → finish setup" } @@ -52,24 +53,7 @@ func HasConfiguredDeployment(ctx context.Context) bool { if ctx == nil { ctx = context.Background() } - PrepareCredentialDiscovery(ctx) - return hasConfiguredDeployment(ctx) -} - -func hasConfiguredDeployment(ctx context.Context) bool { - RefreshConfigCredSnapshot(ctx) - if hasConfiguredDeploymentCached(ctx) { - return true - } - env := eyriecfg.DiscoveryEnvMap(ctx) - if len(env) > 0 { - for _, v := range env { - if strings.TrimSpace(v) != "" { - return true - } - } - } - return false + return runtime.HasConfiguredDeployment(ctx) } // HasSelectedModel reports whether eyrie provider.json has a selected model. diff --git a/internal/config/xiaomi_setup.go b/internal/config/xiaomi_setup.go index 3819c0b4..4b93ba99 100644 --- a/internal/config/xiaomi_setup.go +++ b/internal/config/xiaomi_setup.go @@ -2,76 +2,28 @@ package config import ( "context" - "os" - "strings" - "github.com/GrayCodeAI/eyrie/catalog/xiaomi" - eyriecfg "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/runtime" ) const ProviderXiaomiTokenPlan = "xiaomi_mimo_token_plan" // NeedsXiaomiTokenPlanRegion reports whether the Token Plan gateway still needs a cluster pick. func NeedsXiaomiTokenPlanRegion(providerID string) bool { - if strings.TrimSpace(providerID) != ProviderXiaomiTokenPlan { - return false - } - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return true - } - _, err := xiaomi.NormalizeRegion(cfg.XiaomiMimoTokenPlanRegion) - return err != nil + return runtime.GatewayNeedsRegion(providerID) } // SetXiaomiTokenPlanRegion persists region (cn, sgp, ams) and syncs env for probe/discovery. func SetXiaomiTokenPlanRegion(region string) error { - normalized, err := xiaomi.NormalizeRegion(region) - if err != nil { - return err - } - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - cfg = &eyriecfg.ProviderConfig{} - } - cfg.XiaomiMimoTokenPlanRegion = string(normalized) - if saveErr := eyriecfg.SaveProviderConfig(cfg, ""); saveErr != nil { - return saveErr - } - _ = os.Setenv(eyriecfg.EnvXiaomiTokenPlanRegion, string(normalized)) - base, err := eyriecfg.ResolveXiaomiOpenAIBase(ProviderXiaomiTokenPlan, cfg) - if err == nil && base != "" { - _ = os.Setenv(eyriecfg.EnvXiaomiTokenPlanBaseURL, base) - cfg.XiaomiMimoTokenPlanBaseURL = base - _ = eyriecfg.SaveProviderConfig(cfg, "") - } - return nil + return runtime.SetGatewayRegion(ProviderXiaomiTokenPlan, region) } // XiaomiTokenPlanRegionLabel returns the saved cluster id for UI (cn, sgp, ams) or "" if unset. func XiaomiTokenPlanRegionLabel() string { - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return "" - } - r, err := xiaomi.NormalizeRegion(cfg.XiaomiMimoTokenPlanRegion) - if err != nil { - return "" - } - return string(r) + return runtime.GatewayRegionLabel(ProviderXiaomiTokenPlan) } // ApplyXiaomiTokenPlanRegionEnv sets process env from provider.json before credential probe. func ApplyXiaomiTokenPlanRegionEnv(ctx context.Context) { - _ = ctx - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return - } - if r := strings.TrimSpace(cfg.XiaomiMimoTokenPlanRegion); r != "" { - _ = os.Setenv(eyriecfg.EnvXiaomiTokenPlanRegion, r) - } - if base, err := eyriecfg.ResolveXiaomiOpenAIBase(ProviderXiaomiTokenPlan, cfg); err == nil && base != "" { - _ = os.Setenv(eyriecfg.EnvXiaomiTokenPlanBaseURL, base) - } + runtime.ApplyGatewayEnv(ctx, ProviderXiaomiTokenPlan) } diff --git a/internal/config/zai_setup.go b/internal/config/zai_setup.go index ae8a4b23..4d19793f 100644 --- a/internal/config/zai_setup.go +++ b/internal/config/zai_setup.go @@ -2,11 +2,8 @@ package config import ( "context" - "os" - "strings" - "github.com/GrayCodeAI/eyrie/catalog/zai" - eyriecfg "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/runtime" ) const ( @@ -16,107 +13,21 @@ const ( // NeedsZAIRegion reports whether the Z.AI gateway still needs a region pick for the chosen plan. func NeedsZAIRegion(providerID string) bool { - p := strings.TrimSpace(providerID) - if p != ProviderZAICoding { - return false - } - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return true - } - region := zaiRegionFromConfig(cfg, p) - _, err := zai.NormalizeRegion(region) - return err != nil -} - -func zaiRegionFromConfig(cfg *eyriecfg.ProviderConfig, providerID string) string { - if cfg == nil { - return "" - } - if providerID == ProviderZAICoding { - return cfg.ZAICodingRegion - } - return cfg.ZAIRegion + return runtime.GatewayNeedsRegion(providerID) } // SetZAIRegion persists the region (international or cn) for the given Z.AI gateway and syncs env + derived base. func SetZAIRegion(providerID, region string) error { - normalized, err := zai.NormalizeRegion(region) - if err != nil { - return err - } - - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - cfg = &eyriecfg.ProviderConfig{} - } - - if providerID == ProviderZAICoding { - cfg.ZAICodingRegion = string(normalized) - } else { - cfg.ZAIRegion = string(normalized) - } - - if saveErr := eyriecfg.SaveProviderConfig(cfg, ""); saveErr != nil { - return saveErr - } - - _ = os.Setenv("ZAI_REGION", string(normalized)) - - plan, _ := zai.PlanForProvider(providerID) - base, err := zai.ResolveOpenAIBase(plan, normalized, "") - if err == nil && base != "" { - if providerID == ProviderZAICoding { - _ = os.Setenv("ZAI_CODING_BASE_URL", base) - cfg.ZAICodingBaseURL = base - } else { - _ = os.Setenv("ZAI_BASE_URL", base) - cfg.ZAIBaseURL = base - } - _ = eyriecfg.SaveProviderConfig(cfg, "") - } - return nil + return runtime.SetGatewayRegion(providerID, region) } // ZAIRegionLabel returns the saved region label or "". func ZAIRegionLabel(providerID string) string { - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return "" - } - r := zaiRegionFromConfig(cfg, providerID) - norm, err := zai.NormalizeRegion(r) - if err != nil { - return "" - } - return string(norm) + return runtime.GatewayRegionLabel(providerID) } // ApplyZAIRegionEnv sets process envs from provider.json before probe/fetch/chat. func ApplyZAIRegionEnv(ctx context.Context) { - _ = ctx - cfg := eyriecfg.LoadProviderConfig("") - if cfg == nil { - return - } - - // General - if r := strings.TrimSpace(cfg.ZAIRegion); r != "" { - _ = os.Setenv("ZAI_REGION", r) - plan := zai.PlanGeneral - norm, _ := zai.NormalizeRegion(r) - if base, err := zai.ResolveOpenAIBase(plan, norm, cfg.ZAIBaseURL); err == nil && base != "" { - _ = os.Setenv("ZAI_BASE_URL", base) - } - } - - // Coding Plan - if r := strings.TrimSpace(cfg.ZAICodingRegion); r != "" { - _ = os.Setenv("ZAI_CODING_REGION", r) - plan := zai.PlanCoding - norm, _ := zai.NormalizeRegion(r) - if base, err := zai.ResolveOpenAIBase(plan, norm, cfg.ZAICodingBaseURL); err == nil && base != "" { - _ = os.Setenv("ZAI_CODING_BASE_URL", base) - } - } + runtime.ApplyGatewayEnv(ctx, ProviderZAIPayg) + runtime.ApplyGatewayEnv(ctx, ProviderZAICoding) } diff --git a/internal/engine/session_factory.go b/internal/engine/session_factory.go index 9029135e..afc19013 100644 --- a/internal/engine/session_factory.go +++ b/internal/engine/session_factory.go @@ -2,40 +2,60 @@ package engine import ( "context" - "fmt" + "errors" + "strings" - eyriecfg "github.com/GrayCodeAI/eyrie/config" - "github.com/GrayCodeAI/eyrie/setup" + "github.com/GrayCodeAI/eyrie/runtime" "github.com/GrayCodeAI/hawk/internal/tool" "github.com/GrayCodeAI/hawk/internal/types" ) +func transportBoolPtr(v bool) *bool { return &v } + // BuildChatClient returns an LLM client and whether deployment routing is active. -func BuildChatClient(ctx context.Context, useDeploymentRouting bool, legacyProvider string) (ChatClient, string, bool) { - cfg := eyriecfg.LoadProviderConfig("") - if useDeploymentRouting { - p, err := setup.DeploymentProvider(ctx, cfg) - if err == nil { - return NewProviderChatClient(types.WrapClientProvider(p)), legacyProvider, true +func BuildChatClient(ctx context.Context, selection runtime.SelectionState, legacyProvider string) (ChatClient, string, bool) { + provider := strings.TrimSpace(selection.Provider) + if provider == "" { + provider = legacyProvider + } + transport, err := runtime.ResolveChatTransport(ctx, runtime.ChatTransportOpts{ + Selection: runtime.SelectionOpts{ + ProviderOverride: provider, + ModelOverride: selection.Model, + // Preserve the engine-resolved routing decision. Transport + // construction itself remains in Eyrie. + DeploymentRoutingOverride: transportBoolPtr(selection.DeploymentRouting), + }, + }) + if err == nil && transport.Provider != nil { + label := strings.TrimSpace(transport.Selection.Provider) + if label == "" { + label = provider } + return NewProviderChatClient(types.WrapClientProvider(transport.Provider)), label, transport.Selection.DeploymentRouting } - c := types.NewClient(&types.ClientConfig{Provider: legacyProvider}) - return c, legacyProvider, false + return NewProviderChatClient(types.NewLazyChatProvider(&types.ClientConfig{ + Provider: provider, + })), provider, false } -// NewHawkSession constructs a Session using deployment routing when configured. -func NewHawkSession(ctx context.Context, useDeploymentRouting bool, provider, model, systemPrompt string, registry *tool.Registry) *Session { - chat, label, deploy := BuildChatClient(ctx, useDeploymentRouting, provider) - return NewSessionWithClient(chat, label, model, systemPrompt, registry, deploy) +// NewHawkSession constructs a Session using an engine-resolved selection. +func NewHawkSession(ctx context.Context, selection runtime.SelectionState, provider, model, systemPrompt string, registry *tool.Registry) *Session { + chat, label, deploy := BuildChatClient(ctx, selection, provider) + resolvedModel := strings.TrimSpace(selection.Model) + if resolvedModel == "" { + resolvedModel = model + } + return NewSessionWithClient(chat, label, resolvedModel, systemPrompt, registry, deploy) } -// RebuildSessionTransport rebuilds the LLM client from current settings and provider.json. -func RebuildSessionTransport(ctx context.Context, s *Session, useDeploymentRouting bool, legacyProvider string) error { +// RebuildSessionTransport rebuilds the LLM client from the engine-resolved selection. +func RebuildSessionTransport(ctx context.Context, s *Session, selection runtime.SelectionState, legacyProvider string) error { if s == nil { - return fmt.Errorf("session is nil") + return errors.New("session is nil") } - chat, label, deploy := BuildChatClient(ctx, useDeploymentRouting, legacyProvider) + chat, label, deploy := BuildChatClient(ctx, selection, legacyProvider) s.ReattachTransport(chat, label, deploy) return nil } diff --git a/internal/engine/session_factory_test.go b/internal/engine/session_factory_test.go new file mode 100644 index 00000000..73c8fca7 --- /dev/null +++ b/internal/engine/session_factory_test.go @@ -0,0 +1,36 @@ +package engine + +import ( + "context" + "testing" + + "github.com/GrayCodeAI/eyrie/runtime" +) + +func TestNewHawkSession_UsesResolvedSelectionModel(t *testing.T) { + selection := runtime.SelectionState{ + Provider: "openrouter", + Model: "openrouter/auto", + DeploymentRouting: false, + } + + sess := NewHawkSession(context.Background(), selection, "openrouter", "", "system", nil) + if got := sess.Provider(); got != "openrouter" { + t.Fatalf("provider = %q, want openrouter", got) + } + if got := sess.Model(); got != "openrouter/auto" { + t.Fatalf("model = %q, want openrouter/auto", got) + } +} + +func TestNewHawkSession_FallsBackToCallerModelWhenSelectionEmpty(t *testing.T) { + selection := runtime.SelectionState{ + Provider: "openrouter", + DeploymentRouting: false, + } + + sess := NewHawkSession(context.Background(), selection, "openrouter", "openrouter/fallback", "system", nil) + if got := sess.Model(); got != "openrouter/fallback" { + t.Fatalf("model = %q, want openrouter/fallback", got) + } +} diff --git a/internal/multiagent/worker.go b/internal/multiagent/worker.go index 7b019ece..b8d68d9c 100644 --- a/internal/multiagent/worker.go +++ b/internal/multiagent/worker.go @@ -7,7 +7,7 @@ import ( "os/exec" "strings" - hawkconfig "github.com/GrayCodeAI/hawk/internal/config" + "github.com/GrayCodeAI/eyrie/runtime" "github.com/GrayCodeAI/hawk/internal/engine" "github.com/GrayCodeAI/hawk/internal/tool" ) @@ -37,8 +37,11 @@ func EngineWorker(provider, model, systemPrompt string) WorkerFunc { // Create engine session with tools registry := tool.NewRegistry(baseWorkerTools()...) - settings := hawkconfig.LoadSettings() - sess := engine.NewHawkSession(ctx, hawkconfig.DeploymentRoutingEnabled(settings), provider, model, systemPrompt, registry) + selection := runtime.EffectiveSelection(ctx, runtime.SelectionOpts{ + ProviderOverride: provider, + ModelOverride: model, + }) + sess := engine.NewHawkSession(ctx, selection, provider, model, systemPrompt, registry) // Configure for autonomous operation level := engine.AutonomyLevel(cfg.AutonomyLevel) @@ -146,8 +149,11 @@ func ReadOnlyValidationWorker(provider, model, systemPrompt string) WorkerFunc { ) registry := tool.NewRegistry(readOnlyWorkerTools()...) - settings := hawkconfig.LoadSettings() - sess := engine.NewHawkSession(ctx, hawkconfig.DeploymentRoutingEnabled(settings), provider, model, systemPrompt, registry) + selection := runtime.EffectiveSelection(ctx, runtime.SelectionOpts{ + ProviderOverride: provider, + ModelOverride: model, + }) + sess := engine.NewHawkSession(ctx, selection, provider, model, systemPrompt, registry) level := engine.AutonomyLevel(cfg.AutonomyLevel) if level < engine.AutonomyFull { diff --git a/internal/types/client.go b/internal/types/client.go index 6498ed09..751645de 100644 --- a/internal/types/client.go +++ b/internal/types/client.go @@ -29,6 +29,13 @@ type ChatProvider interface { Name() string } +// ChatClient is the session-level agent-loop client interface. +type ChatClient interface { + Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) + StreamChatContinue(ctx context.Context, messages []EyrieMessage, opts ChatOptions, cfg ContinuationConfig) (*StreamResult, error) + SetAPIKey(provider, apiKey string) +} + // ResponseFormat specifies the desired output format for a Hawk runtime request. type ResponseFormat struct { Type string `json:"type"` @@ -220,6 +227,14 @@ func WrapClientProvider(p client.Provider) ChatProvider { return &providerAdapter{inner: p} } +// NewLazyChatProvider builds a lazy Eyrie provider behind Hawk's transport seam. +func NewLazyChatProvider(cfg *ClientConfig) ChatProvider { + if cfg == nil { + return nil + } + return WrapClientProvider(client.NewLazyProvider(ToClientConfig(cfg))) +} + func StreamChatWithContinuation(ctx context.Context, p ChatProvider, messages []EyrieMessage, opts ChatOptions, cfg ContinuationConfig) (*StreamResult, error) { if p == nil { return nil, nil @@ -302,6 +317,12 @@ func (p *providerAdapter) Name() string { return p.inner.Name() } +func (p *providerAdapter) SetAPIKey(provider, apiKey string) { + if setter, ok := p.inner.(interface{ SetAPIKey(string, string) }); ok { + setter.SetAPIKey(provider, apiKey) + } +} + // ToClientConfig converts Hawk-owned transport config into the provider-runtime shape. func ToClientConfig(cfg *ClientConfig) *client.EyrieConfig { if cfg == nil {