diff --git a/client/lazy_provider.go b/client/lazy_provider.go new file mode 100644 index 0000000..c2cc4bb --- /dev/null +++ b/client/lazy_provider.go @@ -0,0 +1,50 @@ +package client + +import "context" + +// LazyProvider adapts EyrieClient to the Provider interface without eagerly +// resolving credentials or constructing a concrete provider. +type LazyProvider struct { + client *EyrieClient + provider string +} + +// NewLazyProvider creates a provider wrapper that resolves the concrete +// provider only when chat or ping operations are invoked. +func NewLazyProvider(cfg *EyrieConfig) *LazyProvider { + c := Client(cfg) + provider := c.defaultProvider + if cfg != nil && cfg.Provider != "" { + provider = cfg.Provider + } + return &LazyProvider{ + client: c, + provider: provider, + } +} + +func (p *LazyProvider) Chat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*EyrieResponse, error) { + if opts.Provider == "" { + opts.Provider = p.provider + } + return p.client.Chat(ctx, messages, opts) +} + +func (p *LazyProvider) StreamChat(ctx context.Context, messages []EyrieMessage, opts ChatOptions) (*StreamResult, error) { + if opts.Provider == "" { + opts.Provider = p.provider + } + return p.client.StreamChat(ctx, messages, opts) +} + +func (p *LazyProvider) Ping(ctx context.Context) error { + return p.client.Ping(ctx, p.provider) +} + +func (p *LazyProvider) Name() string { + return p.provider +} + +func (p *LazyProvider) SetAPIKey(provider, apiKey string) { + p.client.SetAPIKey(provider, apiKey) +} diff --git a/runtime/default_provider.go b/runtime/default_provider.go index 96a692c..7926614 100644 --- a/runtime/default_provider.go +++ b/runtime/default_provider.go @@ -2,12 +2,38 @@ package runtime import ( "context" + "os" "sort" + "strings" + "time" "github.com/GrayCodeAI/eyrie/catalog" "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/credentials" ) +var chatProviderPreferenceOrder = []string{ + "openai", + "anthropic", + "openrouter", + "grok", + "gemini", + "vertex", + "bedrock", + "zai_coding", + "zai_payg", + "canopywave", + "deepseek", + "azure", + "opencodego", + "kimi", + "xiaomi_mimo_payg", + "xiaomi_mimo_token_plan", + "minimax_token_plan", + "minimax_payg", + "ollama", +} + // DefaultModelProviderFilter returns the catalog provider id to use when listing models // with no explicit provider (e.g. /config model picker after paste-key). // Order: provider.json default → first configured deployment (stable sort by id). @@ -35,3 +61,143 @@ func DefaultModelProviderFilter(ctx context.Context) string { } return "" } + +// PreferredProvider returns the runtime-owned provider choice when a host has +// not pinned one explicitly. Active selection wins first, then inferred model +// ownership, then configured providers ordered by runtime preference, and +// finally credential detection as a last resort. +func PreferredProvider(ctx context.Context) string { + if ctx == nil { + ctx = context.Background() + } + if provider := normalizeRuntimeProviderID(ActiveProvider(ctx)); provider != "" && providerConfigured(ctx, provider) { + return provider + } + if model := ActiveModel(ctx); model != "" { + if provider := inferProviderForModel(ctx, model); provider != "" && providerConfigured(ctx, provider) { + return provider + } + } + if provider := preferredConfiguredProvider(ctx); provider != "" { + return provider + } + return preferredDetectedProvider() +} + +func preferredConfiguredProvider(ctx context.Context) string { + rt, err := Load(ctx) + if err != nil || rt == nil { + return "" + } + rows, err := rt.DeploymentRows() + if err != nil || len(rows) == 0 { + return "" + } + configured := make(map[string]struct{}, len(rows)) + for _, row := range rows { + if !row.Configured { + continue + } + if provider := catalog.CanonicalProviderID(row.ProviderID); provider != "" { + configured[provider] = struct{}{} + } + } + for _, provider := range chatProviderPreferenceOrder { + if _, ok := configured[provider]; ok { + return provider + } + } + + ordered := make([]string, 0, len(configured)) + for provider := range configured { + ordered = append(ordered, provider) + } + sort.Strings(ordered) + if len(ordered) == 0 { + return "" + } + return ordered[0] +} + +func preferredDetectedProvider() string { + for _, provider := range chatProviderPreferenceOrder { + switch provider { + case "ollama": + if runtimeEnvValue("OLLAMA_BASE_URL") != "" { + return provider + } + default: + profile, ok := runtimeProfileForProvider(provider) + if !ok { + continue + } + ready := true + for _, envKey := range profile.DetectionEnv { + if runtimeEnvValue(envKey) == "" { + ready = false + break + } + } + if ready { + return provider + } + } + } + return "" +} + +func runtimeProfileForProvider(provider string) (config.RuntimeProviderProfile, bool) { + switch provider { + case "anthropic": + return config.AnthropicRuntimeProfile, true + case "openai": + return config.OpenAIRuntimeProfile, true + case "openrouter": + return config.OpenRouterRuntimeProfile, true + case "grok": + return config.GrokRuntimeProfile, true + case "gemini": + return config.GeminiRuntimeProfile, true + case "vertex": + return config.VertexRuntimeProfile, true + case "bedrock": + return config.BedrockRuntimeProfile, true + case "zai_coding": + return config.ZAICodingRuntimeProfile, true + case "zai_payg": + return config.ZAIPaygRuntimeProfile, true + case "canopywave": + return config.CanopyWaveRuntimeProfile, true + case "deepseek": + return config.DeepSeekRuntimeProfile, true + case "azure": + return config.AzureRuntimeProfile, true + case "opencodego": + return config.OpenCodeGoRuntimeProfile, true + case "kimi": + return config.KimiRuntimeProfile, true + case "xiaomi_mimo_payg": + return config.XiaomiPaygRuntimeProfile, true + case "xiaomi_mimo_token_plan": + return config.XiaomiTokenPlanRuntimeProfile, true + case "minimax_token_plan": + return config.MiniMaxTokenPlanRuntimeProfile, true + case "minimax_payg": + return config.MiniMaxPaygRuntimeProfile, true + default: + return config.RuntimeProviderProfile{}, false + } +} + +func runtimeEnvValue(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if value := credentials.LookupSecret(ctx, key); value != "" { + return value + } + return strings.TrimSpace(os.Getenv(key)) +} diff --git a/runtime/gateways.go b/runtime/gateways.go new file mode 100644 index 0000000..848e4ad --- /dev/null +++ b/runtime/gateways.go @@ -0,0 +1,415 @@ +package runtime + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/GrayCodeAI/eyrie/catalog" + "github.com/GrayCodeAI/eyrie/catalog/registry" + "github.com/GrayCodeAI/eyrie/catalog/xiaomi" + "github.com/GrayCodeAI/eyrie/catalog/zai" + "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/credentials" + "github.com/GrayCodeAI/eyrie/setup" +) + +const ( + GatewayXiaomiTokenPlan = "xiaomi_mimo_token_plan" + gatewayZAIPayg = "zai_payg" + gatewayZAICoding = "zai_coding" +) + +// GatewayStatusOpts controls runtime-owned gateway summaries for host setup UIs. +type GatewayStatusOpts struct { + ActiveProvider string + ActiveModel string +} + +// GatewayStatus is one setup-gateway row for host UIs. +type GatewayStatus struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + HasStoredCredential bool `json:"has_stored_credential"` + HasConfiguredDeployment bool `json:"has_configured_deployment"` + ModelCount int `json:"model_count"` + Active bool `json:"active"` + RegionLabel string `json:"region_label,omitempty"` + RegionRequired bool `json:"region_required,omitempty"` +} + +// SetupGateways returns the registry-backed gateway ids users configure in setup UIs. +func SetupGateways() []string { + specs := registry.CredentialRegistry() + out := make([]string, 0, len(specs)) + for _, spec := range specs { + if spec.ProviderID != "" { + out = append(out, spec.ProviderID) + } + } + return out +} + +// SetupGatewayID canonicalizes a host-facing setup-gateway id through runtime rules. +func SetupGatewayID(provider string) string { + provider = normalizeRuntimeProviderID(provider) + if catalog.IsSetupGateway(provider) { + return provider + } + return provider +} + +// CatalogProviderID canonicalizes a host-facing provider into the catalog-facing id. +// Setup gateways stay on their registry ids; provider aliases like gemini/grok map to +// their catalog owners google/xai. +func CatalogProviderID(provider string) string { + if gw := SetupGatewayID(provider); catalog.IsSetupGateway(gw) { + return gw + } + switch normalizeRuntimeProviderID(provider) { + case "gemini": + return "google" + case "grok": + return "xai" + default: + return normalizeRuntimeProviderID(provider) + } +} + +// SetupGatewayCredentialEnv returns the primary credential env var for a setup gateway. +func SetupGatewayCredentialEnv(providerID string) string { + spec, ok := registry.SpecByProviderID(SetupGatewayID(providerID)) + if !ok || !spec.RequiresKey { + return "" + } + return strings.TrimSpace(spec.CredentialEnv) +} + +// IsSetupGateway reports whether the provider id resolves to a registered setup gateway. +func IsSetupGateway(providerID string) bool { + return catalog.IsSetupGateway(SetupGatewayID(providerID)) +} + +// GatewayDisplayName returns the registry display label for a setup gateway. +func GatewayDisplayName(providerID string) string { + providerID = normalizeRuntimeProviderID(providerID) + if name := registry.DisplayName(providerID); name != providerID { + return name + } + return providerID +} + +// ActiveGateway returns the selected setup gateway derived from active provider/model state. +func ActiveGateway(ctx context.Context) string { + if ctx == nil { + ctx = context.Background() + } + return activeGateway(ctx) +} + +// GatewayStatuses returns runtime-owned setup-gateway summaries for host /config UIs. +func GatewayStatuses(ctx context.Context, opts GatewayStatusOpts) []GatewayStatus { + if ctx == nil { + ctx = context.Background() + } + active := normalizedGatewaySelection(ctx, opts) + rt, _ := Load(ctx) + + statuses := make([]GatewayStatus, 0) + for _, providerID := range SetupGateways() { + count := 0 + if rt != nil && rt.Catalog != nil { + count = len(catalog.ModelEntriesForProvider(rt.Catalog, providerID)) + } + statuses = append(statuses, GatewayStatus{ + ID: providerID, + DisplayName: GatewayDisplayName(providerID), + HasStoredCredential: HasStoredCredential(ctx, providerID), + HasConfiguredDeployment: providerConfigured(ctx, providerID), + ModelCount: count, + Active: providerID == active, + RegionLabel: GatewayRegionLabel(providerID), + RegionRequired: GatewayNeedsRegion(providerID), + }) + } + return statuses +} + +func normalizedGatewaySelection(ctx context.Context, opts GatewayStatusOpts) string { + if provider := normalizeRuntimeProviderID(opts.ActiveProvider); provider != "" { + return provider + } + if model := strings.TrimSpace(opts.ActiveModel); model != "" { + if gateway := inferProviderForModel(ctx, model); gateway != "" { + return gateway + } + } + return activeGateway(ctx) +} + +// CachedModelCountForProvider returns the on-disk catalog model count for a gateway. +func CachedModelCountForProvider(ctx context.Context, provider string) int { + provider = normalizeRuntimeProviderID(provider) + if provider == "" { + return 0 + } + rt, err := Load(ctx) + if err != nil || rt == nil || rt.Catalog == nil { + return 0 + } + return len(catalog.ModelEntriesForProvider(rt.Catalog, provider)) +} + +// ShouldClearSelectionAfterCredentialRemove reports whether removing a gateway key invalidates the active selection. +func ShouldClearSelectionAfterCredentialRemove(ctx context.Context, removedProvider string) bool { + if ctx == nil { + ctx = context.Background() + } + removedProvider = normalizeRuntimeProviderID(removedProvider) + if !HasConfiguredDeployment(ctx) { + return true + } + if gw := activeGateway(ctx); gw == removedProvider { + return true + } + if m := strings.TrimSpace(ActiveModel(ctx)); m != "" && inferProviderForModel(ctx, m) == removedProvider { + return true + } + return false +} + +// DeploymentRoutingEnabled resolves runtime routing policy plus an optional host override. +func DeploymentRoutingEnabled(ctx context.Context, override *bool) bool { + _ = ctx + cfg := config.LoadProviderConfig("") + return useDeploymentRouting(cfg, override) +} + +// HasStoredCredential reports whether the OS secret store has a usable secret for a gateway. +func HasStoredCredential(ctx context.Context, providerID string) bool { + if ctx == nil { + ctx = context.Background() + } + for _, envKey := range gatewayCredentialEnvKeys(providerID) { + if credentials.HasSecret(ctx, envKey) { + return true + } + } + return false +} + +// CredentialEnvKeys returns the credential env vars associated with a provider, +// including registry fallbacks and compatibility aliases. +func CredentialEnvKeys(providerID string) []string { + return gatewayCredentialEnvKeys(providerID) +} + +func gatewayCredentialEnvKeys(providerID string) []string { + providerID = normalizeRuntimeProviderID(providerID) + spec, ok := registry.SpecByProviderID(providerID) + if !ok { + return nil + } + seen := map[string]bool{} + var out []string + add := func(value string) { + value = strings.TrimSpace(value) + if value == "" || seen[value] { + return + } + seen[value] = true + out = append(out, value) + } + add(spec.CredentialEnv) + for _, env := range spec.CredentialEnvFallbacks { + add(env) + } + for _, env := range providerCredentialAliases(providerID) { + add(env) + } + return out +} + +func providerCredentialAliases(providerID string) []string { + switch providerID { + case "anthropic": + return []string{"CLAUDE_API_KEY"} + case "gemini": + return []string{"GOOGLE_API_KEY"} + case "xiaomi_mimo_payg": + return []string{"XIAOMI_MIMO_API_KEY"} + default: + return nil + } +} + +// PrepareCredentialDiscovery applies runtime-owned gateway env derivations before probe/discovery. +func PrepareCredentialDiscovery(ctx context.Context) { + ApplyGatewayEnv(ctx, GatewayXiaomiTokenPlan) + ApplyGatewayEnv(ctx, gatewayZAIPayg) + ApplyGatewayEnv(ctx, gatewayZAICoding) +} + +// ApplyGatewayEnv applies derived env settings from provider.json for gateways that need them. +func ApplyGatewayEnv(ctx context.Context, providerID string) { + _ = ctx + cfg := config.LoadProviderConfig("") + if cfg == nil { + return + } + switch normalizeRuntimeProviderID(providerID) { + case GatewayXiaomiTokenPlan: + if region := strings.TrimSpace(cfg.XiaomiMimoTokenPlanRegion); region != "" { + _ = os.Setenv(config.EnvXiaomiTokenPlanRegion, region) + } + if base, err := config.ResolveXiaomiOpenAIBase(GatewayXiaomiTokenPlan, cfg); err == nil && base != "" { + _ = os.Setenv(config.EnvXiaomiTokenPlanBaseURL, base) + } + case gatewayZAIPayg: + if region := strings.TrimSpace(cfg.ZAIRegion); region != "" { + _ = os.Setenv("ZAI_REGION", region) + norm, _ := zai.NormalizeRegion(region) + if base, err := zai.ResolveOpenAIBase(zai.PlanGeneral, norm, cfg.ZAIBaseURL); err == nil && base != "" { + _ = os.Setenv("ZAI_BASE_URL", base) + } + } + case gatewayZAICoding: + if region := strings.TrimSpace(cfg.ZAICodingRegion); region != "" { + _ = os.Setenv("ZAI_CODING_REGION", region) + norm, _ := zai.NormalizeRegion(region) + if base, err := zai.ResolveOpenAIBase(zai.PlanCoding, norm, cfg.ZAICodingBaseURL); err == nil && base != "" { + _ = os.Setenv("ZAI_CODING_BASE_URL", base) + } + } + } +} + +// GatewayNeedsRegion reports whether a gateway still requires a region selection. +func GatewayNeedsRegion(providerID string) bool { + cfg := config.LoadProviderConfig("") + switch normalizeRuntimeProviderID(providerID) { + case GatewayXiaomiTokenPlan: + if cfg == nil { + return true + } + _, err := xiaomi.NormalizeRegion(cfg.XiaomiMimoTokenPlanRegion) + return err != nil + case gatewayZAICoding: + if cfg == nil { + return true + } + _, err := zai.NormalizeRegion(cfg.ZAICodingRegion) + return err != nil + default: + return false + } +} + +// GatewayRegionLabel returns the saved normalized region for gateways that require one. +func GatewayRegionLabel(providerID string) string { + cfg := config.LoadProviderConfig("") + if cfg == nil { + return "" + } + switch normalizeRuntimeProviderID(providerID) { + case GatewayXiaomiTokenPlan: + region, err := xiaomi.NormalizeRegion(cfg.XiaomiMimoTokenPlanRegion) + if err != nil { + return "" + } + return string(region) + case gatewayZAICoding: + region, err := zai.NormalizeRegion(cfg.ZAICodingRegion) + if err != nil { + return "" + } + return string(region) + default: + return "" + } +} + +// SetGatewayRegion persists a normalized gateway region and updates derived env/base-url state. +func SetGatewayRegion(providerID, region string) error { + cfg := config.LoadProviderConfig("") + if cfg == nil { + cfg = &config.ProviderConfig{} + } + switch normalizeRuntimeProviderID(providerID) { + case GatewayXiaomiTokenPlan: + normalized, err := xiaomi.NormalizeRegion(region) + if err != nil { + return err + } + cfg.XiaomiMimoTokenPlanRegion = string(normalized) + if base, err := config.ResolveXiaomiOpenAIBase(GatewayXiaomiTokenPlan, cfg); err == nil && base != "" { + cfg.XiaomiMimoTokenPlanBaseURL = base + } + if err := config.SaveProviderConfig(cfg, ""); err != nil { + return err + } + ApplyGatewayEnv(context.Background(), GatewayXiaomiTokenPlan) + return nil + case gatewayZAIPayg, gatewayZAICoding: + normalized, err := zai.NormalizeRegion(region) + if err != nil { + return err + } + if normalizeRuntimeProviderID(providerID) == gatewayZAICoding { + cfg.ZAICodingRegion = string(normalized) + if base, err := zai.ResolveOpenAIBase(zai.PlanCoding, normalized, cfg.ZAICodingBaseURL); err == nil && base != "" { + cfg.ZAICodingBaseURL = base + } + } else { + cfg.ZAIRegion = string(normalized) + if base, err := zai.ResolveOpenAIBase(zai.PlanGeneral, normalized, cfg.ZAIBaseURL); err == nil && base != "" { + cfg.ZAIBaseURL = base + } + } + if err := config.SaveProviderConfig(cfg, ""); err != nil { + return err + } + ApplyGatewayEnv(context.Background(), providerID) + return nil + default: + return fmt.Errorf("runtime: gateway %q does not use a selectable region", providerID) + } +} + +// ApplyCredentialsForProvider refreshes runtime catalog/provider state for one gateway after saving credentials. +func ApplyCredentialsForProvider(ctx context.Context, providerID string) (*setup.ApplyCredentialsResult, error) { + PrepareCredentialDiscovery(ctx) + return setup.ApplyCredentialsForProvider(ctx, normalizeRuntimeProviderID(providerID), config.DiscoveryCredentials(ctx)) +} + +// RefreshGatewayCatalog fetches live models for one gateway and updates the cached catalog. +func RefreshGatewayCatalog(ctx context.Context, providerID string) (string, error) { + PrepareCredentialDiscovery(ctx) + providerID = normalizeRuntimeProviderID(providerID) + result, err := setup.DiscoverProviderCatalog(ctx, providerID, config.DiscoveryCredentials(ctx)) + if err != nil { + return "", err + } + count := 0 + if result.Compiled != nil { + count = len(catalog.ModelEntriesForProvider(result.Compiled, providerID)) + } + return fmt.Sprintf("Refreshed %s (%d models)", providerID, count), nil +} + +// FormatApplyCredentialsSummary summarizes provider apply results for host UIs. +func FormatApplyCredentialsSummary(result *setup.ApplyCredentialsResult) string { + if result == nil || result.Catalog == nil || result.Catalog.Compiled == nil { + return "Eyrie credentials applied" + } + models := len(result.Catalog.Compiled.ModelsByID) + deployments := 0 + if result.ProviderConfig != nil { + deployments = len(result.ProviderConfig.Deployments) + } + return fmt.Sprintf( + "Eyrie: %d models, %d deployments configured, routing updated -> %s", + models, deployments, result.ProviderConfigPath, + ) +} diff --git a/runtime/runtime.go b/runtime/runtime.go index 90d82b4..7f3d3ff 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -26,8 +26,6 @@ import ( "context" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "time" @@ -118,9 +116,22 @@ func (r *Runtime) ChatProvider(ctx context.Context) (client.Provider, error) { return p, nil } +// ChatProvider builds the configured chat provider without requiring callers to +// load runtime state first. Host applications should prefer this over reaching +// into lower-level setup/config packages. +func ChatProvider(ctx context.Context) (client.Provider, error) { + cfg := config.LoadProviderConfig("") + return setup.DeploymentProvider(ctx, cfg) +} + +// AvailableProviders lists the currently registered provider IDs. +func AvailableProviders() []string { + return client.Client(nil).GetProviders() +} + // RoutingPreviewJSON returns effective routing for a model ID. func (r *Runtime) RoutingPreviewJSON(model string) (string, error) { - return setup.RoutingPreview(ctxWithBackground(), model) + return RoutingPreview(ctxWithBackground(), model) } func ctxWithBackground() context.Context { @@ -252,7 +263,5 @@ func configuredDeploymentIDsForProvider(compiled *catalog.CompiledCatalogV1, pro // DefaultPaths reports standard eyrie paths on disk. func DefaultPaths() (catalogPath, providerPath string) { - home, _ := os.UserHomeDir() - return filepath.Join(home, ".eyrie", "model_catalog.json"), - config.GetProviderConfigPath() + return catalog.DefaultCachePath(), config.GetProviderConfigPath() } diff --git a/runtime/selection.go b/runtime/selection.go index 020caad..720ec6a 100644 --- a/runtime/selection.go +++ b/runtime/selection.go @@ -6,20 +6,182 @@ import ( "strings" "github.com/GrayCodeAI/eyrie/catalog" + "github.com/GrayCodeAI/eyrie/catalog/registry" + "github.com/GrayCodeAI/eyrie/client" "github.com/GrayCodeAI/eyrie/config" + "github.com/GrayCodeAI/eyrie/setup" ) // ActiveModel returns the selected model from provider.json. func ActiveModel(ctx context.Context) string { cfg := config.LoadProviderConfig("") - return config.ActiveModel(cfg) + if cfg == nil { + return "" + } + return strings.TrimSpace(cfg.ActiveModel) } // ActiveProvider returns the selected provider from provider.json. func ActiveProvider(ctx context.Context) string { _ = ctx cfg := config.LoadProviderConfig("") - return config.ActiveProvider(cfg) + if cfg == nil { + return "" + } + return catalog.CanonicalProviderID(strings.TrimSpace(cfg.ActiveProvider)) +} + +// NormalizeProviderID resolves catalog aliases and host-facing variants to the +// runtime provider identifier used by Eyrie adapters and setup gateways. +func NormalizeProviderID(provider string) string { + return normalizeRuntimeProviderID(provider) +} + +// ActiveProviderID canonicalizes a host-facing provider/gateway id through the +// runtime provider-id rules used by Eyrie. +func ActiveProviderID(provider string) string { + return NormalizeProviderID(provider) +} + +// SelectionState is the engine-resolved provider/model selection for chat. +type SelectionState struct { + Provider string `json:"provider"` + Model string `json:"model"` + HasConfiguredDeployment bool `json:"has_configured_deployment"` + DeploymentRouting bool `json:"deployment_routing"` +} + +// SelectionOpts supplies optional host overrides for provider/model selection. +// Empty overrides mean "use engine-persisted/default selection". +type SelectionOpts struct { + ProviderOverride string + ModelOverride string + // DeploymentRoutingOverride lets a host force the routing mode while the + // migration to pure engine-owned policy is in progress. + DeploymentRoutingOverride *bool +} + +// HasConfiguredDeployment reports whether the engine has at least one usable deployment. +func HasConfiguredDeployment(ctx context.Context) bool { + return config.HasAnyConfiguredDeployment(ctx) +} + +// ResolveCanonicalModel maps aliases/native IDs to canonical catalog model IDs. +func ResolveCanonicalModel(ctx context.Context, model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + rt, err := Load(ctx) + if err == nil && rt != nil && rt.Catalog != nil { + if canonical, ok := rt.Catalog.CanonicalModelForAliasOrID(model); ok { + return canonical + } + } + if strings.Contains(model, "/") { + return model + } + return model +} + +// DefaultModelForProvider returns the preferred model for a provider using cache +// first, then live discovery when credentials are configured. +func DefaultModelForProvider(ctx context.Context, provider string) string { + provider = normalizeRuntimeProviderID(provider) + if provider == "" { + return "" + } + if rt, err := Load(ctx); err == nil && rt != nil && rt.Catalog != nil { + if id := catalog.FirstModelForProvider(rt.Catalog, provider); id != "" { + return id + } + } + if !providerConfigured(ctx, provider) { + return "" + } + models, err := ListModels(ctx, ListModelsOpts{ProviderID: provider, Source: ListSourceAuto}) + if err == nil && len(models) > 0 { + return strings.TrimSpace(models[0].ID) + } + return "" +} + +// SyncSelectionWithCredentials clears stale persisted selection when the selected +// gateway no longer has usable credentials. +func SyncSelectionWithCredentials(ctx context.Context) { + if ctx == nil { + ctx = context.Background() + } + if !HasConfiguredDeployment(ctx) { + if strings.TrimSpace(ActiveModel(ctx)) != "" || strings.TrimSpace(ActiveProvider(ctx)) != "" { + _ = ClearActiveSelection(ctx) + } + return + } + gateway := activeGateway(ctx) + if gateway == "" { + return + } + if !providerConfigured(ctx, gateway) { + _ = ClearActiveSelection(ctx) + } +} + +// EffectiveSelection resolves the provider/model that chat should use after +// applying persisted selection, optional host overrides, credential-aware +// fallback, and optional canonicalization. +func EffectiveSelection(ctx context.Context, opts SelectionOpts) SelectionState { + if ctx == nil { + ctx = context.Background() + } + SyncSelectionWithCredentials(ctx) + cfg := config.LoadProviderConfig("") + state := SelectionState{ + HasConfiguredDeployment: HasConfiguredDeployment(ctx), + DeploymentRouting: useDeploymentRouting(cfg, opts.DeploymentRoutingOverride), + } + + provider := normalizeRuntimeProviderID(ActiveProvider(ctx)) + if override := normalizeRuntimeProviderID(opts.ProviderOverride); override != "" { + provider = override + } + + model := strings.TrimSpace(ActiveModel(ctx)) + if override := strings.TrimSpace(opts.ModelOverride); override != "" { + model = override + } + + if provider == "" && model != "" { + provider = inferProviderForModel(ctx, model) + } + + if provider == "" { + provider = PreferredProvider(ctx) + } + + if !state.HasConfiguredDeployment { + state.Provider = provider + state.Model = model + return state + } + + if provider != "" && !providerConfigured(ctx, provider) { + if detected := normalizeRuntimeProviderID(client.DetectProvider()); detected != "" && providerConfigured(ctx, detected) { + provider = detected + model = "" + } + } + + if provider != "" && model == "" { + model = DefaultModelForProvider(ctx, provider) + } + if state.DeploymentRouting && model != "" { + model = ResolveCanonicalModel(ctx, model) + } + + state.Provider = provider + state.Model = model + return state } // SetActiveModel persists the user's model choice to provider.json. @@ -74,12 +236,90 @@ func inferProviderForModel(ctx context.Context, modelID string) string { rt, err := Load(ctx) if err != nil || rt == nil || rt.Catalog == nil { if prefix, _, ok := strings.Cut(strings.TrimSpace(modelID), "/"); ok && catalog.IsSetupGateway(prefix) { - return catalog.CanonicalProviderID(prefix) + return normalizeRuntimeProviderID(prefix) } return "" } if gw := catalog.GatewayForModel(rt.Catalog, modelID); gw != "" { - return gw + return normalizeRuntimeProviderID(gw) + } + return "" +} + +func activeGateway(ctx context.Context) string { + if provider := normalizeRuntimeProviderID(ActiveProvider(ctx)); catalog.IsSetupGateway(provider) { + return provider + } + if model := strings.TrimSpace(ActiveModel(ctx)); model != "" { + return inferProviderForModel(ctx, model) } return "" } + +func providerConfigured(ctx context.Context, provider string) bool { + provider = normalizeRuntimeProviderID(provider) + if provider == "" { + return false + } + spec, ok := registry.SpecByProviderID(provider) + if !ok { + return false + } + compiled, err := catalog.LoadCatalogForDiscovery(ctx) + if err != nil || compiled == nil || compiled.Catalog == nil { + return false + } + dep, ok := compiled.Catalog.Deployments[spec.DeploymentID] + if !ok { + return false + } + env := config.DiscoveryEnvMap(ctx) + dc := config.DeploymentConfigFromEnv(dep, env) + return config.DeploymentConfigured(spec.DeploymentID, dep, dc) +} + +func normalizeRuntimeProviderID(provider string) string { + base := strings.ToLower(strings.TrimSpace(provider)) + if base == "" { + return "" + } + switch base { + case "z-ai-payg": + base = "zai_payg" + case "z-ai-coding": + base = "zai_coding" + case "xiaomi-mimo": + base = "xiaomi_mimo_payg" + case "xiaomi-mimo-payg": + base = "xiaomi_mimo_payg" + case "xiaomi-mimo-token-plan": + base = "xiaomi_mimo_token_plan" + } + candidates := []string{ + base, + strings.ReplaceAll(base, "-", "_"), + strings.ReplaceAll(base, "_", "-"), + catalog.CanonicalProviderID(base), + catalog.CanonicalProviderID(strings.ReplaceAll(base, "-", "_")), + catalog.CanonicalProviderID(strings.ReplaceAll(base, "_", "-")), + } + seen := map[string]bool{} + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" || seen[candidate] { + continue + } + seen[candidate] = true + if spec, ok := registry.SpecByProviderID(candidate); ok { + return spec.ProviderID + } + } + return catalog.CanonicalProviderID(strings.ReplaceAll(base, "-", "_")) +} + +func useDeploymentRouting(cfg *config.ProviderConfig, override *bool) bool { + if override != nil { + return *override + } + return setup.UseDeploymentRouting(cfg) +} diff --git a/runtime/status.go b/runtime/status.go new file mode 100644 index 0000000..5ef725c --- /dev/null +++ b/runtime/status.go @@ -0,0 +1,22 @@ +package runtime + +import ( + "context" + + "github.com/GrayCodeAI/eyrie/setup" +) + +// DeploymentStatus returns deployment-routing diagnostics for host UIs. +func DeploymentStatus(ctx context.Context, activeModel string) (setup.StatusReport, error) { + return setup.DeploymentStatus(ctx, activeModel) +} + +// FormatDeploymentStatus renders a deployment-routing diagnostics report for host UIs. +func FormatDeploymentStatus(report setup.StatusReport) string { + return setup.FormatStatus(report) +} + +// RoutingPreview returns the effective routing preview JSON for a model ID. +func RoutingPreview(ctx context.Context, model string) (string, error) { + return setup.RoutingPreview(ctx, model) +} diff --git a/runtime/transport.go b/runtime/transport.go new file mode 100644 index 0000000..bb34c98 --- /dev/null +++ b/runtime/transport.go @@ -0,0 +1,78 @@ +package runtime + +import ( + "context" + + "github.com/GrayCodeAI/eyrie/client" +) + +// ChatTransportOpts supplies host-side overrides while transport ownership +// moves into the runtime package. +type ChatTransportOpts struct { + Selection SelectionOpts +} + +// ChatTransport is the runtime-owned chat transport plan that host apps adapt +// into their local session/client abstractions. +type ChatTransport struct { + Selection SelectionState + Provider client.Provider +} + +// ResolveChatTransport resolves the effective selection and constructs the +// runtime-owned provider transport for both deployment-routed and direct +// provider execution. +func ResolveChatTransport(ctx context.Context, opts ChatTransportOpts) (ChatTransport, error) { + if ctx == nil { + ctx = context.Background() + } + selection := EffectiveSelection(ctx, opts.Selection) + transport := ChatTransport{Selection: selection} + if !selection.DeploymentRouting { + transport.Provider = directChatProvider(ctx, selection.Provider) + return transport, nil + } + provider, err := ChatProvider(ctx) + if err != nil { + transport.Provider = directChatProvider(ctx, selection.Provider) + if transport.Provider != nil { + transport.Selection.DeploymentRouting = false + return transport, nil + } + return transport, err + } + transport.Provider = provider + return transport, nil +} + +func directChatProvider(ctx context.Context, primary string) client.Provider { + primary = NormalizeProviderID(primary) + if primary == "" { + return nil + } + providers := []client.Provider{ + client.NewLazyProvider(&client.EyrieConfig{Provider: primary}), + } + for _, providerID := range directFallbackProviderIDs(ctx, primary) { + providers = append(providers, client.NewLazyProvider(&client.EyrieConfig{Provider: providerID})) + } + if len(providers) == 1 { + return providers[0] + } + return client.NewFallbackProvider(providers...) +} + +func directFallbackProviderIDs(ctx context.Context, primary string) []string { + primary = NormalizeProviderID(primary) + switch primary { + case "openai": + if providerConfigured(ctx, "anthropic") { + return []string{"anthropic"} + } + case "anthropic": + if providerConfigured(ctx, "openai") { + return []string{"openai"} + } + } + return nil +} diff --git a/runtime/transport_policy_test.go b/runtime/transport_policy_test.go new file mode 100644 index 0000000..287f755 --- /dev/null +++ b/runtime/transport_policy_test.go @@ -0,0 +1,105 @@ +package runtime + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/GrayCodeAI/eyrie/credentials" +) + +func boolPtr(v bool) *bool { return &v } + +func TestPreferredProvider_PrefersOpenAIOverAnthropicWhenBothConfigured(t *testing.T) { + dir := t.TempDir() + t.Setenv("HAWK_CONFIG_DIR", dir) + if err := os.WriteFile(filepath.Join(dir, "provider.json"), []byte("{}\n"), 0o600); err != nil { + t.Fatal(err) + } + + store := &credentials.MapStore{} + credentials.SetDefaultStore(store) + t.Cleanup(func() { credentials.SetDefaultStore(nil) }) + + ctx := context.Background() + if err := store.Set(ctx, credentials.AccountForEnv("OPENAI_API_KEY"), "sk-openai-test"); err != nil { + t.Fatal(err) + } + if err := store.Set(ctx, credentials.AccountForEnv("ANTHROPIC_API_KEY"), "sk-ant-test"); err != nil { + t.Fatal(err) + } + + if got := PreferredProvider(ctx); got != "openai" { + t.Fatalf("PreferredProvider() = %q, want openai", got) + } +} + +func TestEffectiveSelection_InfersProviderFromModelOverride(t *testing.T) { + dir := t.TempDir() + t.Setenv("HAWK_CONFIG_DIR", dir) + if err := os.WriteFile(filepath.Join(dir, "provider.json"), []byte("{}\n"), 0o600); err != nil { + t.Fatal(err) + } + + store := &credentials.MapStore{} + credentials.SetDefaultStore(store) + t.Cleanup(func() { credentials.SetDefaultStore(nil) }) + + ctx := context.Background() + if err := store.Set(ctx, credentials.AccountForEnv("OPENAI_API_KEY"), "sk-openai-test"); err != nil { + t.Fatal(err) + } + + state := EffectiveSelection(ctx, SelectionOpts{ + ModelOverride: "openai/gpt-4o", + }) + if state.Provider != "openai" { + t.Fatalf("provider = %q, want openai", state.Provider) + } + if state.Model != "openai/gpt-4o" { + t.Fatalf("model = %q, want openai/gpt-4o", state.Model) + } +} + +func TestResolveChatTransport_DirectOpenAIFallsBackToAnthropic(t *testing.T) { + dir := t.TempDir() + t.Setenv("HAWK_CONFIG_DIR", dir) + if err := os.WriteFile(filepath.Join(dir, "provider.json"), []byte("{}\n"), 0o600); err != nil { + t.Fatal(err) + } + + store := &credentials.MapStore{} + credentials.SetDefaultStore(store) + t.Cleanup(func() { credentials.SetDefaultStore(nil) }) + + ctx := context.Background() + if err := store.Set(ctx, credentials.AccountForEnv("OPENAI_API_KEY"), "sk-openai-test"); err != nil { + t.Fatal(err) + } + if err := store.Set(ctx, credentials.AccountForEnv("ANTHROPIC_API_KEY"), "sk-ant-test"); err != nil { + t.Fatal(err) + } + + transport, err := ResolveChatTransport(ctx, ChatTransportOpts{ + Selection: SelectionOpts{ + ProviderOverride: "openai", + DeploymentRoutingOverride: boolPtr(false), + }, + }) + if err != nil { + t.Fatalf("ResolveChatTransport() error = %v", err) + } + if transport.Provider == nil { + t.Fatal("expected transport provider") + } + if got := transport.Provider.Name(); got != "fallback(openai->anthropic)" { + t.Fatalf("provider name = %q, want fallback(openai->anthropic)", got) + } + if transport.Selection.Provider != "openai" { + t.Fatalf("selection provider = %q, want openai", transport.Selection.Provider) + } + if transport.Selection.DeploymentRouting { + t.Fatal("expected deployment routing disabled") + } +}