diff --git a/internal/adapters/dashboard/account_client.go b/internal/adapters/dashboard/account_client.go index db0f02a..6888f9b 100644 --- a/internal/adapters/dashboard/account_client.go +++ b/internal/adapters/dashboard/account_client.go @@ -77,7 +77,7 @@ func (c *AccountClient) Login(ctx context.Context, email, password, orgPublicID raw, err := c.doPostRaw(ctx, "/auth/cli/login", body, nil, "") if err != nil { - return nil, nil, fmt.Errorf("%w", domain.ErrDashboardLoginFailed) + return nil, nil, fmt.Errorf("%w: %w", domain.ErrDashboardLoginFailed, err) } // Check if response contains userToken (success) or totpFactor (MFA required) @@ -120,7 +120,7 @@ func (c *AccountClient) LoginMFA(ctx context.Context, userPublicID, code, orgPub var result domain.DashboardAuthResponse if err := c.doPost(ctx, "/auth/cli/login/mfa", body, nil, "", &result); err != nil { - return nil, fmt.Errorf("%w", domain.ErrDashboardLoginFailed) + return nil, fmt.Errorf("%w: %w", domain.ErrDashboardLoginFailed, err) } return &result, nil } @@ -130,7 +130,7 @@ func (c *AccountClient) Refresh(ctx context.Context, userToken, orgToken string) headers := bearerHeaders(userToken, orgToken) var result domain.DashboardRefreshResponse if err := c.doPost(ctx, "/auth/cli/refresh", nil, headers, userToken, &result); err != nil { - return nil, fmt.Errorf("%w", domain.ErrDashboardSessionExpired) + return nil, fmt.Errorf("failed to refresh session: %w", err) } return &result, nil } diff --git a/internal/adapters/dashboard/account_client_test.go b/internal/adapters/dashboard/account_client_test.go new file mode 100644 index 0000000..0704aa5 --- /dev/null +++ b/internal/adapters/dashboard/account_client_test.go @@ -0,0 +1,503 @@ +package dashboard + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newAccountClientTestServer(t *testing.T, handler func(t *testing.T, w http.ResponseWriter, r *http.Request, rawBody []byte, body map[string]any)) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Helper() + + rawBody, err := io.ReadAll(r.Body) + require.NoError(t, err) + + var body map[string]any + if len(rawBody) > 0 { + require.NoError(t, json.Unmarshal(rawBody, &body)) + } + + handler(t, w, r, rawBody, body) + })) +} + +func writeDashboardEnvelope(t *testing.T, w http.ResponseWriter, data any) { + t.Helper() + + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "request_id": "req-123", + "success": true, + "data": data, + })) +} + +func TestAccountClientPublicEndpoints(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + handler func(t *testing.T, w http.ResponseWriter, r *http.Request, rawBody []byte, body map[string]any) + run func(t *testing.T, client *AccountClient) + }{ + { + name: "register", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/register", r.URL.Path) + assert.Equal(t, "user@example.com", body["email"]) + assert.Equal(t, "secret", body["password"]) + assert.Equal(t, true, body["privacyPolicyAccepted"]) + assert.NotEmpty(t, r.Header.Get("DPoP")) + + writeDashboardEnvelope(t, w, map[string]any{ + "verificationChannel": "email", + "expiresAt": "2026-04-20T12:00:00Z", + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.Register(context.Background(), "user@example.com", "secret", true) + require.NoError(t, err) + assert.Equal(t, "email", resp.VerificationChannel) + }, + }, + { + name: "verify email code", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/verify-email-code", r.URL.Path) + assert.Equal(t, "user@example.com", body["email"]) + assert.Equal(t, "123456", body["code"]) + assert.Equal(t, "us", body["region"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token", + "user": map[string]any{ + "publicId": "user-1", + }, + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.VerifyEmailCode(context.Background(), "user@example.com", "123456", "us") + require.NoError(t, err) + assert.Equal(t, "user-token", resp.UserToken) + }, + }, + { + name: "resend verification code", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/resend-verification-code", r.URL.Path) + assert.Equal(t, "user@example.com", body["email"]) + + writeDashboardEnvelope(t, w, map[string]any{}) + }, + run: func(t *testing.T, client *AccountClient) { + require.NoError(t, client.ResendVerificationCode(context.Background(), "user@example.com")) + }, + }, + { + name: "login MFA completion", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/login/mfa", r.URL.Path) + assert.Equal(t, "user-1", body["userPublicId"]) + assert.Equal(t, "654321", body["code"]) + assert.Equal(t, "org-1", body["orgPublicId"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token", + "user": map[string]any{ + "publicId": "user-1", + }, + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.LoginMFA(context.Background(), "user-1", "654321", "org-1") + require.NoError(t, err) + assert.Equal(t, "org-token", resp.OrgToken) + }, + }, + { + name: "refresh", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, rawBody []byte, _ map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/refresh", r.URL.Path) + assert.Empty(t, rawBody) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + + writeDashboardEnvelope(t, w, map[string]any{ + "userToken": "user-token-new", + "orgToken": "org-token-new", + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.Refresh(context.Background(), "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "user-token-new", resp.UserToken) + assert.Equal(t, "org-token-new", resp.OrgToken) + }, + }, + { + name: "logout", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, rawBody []byte, _ map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/logout", r.URL.Path) + assert.Empty(t, rawBody) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + + writeDashboardEnvelope(t, w, map[string]any{}) + }, + run: func(t *testing.T, client *AccountClient) { + require.NoError(t, client.Logout(context.Background(), "user-token", "org-token")) + }, + }, + { + name: "sso start register", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/auth/cli/sso/start", r.URL.Path) + assert.Equal(t, "google_SSO", body["loginType"]) + assert.Equal(t, "register", body["mode"]) + assert.Equal(t, true, body["privacyPolicyAccepted"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "flowId": "flow-1", + "verificationUri": "https://example.com/device", + "verificationUriComplete": "https://example.com/device?code=abc", + "userCode": "ABCDEF", + "expiresIn": 300, + "interval": 5, + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.SSOStart(context.Background(), domain.SSOLoginTypeGoogle, "register", true) + require.NoError(t, err) + assert.Equal(t, "flow-1", resp.FlowID) + }, + }, + { + name: "get current session", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, rawBody []byte, _ map[string]any) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "/sessions/current", r.URL.Path) + assert.Empty(t, rawBody) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + + writeDashboardEnvelope(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-1", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Acme"}, + }, + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.GetCurrentSession(context.Background(), "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "org-1", resp.CurrentOrg) + require.Len(t, resp.Relations, 1) + }, + }, + { + name: "switch org", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/sessions/switch-org", r.URL.Path) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + assert.Equal(t, "org-2", body["orgPublicId"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "orgToken": "org-token-new", + "orgSessionId": "session-2", + "org": map[string]any{ + "publicId": "org-2", + "name": "Beta", + }, + }) + }, + run: func(t *testing.T, client *AccountClient) { + resp, err := client.SwitchOrg(context.Background(), "org-2", "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "org-token-new", resp.OrgToken) + assert.Equal(t, "Beta", resp.Org.Name) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, tt.handler) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + tt.run(t, client) + }) + } +} + +func TestAccountClientLoginVariants(t *testing.T) { + t.Parallel() + + t.Run("success response", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, "/auth/cli/login", r.URL.Path) + assert.Equal(t, "user@example.com", body["email"]) + assert.Equal(t, "secret", body["password"]) + assert.Equal(t, "org-1", body["orgPublicId"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token", + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1"}, + }, + }) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + auth, mfa, err := client.Login(context.Background(), "user@example.com", "secret", "org-1") + require.NoError(t, err) + assert.NotNil(t, auth) + assert.Nil(t, mfa) + assert.Equal(t, "user-token", auth.UserToken) + }) + + t.Run("mfa required response", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + writeDashboardEnvelope(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1"}, + }, + "totpFactor": map[string]any{ + "factorSid": "factor-1", + }, + }) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + auth, mfa, err := client.Login(context.Background(), "user@example.com", "secret", "") + require.NoError(t, err) + assert.Nil(t, auth) + require.NotNil(t, mfa) + assert.Equal(t, "factor-1", mfa.TOTPFactor.FactorSID) + }) + + t.Run("unexpected payload returns login failed", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + writeDashboardEnvelope(t, w, map[string]any{"status": "unknown"}) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + auth, mfa, err := client.Login(context.Background(), "user@example.com", "secret", "") + require.Error(t, err) + assert.Nil(t, auth) + assert.Nil(t, mfa) + assert.ErrorIs(t, err, domain.ErrDashboardLoginFailed) + }) + + t.Run("transport and API errors are wrapped", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + w.WriteHeader(http.StatusUnauthorized) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": "INVALID_CREDENTIALS", + "message": "Invalid email or password", + }, + })) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + auth, mfa, err := client.Login(context.Background(), "user@example.com", "secret", "") + require.Error(t, err) + assert.Nil(t, auth) + assert.Nil(t, mfa) + assert.ErrorIs(t, err, domain.ErrDashboardLoginFailed) + assert.Contains(t, err.Error(), "INVALID_CREDENTIALS") + }) +} + +func TestAccountClientLoginMFAWrapsUnderlyingError(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + w.WriteHeader(http.StatusUnauthorized) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": "INVALID_TOTP", + "message": "Invalid MFA code", + }, + })) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + resp, err := client.LoginMFA(context.Background(), "user-1", "654321", "org-1") + require.Error(t, err) + assert.Nil(t, resp) + assert.ErrorIs(t, err, domain.ErrDashboardLoginFailed) + assert.Contains(t, err.Error(), "INVALID_TOTP") +} + +func TestAccountClientSSOPollVariants(t *testing.T) { + t.Parallel() + + t.Run("complete response populates auth", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, r *http.Request, _ []byte, body map[string]any) { + assert.Equal(t, "/auth/cli/sso/poll", r.URL.Path) + assert.Equal(t, "flow-1", body["flowId"]) + assert.Equal(t, "org-1", body["orgPublicId"]) + + writeDashboardEnvelope(t, w, map[string]any{ + "status": "complete", + "userToken": "user-token", + "orgToken": "org-token", + "user": map[string]any{ + "publicId": "user-1", + }, + }) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + resp, err := client.SSOPoll(context.Background(), "flow-1", "org-1") + require.NoError(t, err) + require.NotNil(t, resp.Auth) + assert.Equal(t, "user-token", resp.Auth.UserToken) + }) + + t.Run("mfa response populates MFA payload", func(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + writeDashboardEnvelope(t, w, map[string]any{ + "status": "mfa_required", + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1"}, + }, + "totpFactor": map[string]any{ + "factorSid": "factor-1", + }, + }) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + resp, err := client.SSOPoll(context.Background(), "flow-1", "") + require.NoError(t, err) + require.NotNil(t, resp.MFA) + assert.Equal(t, "factor-1", resp.MFA.TOTPFactor.FactorSID) + }) +} + +func TestAccountClientRefreshPropagatesUnderlyingError(t *testing.T) { + t.Parallel() + + server := newAccountClientTestServer(t, func(t *testing.T, w http.ResponseWriter, _ *http.Request, _ []byte, _ map[string]any) { + w.WriteHeader(http.StatusUnauthorized) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": "INVALID_SESSION", + "message": "Invalid or expired session", + }, + })) + }) + defer server.Close() + + client := &AccountClient{ + baseURL: server.URL, + httpClient: server.Client(), + dpop: &mockDPoP{proof: "test-proof"}, + } + + resp, err := client.Refresh(context.Background(), "user-token", "org-token") + require.Error(t, err) + assert.Nil(t, resp) + assert.True(t, errors.Is(err, domain.ErrDashboardSessionExpired)) + assert.Contains(t, err.Error(), "failed to refresh session") +} diff --git a/internal/adapters/dashboard/gateway_client.go b/internal/adapters/dashboard/gateway_client.go index b4331d3..55384ac 100644 --- a/internal/adapters/dashboard/gateway_client.go +++ b/internal/adapters/dashboard/gateway_client.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "os" + "strings" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" @@ -68,7 +69,7 @@ func (c *GatewayClient) ListApplications(ctx context.Context, orgPublicID, regio } if len(resp.Errors) > 0 { - return nil, fmt.Errorf("GraphQL error: %s", formatGraphQLError(resp.Errors[0])) + return nil, fmt.Errorf("failed to list applications: %w", graphQLErrorAsError(resp.Errors[0])) } return resp.Data.Applications.Applications, nil @@ -115,7 +116,7 @@ func (c *GatewayClient) CreateApplication(ctx context.Context, orgPublicID, regi } if len(resp.Errors) > 0 { - return nil, fmt.Errorf("GraphQL error: %s", formatGraphQLError(resp.Errors[0])) + return nil, fmt.Errorf("failed to create application: %w", graphQLErrorAsError(resp.Errors[0])) } return &resp.Data.CreateApplication, nil @@ -156,7 +157,7 @@ func (c *GatewayClient) ListAPIKeys(ctx context.Context, appID, region, userToke } if len(resp.Errors) > 0 { - return nil, fmt.Errorf("GraphQL error: %s", formatGraphQLError(resp.Errors[0])) + return nil, fmt.Errorf("failed to list API keys: %w", graphQLErrorAsError(resp.Errors[0])) } return resp.Data.APIKeys, nil @@ -206,7 +207,7 @@ func (c *GatewayClient) CreateAPIKey(ctx context.Context, appID, region, name st } if len(resp.Errors) > 0 { - return nil, fmt.Errorf("GraphQL error: %s", formatGraphQLError(resp.Errors[0])) + return nil, fmt.Errorf("failed to create API key: %w", graphQLErrorAsError(resp.Errors[0])) } return &resp.Data.CreateAPIKey, nil @@ -259,6 +260,9 @@ func (c *GatewayClient) doGraphQL(ctx context.Context, url, query string, variab } if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if gqlErr := parseGraphQLErrorResponse(resp.StatusCode, respBody); gqlErr != nil { + return nil, gqlErr + } return nil, parseErrorResponse(resp.StatusCode, respBody) } @@ -309,3 +313,58 @@ func formatGraphQLError(e graphQLError) string { } return e.Message } + +func graphQLErrorAsError(e graphQLError) error { + return graphQLErrorAsErrorWithStatus(http.StatusOK, e) +} + +func parseGraphQLErrorResponse(statusCode int, body []byte) error { + var resp struct { + Errors []graphQLError `json:"errors"` + } + if err := json.Unmarshal(body, &resp); err != nil || len(resp.Errors) == 0 { + return nil + } + return graphQLErrorAsErrorWithStatus(statusCode, resp.Errors[0]) +} + +func graphQLErrorAsErrorWithStatus(statusCode int, e graphQLError) error { + message := formatGraphQLError(e) + if isGraphQLInvalidSession(statusCode, e) { + return domain.NewDashboardAPIError(http.StatusUnauthorized, "INVALID_SESSION", invalidSessionMessage(e)) + } + + if e.Extensions == nil || e.Extensions.Code == "" { + return fmt.Errorf("GraphQL error: %s", message) + } + + return domain.NewDashboardAPIError(statusCode, e.Extensions.Code, message) +} + +func isGraphQLInvalidSession(statusCode int, e graphQLError) bool { + if e.Extensions == nil { + return false + } + if e.Extensions.Code == "INVALID_SESSION" { + return true + } + if statusCode != http.StatusUnauthorized || e.Extensions.Code != "UNAUTHENTICATED" { + return false + } + + topLevel := strings.TrimSpace(e.Message) + extensionMsg := strings.TrimSpace(e.Extensions.Message) + return strings.EqualFold(topLevel, "INVALID_SESSION") || strings.EqualFold(extensionMsg, "INVALID_SESSION") +} + +func invalidSessionMessage(e graphQLError) string { + if e.Extensions != nil { + if msg := strings.TrimSpace(e.Extensions.Message); msg != "" && !strings.EqualFold(msg, "INVALID_SESSION") { + return msg + } + } + if msg := strings.TrimSpace(e.Message); msg != "" && !strings.EqualFold(msg, "INVALID_SESSION") { + return msg + } + return "Invalid or expired session" +} diff --git a/internal/adapters/dashboard/gateway_client_test.go b/internal/adapters/dashboard/gateway_client_test.go new file mode 100644 index 0000000..baec5c9 --- /dev/null +++ b/internal/adapters/dashboard/gateway_client_test.go @@ -0,0 +1,345 @@ +package dashboard + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayClientOperations(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T, client *GatewayClient) + handler func(t *testing.T, w http.ResponseWriter, r *http.Request, body map[string]any) + }{ + { + name: "list applications", + handler: func(t *testing.T, w http.ResponseWriter, r *http.Request, body map[string]any) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + assert.NotEmpty(t, r.Header.Get("DPoP")) + + assert.Contains(t, body["query"].(string), "applications(filter: $filter)") + variables := body["variables"].(map[string]any) + filter := variables["filter"].(map[string]any) + assert.Equal(t, "org-1", filter["orgPublicId"]) + + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "applications": map[string]any{ + "applications": []map[string]any{ + { + "applicationId": "app-1", + "organizationId": "org-1", + "region": "us", + "environment": "sandbox", + "branding": map[string]any{ + "name": "App One", + "description": "Primary", + }, + }, + }, + }, + }, + })) + }, + run: func(t *testing.T, client *GatewayClient) { + apps, err := client.ListApplications(context.Background(), "org-1", "us", "user-token", "org-token") + require.NoError(t, err) + require.Len(t, apps, 1) + assert.Equal(t, "app-1", apps[0].ApplicationID) + require.NotNil(t, apps[0].Branding) + assert.Equal(t, "App One", apps[0].Branding.Name) + }, + }, + { + name: "create application", + handler: func(t *testing.T, w http.ResponseWriter, _ *http.Request, body map[string]any) { + assert.Contains(t, body["query"].(string), "createApplication(orgPublicId: $orgPublicId, options: $options)") + variables := body["variables"].(map[string]any) + assert.Equal(t, "org-1", variables["orgPublicId"]) + options := variables["options"].(map[string]any) + assert.Equal(t, "eu", options["region"]) + branding := options["branding"].(map[string]any) + assert.Equal(t, "Created App", branding["name"]) + + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "createApplication": map[string]any{ + "applicationId": "app-new", + "clientSecret": "secret", + "organizationId": "org-1", + "region": "eu", + "environment": "production", + "branding": map[string]any{ + "name": "Created App", + }, + }, + }, + })) + }, + run: func(t *testing.T, client *GatewayClient) { + app, err := client.CreateApplication(context.Background(), "org-1", "eu", "Created App", "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "app-new", app.ApplicationID) + assert.Equal(t, "secret", app.ClientSecret) + }, + }, + { + name: "list api keys", + handler: func(t *testing.T, w http.ResponseWriter, _ *http.Request, body map[string]any) { + assert.Contains(t, body["query"].(string), "apiKeys(appId: $appId)") + variables := body["variables"].(map[string]any) + assert.Equal(t, "app-1", variables["appId"]) + + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "apiKeys": []map[string]any{ + { + "id": "key-1", + "name": "CI", + "status": "active", + "permissions": []string{"send"}, + "expiresAt": 10.0, + "createdAt": 5.0, + }, + }, + }, + })) + }, + run: func(t *testing.T, client *GatewayClient) { + keys, err := client.ListAPIKeys(context.Background(), "app-1", "us", "user-token", "org-token") + require.NoError(t, err) + require.Len(t, keys, 1) + assert.Equal(t, "key-1", keys[0].ID) + }, + }, + { + name: "create api key includes expiresIn when set", + handler: func(t *testing.T, w http.ResponseWriter, _ *http.Request, body map[string]any) { + assert.Contains(t, body["query"].(string), "createApiKey(appId: $appId, options: $options)") + variables := body["variables"].(map[string]any) + assert.Equal(t, "app-1", variables["appId"]) + options := variables["options"].(map[string]any) + assert.Equal(t, "Nightly", options["name"]) + assert.Equal(t, float64(30), options["expiresIn"]) + + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "createApiKey": map[string]any{ + "id": "key-2", + "name": "Nightly", + "apiKey": "secret-key", + "status": "active", + "permissions": []string{"send"}, + "expiresAt": 123.0, + "createdAt": 100.0, + }, + }, + })) + }, + run: func(t *testing.T, client *GatewayClient) { + key, err := client.CreateAPIKey(context.Background(), "app-1", "us", "Nightly", 30, "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "secret-key", key.APIKey) + }, + }, + { + name: "create api key omits expiresIn when zero", + handler: func(t *testing.T, w http.ResponseWriter, _ *http.Request, body map[string]any) { + variables := body["variables"].(map[string]any) + options := variables["options"].(map[string]any) + assert.Equal(t, "Default", options["name"]) + _, ok := options["expiresIn"] + assert.False(t, ok) + + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "createApiKey": map[string]any{ + "id": "key-3", + "name": "Default", + "apiKey": "key-default", + "status": "active", + "permissions": []string{"send"}, + "expiresAt": 0.0, + "createdAt": 100.0, + }, + }, + })) + }, + run: func(t *testing.T, client *GatewayClient) { + key, err := client.CreateAPIKey(context.Background(), "app-1", "us", "Default", 0, "user-token", "org-token") + require.NoError(t, err) + assert.Equal(t, "key-default", key.APIKey) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + tt.handler(t, w, r, body) + })) + defer server.Close() + + origGatewayURL := os.Getenv("NYLAS_DASHBOARD_GATEWAY_URL") + t.Cleanup(func() { setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_URL", origGatewayURL) }) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_URL", server.URL)) + + client := NewGatewayClient(&mockDPoP{proof: "test-proof"}) + tt.run(t, client) + }) + } +} + +func TestGatewayClientHandlesGraphQLErrorsAndRedirects(t *testing.T) { + t.Run("GraphQL error is surfaced", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "errors": []map[string]any{ + { + "message": "top-level", + "extensions": map[string]any{ + "message": "more specific", + }, + }, + }, + })) + })) + defer server.Close() + + origGatewayURL := os.Getenv("NYLAS_DASHBOARD_GATEWAY_URL") + t.Cleanup(func() { setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_URL", origGatewayURL) }) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_URL", server.URL)) + + client := NewGatewayClient(&mockDPoP{proof: "test-proof"}) + apps, err := client.ListApplications(context.Background(), "org-1", "us", "user-token", "org-token") + require.Error(t, err) + assert.Nil(t, apps) + assert.Contains(t, err.Error(), "failed to list applications") + assert.Contains(t, err.Error(), "GraphQL error: more specific") + }) + + t.Run("GraphQL invalid session preserves structured auth error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "errors": []map[string]any{ + { + "message": "session expired", + "extensions": map[string]any{ + "code": "INVALID_SESSION", + "message": "Invalid or expired session", + }, + }, + }, + })) + })) + defer server.Close() + + origGatewayURL := os.Getenv("NYLAS_DASHBOARD_GATEWAY_URL") + t.Cleanup(func() { setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_URL", origGatewayURL) }) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_URL", server.URL)) + + client := NewGatewayClient(&mockDPoP{proof: "test-proof"}) + apps, err := client.ListApplications(context.Background(), "org-1", "us", "user-token", "org-token") + require.Error(t, err) + assert.Nil(t, apps) + assert.ErrorIs(t, err, domain.ErrDashboardSessionExpired) + assert.Contains(t, err.Error(), "INVALID_SESSION") + }) + + t.Run("HTTP 401 GraphQL invalid session preserves structured auth error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "errors": []map[string]any{ + { + "message": "INVALID_SESSION", + "extensions": map[string]any{ + "code": "UNAUTHENTICATED", + }, + }, + }, + })) + })) + defer server.Close() + + origGatewayURL := os.Getenv("NYLAS_DASHBOARD_GATEWAY_URL") + t.Cleanup(func() { setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_URL", origGatewayURL) }) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_URL", server.URL)) + + client := NewGatewayClient(&mockDPoP{proof: "test-proof"}) + apps, err := client.ListApplications(context.Background(), "org-1", "us", "user-token", "org-token") + require.Error(t, err) + assert.Nil(t, apps) + assert.ErrorIs(t, err, domain.ErrDashboardSessionExpired) + assert.Contains(t, err.Error(), "INVALID_SESSION") + assert.Contains(t, err.Error(), "Invalid or expired session") + }) + + t.Run("redirect is not followed", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "https://redirected.example.com") + w.WriteHeader(http.StatusFound) + })) + defer server.Close() + + client := NewGatewayClient(&mockDPoP{proof: "test-proof"}) + _, err := client.doGraphQL(context.Background(), server.URL, "query { ok }", map[string]any{}, "user-token", "org-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "server redirected to https://redirected.example.com") + }) +} + +func TestGatewayURLPrecedence(t *testing.T) { + origGlobal := os.Getenv("NYLAS_DASHBOARD_GATEWAY_URL") + origUS := os.Getenv("NYLAS_DASHBOARD_GATEWAY_US_URL") + origEU := os.Getenv("NYLAS_DASHBOARD_GATEWAY_EU_URL") + t.Cleanup(func() { + setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_URL", origGlobal) + setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_US_URL", origUS) + setEnvOrUnsetLocal("NYLAS_DASHBOARD_GATEWAY_EU_URL", origEU) + }) + + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_URL", "https://global.example.com/graphql")) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_US_URL", "https://us.example.com/graphql")) + require.NoError(t, os.Setenv("NYLAS_DASHBOARD_GATEWAY_EU_URL", "https://eu.example.com/graphql")) + + assert.Equal(t, "https://us.example.com/graphql", gatewayURL("us")) + assert.Equal(t, "https://eu.example.com/graphql", gatewayURL("eu")) + + require.NoError(t, os.Unsetenv("NYLAS_DASHBOARD_GATEWAY_US_URL")) + require.NoError(t, os.Unsetenv("NYLAS_DASHBOARD_GATEWAY_EU_URL")) + + assert.Equal(t, "https://global.example.com/graphql", gatewayURL("us")) + assert.Equal(t, "https://global.example.com/graphql", gatewayURL("eu")) +} + +func TestFormatGraphQLError(t *testing.T) { + assert.Equal(t, "specific", formatGraphQLError(graphQLError{ + Message: "generic", + Extensions: &graphQLExtensions{ + Message: "specific", + }, + })) + assert.Equal(t, "generic", formatGraphQLError(graphQLError{Message: "generic"})) +} + +func setEnvOrUnsetLocal(key, value string) { + if value == "" { + _ = os.Unsetenv(key) + return + } + _ = os.Setenv(key, value) +} diff --git a/internal/adapters/dashboard/http.go b/internal/adapters/dashboard/http.go index 0db569e..1819967 100644 --- a/internal/adapters/dashboard/http.go +++ b/internal/adapters/dashboard/http.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "net/http" + + "github.com/nylas/cli/internal/domain" ) const ( @@ -188,20 +190,6 @@ func (c *AccountClient) doGet(ctx context.Context, path string, extraHeaders map return nil } -// DashboardAPIError represents an error from the dashboard API. -// It carries the status code and server message for debugging. -type DashboardAPIError struct { - StatusCode int - ServerMsg string -} - -func (e *DashboardAPIError) Error() string { - if e.ServerMsg != "" { - return fmt.Sprintf("dashboard API error (HTTP %d): %s", e.StatusCode, e.ServerMsg) - } - return fmt.Sprintf("dashboard API error (HTTP %d)", e.StatusCode) -} - // parseErrorResponse extracts a user-friendly error from an HTTP error response. // The dashboard-account error envelope is: // @@ -213,18 +201,17 @@ func parseErrorResponse(statusCode int, body []byte) error { Message string `json:"message"` } `json:"error"` } + code := "" msg := "" - if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" { + if json.Unmarshal(body, &errResp) == nil { + code = errResp.Error.Code msg = errResp.Error.Message - if errResp.Error.Code != "" { - msg = errResp.Error.Code + ": " + msg - } } - if msg == "" { + if msg == "" && code == "" { msg = string(body) if len(msg) > 200 { msg = msg[:200] } } - return &DashboardAPIError{StatusCode: statusCode, ServerMsg: msg} + return domain.NewDashboardAPIError(statusCode, code, msg) } diff --git a/internal/adapters/dashboard/http_test.go b/internal/adapters/dashboard/http_test.go index c72f6e0..bb5b06c 100644 --- a/internal/adapters/dashboard/http_test.go +++ b/internal/adapters/dashboard/http_test.go @@ -3,10 +3,12 @@ package dashboard import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" + "github.com/nylas/cli/internal/domain" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -84,6 +86,7 @@ func TestParseErrorResponse(t *testing.T) { statusCode int body string wantMsg string + wantErrIs error }{ { name: "parses error with code and message", @@ -109,14 +112,28 @@ func TestParseErrorResponse(t *testing.T) { body: string(make([]byte, 300)), wantMsg: "", // truncated to 200 chars }, + { + name: "classifies invalid session", + statusCode: 401, + body: `{"error":{"code":"INVALID_SESSION","message":"Invalid or expired session"}}`, + wantMsg: "INVALID_SESSION: Invalid or expired session", + wantErrIs: domain.ErrDashboardSessionExpired, + }, + { + name: "classifies code-only invalid session", + statusCode: 401, + body: `{"error":{"code":"INVALID_SESSION"}}`, + wantMsg: "INVALID_SESSION", + wantErrIs: domain.ErrDashboardSessionExpired, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := parseErrorResponse(tt.statusCode, []byte(tt.body)) - dashErr, ok := err.(*DashboardAPIError) + dashErr, ok := err.(*domain.DashboardAPIError) if !ok { - t.Fatalf("expected *DashboardAPIError, got %T", err) + t.Fatalf("expected *domain.DashboardAPIError, got %T", err) } if dashErr.StatusCode != tt.statusCode { t.Errorf("StatusCode = %d, want %d", dashErr.StatusCode, tt.statusCode) @@ -124,6 +141,9 @@ func TestParseErrorResponse(t *testing.T) { if tt.wantMsg != "" && dashErr.ServerMsg != tt.wantMsg { t.Errorf("ServerMsg = %q, want %q", dashErr.ServerMsg, tt.wantMsg) } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Fatalf("expected errors.Is(%v), got %v", tt.wantErrIs, err) + } }) } } @@ -179,17 +199,17 @@ func TestUnwrapEnvelope(t *testing.T) { func TestDashboardAPIError_Error(t *testing.T) { tests := []struct { name string - err DashboardAPIError + err domain.DashboardAPIError wantStr string }{ { name: "with message", - err: DashboardAPIError{StatusCode: 400, ServerMsg: "bad request"}, + err: domain.DashboardAPIError{StatusCode: 400, ServerMsg: "bad request"}, wantStr: "dashboard API error (HTTP 400): bad request", }, { name: "without message", - err: DashboardAPIError{StatusCode: 500}, + err: domain.DashboardAPIError{StatusCode: 500}, wantStr: "dashboard API error (HTTP 500)", }, } diff --git a/internal/app/dashboard/app_service.go b/internal/app/dashboard/app_service.go index 59b198a..59491ac 100644 --- a/internal/app/dashboard/app_service.go +++ b/internal/app/dashboard/app_service.go @@ -37,8 +37,9 @@ func (s *AppService) ListApplications(ctx context.Context, orgPublicID, regionFi // Query both regions in parallel type result struct { - apps []domain.GatewayApplication - err error + region string + apps []domain.GatewayApplication + err error } var wg sync.WaitGroup @@ -50,27 +51,37 @@ func (s *AppService) ListApplications(ctx context.Context, orgPublicID, regionFi go func(idx int, r string) { defer wg.Done() apps, err := s.gateway.ListApplications(ctx, orgPublicID, r, userToken, orgToken) - results[idx] = result{apps: apps, err: err} + results[idx] = result{region: r, apps: apps, err: err} }(i, region) } wg.Wait() var allApps []domain.GatewayApplication - var errs []error + failures := make(map[string]error) for _, r := range results { if r.err != nil { - errs = append(errs, r.err) + failures[r.region] = r.err continue } allApps = append(allApps, r.apps...) } // If both failed, return the first error - if len(errs) == len(regions) { - return nil, fmt.Errorf("failed to list applications: %w", errs[0]) + if len(failures) == len(regions) { + for _, region := range regions { + if err := failures[region]; err != nil { + return nil, fmt.Errorf("failed to list applications: %w", err) + } + } } allApps = deduplicateApps(allApps) + if len(failures) > 0 { + return allApps, &domain.DashboardPartialResultError{ + Operation: "application list", + Failures: failures, + } + } return allApps, nil } diff --git a/internal/app/dashboard/app_service_test.go b/internal/app/dashboard/app_service_test.go new file mode 100644 index 0000000..5af88b1 --- /dev/null +++ b/internal/app/dashboard/app_service_test.go @@ -0,0 +1,201 @@ +package dashboard + +import ( + "context" + "errors" + "sync" + "testing" + + dashboardadapter "github.com/nylas/cli/internal/adapters/dashboard" + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAppServiceListApplications(t *testing.T) { + t.Parallel() + + t.Run("region filter forwards single request", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-token", "org-token") + + mock := &dashboardadapter.MockGatewayClient{ + ListApplicationsFn: func(_ context.Context, orgPublicID, region, userToken, orgToken string) ([]domain.GatewayApplication, error) { + assert.Equal(t, "org-1", orgPublicID) + assert.Equal(t, "eu", region) + assert.Equal(t, "user-token", userToken) + assert.Equal(t, "org-token", orgToken) + return []domain.GatewayApplication{{ApplicationID: "app-eu", Region: "eu"}}, nil + }, + } + + svc := NewAppService(mock, store) + apps, err := svc.ListApplications(context.Background(), "org-1", "eu") + require.NoError(t, err) + require.Len(t, apps, 1) + assert.Equal(t, "app-eu", apps[0].ApplicationID) + }) + + t.Run("merges regions, tolerates one failure, and deduplicates", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-token", "org-token") + + mock := &dashboardadapter.MockGatewayClient{ + ListApplicationsFn: func(_ context.Context, _ string, region, _, _ string) ([]domain.GatewayApplication, error) { + switch region { + case "us": + return []domain.GatewayApplication{ + {ApplicationID: "shared-app", Region: "us"}, + {ApplicationID: "us-only", Region: "us"}, + }, nil + case "eu": + return []domain.GatewayApplication{ + {ApplicationID: "shared-app", Region: "eu"}, + }, errors.New("eu unavailable") + default: + return nil, nil + } + }, + } + + svc := NewAppService(mock, store) + apps, err := svc.ListApplications(context.Background(), "org-1", "") + require.Error(t, err) + var partialErr *domain.DashboardPartialResultError + require.ErrorAs(t, err, &partialErr) + require.Len(t, apps, 2) + assert.ElementsMatch(t, []string{"shared-app", "us-only"}, []string{apps[0].ApplicationID, apps[1].ApplicationID}) + }) + + t.Run("returns first error when both regions fail", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-token", "org-token") + + mock := &dashboardadapter.MockGatewayClient{ + ListApplicationsFn: func(_ context.Context, _ string, region, _, _ string) ([]domain.GatewayApplication, error) { + return nil, errors.New(region + " failed") + }, + } + + svc := NewAppService(mock, store) + apps, err := svc.ListApplications(context.Background(), "org-1", "") + require.Error(t, err) + assert.Nil(t, apps) + assert.Contains(t, err.Error(), "failed to list applications") + }) +} + +func TestAppServiceDeduplicateApps(t *testing.T) { + t.Parallel() + + input := []domain.GatewayApplication{ + {ApplicationID: "app-1", Region: "us"}, + {ApplicationID: "app-1", Region: "eu"}, + {Region: "us", Environment: "sandbox", Branding: &domain.GatewayApplicationBrand{Name: "No ID"}}, + {Region: "us", Environment: "sandbox", Branding: &domain.GatewayApplicationBrand{Name: "No ID"}}, + } + + got := deduplicateApps(input) + require.Len(t, got, 2) + assert.Equal(t, "app-1", got[0].ApplicationID) + assert.Equal(t, "No ID", got[1].Branding.Name) +} + +func TestAppServiceManagementCalls(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-token", "org-token") + + var mu sync.Mutex + calls := make([]string, 0, 3) + mock := &dashboardadapter.MockGatewayClient{ + CreateApplicationFn: func(_ context.Context, orgPublicID, region, name, userToken, orgToken string) (*domain.GatewayCreatedApplication, error) { + mu.Lock() + calls = append(calls, "createApp:"+orgPublicID+":"+region+":"+name+":"+userToken+":"+orgToken) + mu.Unlock() + return &domain.GatewayCreatedApplication{ApplicationID: "app-1"}, nil + }, + ListAPIKeysFn: func(_ context.Context, appID, region, userToken, orgToken string) ([]domain.GatewayAPIKey, error) { + mu.Lock() + calls = append(calls, "listKeys:"+appID+":"+region+":"+userToken+":"+orgToken) + mu.Unlock() + return []domain.GatewayAPIKey{{ID: "key-1"}}, nil + }, + CreateAPIKeyFn: func(_ context.Context, appID, region, name string, expiresInDays int, userToken, orgToken string) (*domain.GatewayCreatedAPIKey, error) { + mu.Lock() + calls = append(calls, "createKey:"+appID+":"+region+":"+name+":"+userToken+":"+orgToken) + mu.Unlock() + return &domain.GatewayCreatedAPIKey{ID: "key-2"}, nil + }, + } + + svc := NewAppService(mock, store) + + app, err := svc.CreateApplication(context.Background(), "org-1", "us", "Primary") + require.NoError(t, err) + assert.Equal(t, "app-1", app.ApplicationID) + + keys, err := svc.ListAPIKeys(context.Background(), "app-1", "us") + require.NoError(t, err) + require.Len(t, keys, 1) + + key, err := svc.CreateAPIKey(context.Background(), "app-1", "us", "Nightly", 30) + require.NoError(t, err) + assert.Equal(t, "key-2", key.ID) + + assert.Contains(t, calls, "createApp:org-1:us:Primary:user-token:org-token") + assert.Contains(t, calls, "listKeys:app-1:us:user-token:org-token") + assert.Contains(t, calls, "createKey:app-1:us:Nightly:user-token:org-token") +} + +func TestAppServiceReturnsNotLoggedInWhenTokensMissing(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + svc := NewAppService(&dashboardadapter.MockGatewayClient{}, store) + + _, err := svc.CreateApplication(context.Background(), "org-1", "us", "Primary") + require.Error(t, err) + assert.ErrorIs(t, err, domain.ErrDashboardNotLoggedIn) +} + +func TestLoadDashboardTokensPropagatesSecretStoreFailures(t *testing.T) { + t.Parallel() + + t.Run("user token load failure is returned", func(t *testing.T) { + t.Parallel() + + store := &failingSecretStore{ + memSecretStore: newMemSecretStore(), + failGetKey: ports.KeyDashboardUserToken, + } + + _, _, err := loadDashboardTokens(store) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to load dashboard user token") + assert.ErrorIs(t, err, domain.ErrSecretStoreFailed) + }) + + t.Run("org token load failure is returned", func(t *testing.T) { + t.Parallel() + + store := &failingSecretStore{ + memSecretStore: newMemSecretStore(), + failGetKey: ports.KeyDashboardOrgToken, + } + seedTokens(store, "user-token", "") + + _, _, err := loadDashboardTokens(store) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to load dashboard organization token") + assert.ErrorIs(t, err, domain.ErrSecretStoreFailed) + }) +} diff --git a/internal/app/dashboard/auth_service.go b/internal/app/dashboard/auth_service.go index 597f1ee..cc2fa88 100644 --- a/internal/app/dashboard/auth_service.go +++ b/internal/app/dashboard/auth_service.go @@ -4,6 +4,7 @@ package dashboard import ( "context" + "errors" "fmt" "github.com/nylas/cli/internal/domain" @@ -16,6 +17,32 @@ type AuthService struct { secrets ports.SecretStore } +var ( + dashboardSessionStateKeys = []string{ + ports.KeyDashboardUserToken, + ports.KeyDashboardOrgToken, + ports.KeyDashboardUserPublicID, + ports.KeyDashboardOrgPublicID, + ports.KeyDashboardAppID, + ports.KeyDashboardAppRegion, + } + dashboardRefreshStateKeys = []string{ + ports.KeyDashboardUserToken, + ports.KeyDashboardOrgToken, + } + dashboardSwitchOrgStateKeys = []string{ + ports.KeyDashboardOrgToken, + ports.KeyDashboardOrgPublicID, + ports.KeyDashboardAppID, + ports.KeyDashboardAppRegion, + } +) + +type secretSnapshot struct { + value string + present bool +} + // NewAuthService creates a new dashboard auth service. func NewAuthService(account ports.DashboardAccountClient, secrets ports.SecretStore) *AuthService { return &AuthService{ @@ -85,21 +112,8 @@ func (s *AuthService) Refresh(ctx context.Context) error { return err } - resp, err := s.account.Refresh(ctx, userToken, orgToken) - if err != nil { - return err - } - - if err := s.secrets.Set(ports.KeyDashboardUserToken, resp.UserToken); err != nil { - return fmt.Errorf("failed to store refreshed user token: %w", err) - } - if resp.OrgToken != "" { - if err := s.secrets.Set(ports.KeyDashboardOrgToken, resp.OrgToken); err != nil { - return fmt.Errorf("failed to store refreshed org token: %w", err) - } - } - - return nil + _, _, err = s.refreshTokens(ctx, userToken, orgToken) + return err } // Logout invalidates the session and clears local tokens. @@ -112,8 +126,7 @@ func (s *AuthService) Logout(ctx context.Context) error { } // Always clear local state - s.clearTokens() - return nil + return s.clearTokens() } // SSOStart initiates an SSO device authorization flow. @@ -169,6 +182,17 @@ func (s *AuthService) GetCurrentSession(ctx context.Context) (*domain.DashboardS if err != nil { return nil, err } + + session, err := s.account.GetCurrentSession(ctx, userToken, orgToken) + if !errors.Is(err, domain.ErrDashboardSessionExpired) { + return session, err + } + + userToken, orgToken, err = s.refreshTokens(ctx, userToken, orgToken) + if err != nil { + return nil, err + } + return s.account.GetCurrentSession(ctx, userToken, orgToken) } @@ -180,25 +204,29 @@ func (s *AuthService) SwitchOrg(ctx context.Context, orgPublicID string) (*domai } resp, err := s.account.SwitchOrg(ctx, orgPublicID, userToken, orgToken) + if errors.Is(err, domain.ErrDashboardSessionExpired) { + userToken, orgToken, err = s.refreshTokens(ctx, userToken, orgToken) + if err != nil { + return nil, err + } + resp, err = s.account.SwitchOrg(ctx, orgPublicID, userToken, orgToken) + } if err != nil { return nil, err } - // Store the new org token and org ID - if resp.OrgToken != "" { - if err := s.secrets.Set(ports.KeyDashboardOrgToken, resp.OrgToken); err != nil { - return nil, fmt.Errorf("failed to store org token: %w", err) - } - } + nextOrgID := orgPublicID if resp.Org.PublicID != "" { - if err := s.secrets.Set(ports.KeyDashboardOrgPublicID, resp.Org.PublicID); err != nil { - return nil, fmt.Errorf("failed to store org ID: %w", err) - } + nextOrgID = resp.Org.PublicID + } + if err := s.replaceSecretValues(dashboardSwitchOrgStateKeys, map[string]*string{ + ports.KeyDashboardOrgToken: stringPtrOrNil(resp.OrgToken), + ports.KeyDashboardOrgPublicID: stringPtrOrNil(nextOrgID), + ports.KeyDashboardAppID: nil, + ports.KeyDashboardAppRegion: nil, + }); err != nil { + return nil, fmt.Errorf("failed to persist organization switch: %w", err) } - - // Clear active app since it belongs to the previous org - _ = s.secrets.Delete(ports.KeyDashboardAppID) - _ = s.secrets.Delete(ports.KeyDashboardAppRegion) return resp, nil } @@ -221,25 +249,19 @@ func (s *AuthService) SyncSessionOrg(ctx context.Context) error { // storeTokens persists auth tokens and user/org identifiers. func (s *AuthService) storeTokens(resp *domain.DashboardAuthResponse) error { - if err := s.secrets.Set(ports.KeyDashboardUserToken, resp.UserToken); err != nil { - return err - } - if resp.OrgToken != "" { - if err := s.secrets.Set(ports.KeyDashboardOrgToken, resp.OrgToken); err != nil { - return err - } - } - if resp.User.PublicID != "" { - if err := s.secrets.Set(ports.KeyDashboardUserPublicID, resp.User.PublicID); err != nil { - return err - } - } + orgPublicID := "" if len(resp.Organizations) == 1 { - if err := s.secrets.Set(ports.KeyDashboardOrgPublicID, resp.Organizations[0].PublicID); err != nil { - return err - } + orgPublicID = resp.Organizations[0].PublicID } - return nil + + return s.replaceSecretValues(dashboardSessionStateKeys, map[string]*string{ + ports.KeyDashboardUserToken: stringPtrOrNil(resp.UserToken), + ports.KeyDashboardOrgToken: stringPtrOrNil(resp.OrgToken), + ports.KeyDashboardUserPublicID: stringPtrOrNil(resp.User.PublicID), + ports.KeyDashboardOrgPublicID: stringPtrOrNil(orgPublicID), + ports.KeyDashboardAppID: nil, + ports.KeyDashboardAppRegion: nil, + }) } // SetActiveOrg updates the active organization. @@ -249,16 +271,113 @@ func (s *AuthService) SetActiveOrg(orgPublicID string) error { // clearTokens removes all dashboard auth data from the keyring, // including the active app selection to prevent stale state after re-login. -func (s *AuthService) clearTokens() { - _ = s.secrets.Delete(ports.KeyDashboardUserToken) - _ = s.secrets.Delete(ports.KeyDashboardOrgToken) - _ = s.secrets.Delete(ports.KeyDashboardUserPublicID) - _ = s.secrets.Delete(ports.KeyDashboardOrgPublicID) - _ = s.secrets.Delete(ports.KeyDashboardAppID) - _ = s.secrets.Delete(ports.KeyDashboardAppRegion) +func (s *AuthService) clearTokens() error { + var errs []error + for _, key := range dashboardSessionStateKeys { + if err := s.secrets.Delete(key); err != nil { + errs = append(errs, fmt.Errorf("failed to clear %s: %w", key, err)) + } + } + return errors.Join(errs...) } // loadTokens retrieves the stored tokens. func (s *AuthService) loadTokens() (userToken, orgToken string, err error) { return loadDashboardTokens(s.secrets) } + +func (s *AuthService) refreshTokens(ctx context.Context, userToken, orgToken string) (string, string, error) { + resp, err := s.account.Refresh(ctx, userToken, orgToken) + if err != nil { + return "", "", err + } + + updates := map[string]*string{ + ports.KeyDashboardUserToken: stringPtrOrNil(resp.UserToken), + } + if resp.OrgToken != "" { + updates[ports.KeyDashboardOrgToken] = stringPtrOrNil(resp.OrgToken) + } + if err := s.replaceSecretValues(dashboardRefreshStateKeys, updates); err != nil { + return "", "", fmt.Errorf("failed to store refreshed credentials: %w", err) + } + userToken = resp.UserToken + if resp.OrgToken != "" { + orgToken = resp.OrgToken + } + + return userToken, orgToken, nil +} + +func (s *AuthService) replaceSecretValues(keys []string, updates map[string]*string) error { + snapshot, err := s.snapshotSecretValues(keys) + if err != nil { + return err + } + if err := s.applySecretValues(keys, updates); err != nil { + if rollbackErr := s.restoreSecretValues(keys, snapshot); rollbackErr != nil { + return errors.Join(err, fmt.Errorf("failed to rollback dashboard session state: %w", rollbackErr)) + } + return err + } + return nil +} + +func (s *AuthService) snapshotSecretValues(keys []string) (map[string]secretSnapshot, error) { + snapshot := make(map[string]secretSnapshot, len(keys)) + for _, key := range keys { + value, err := s.secrets.Get(key) + switch { + case err == nil: + snapshot[key] = secretSnapshot{value: value, present: true} + case errors.Is(err, domain.ErrSecretNotFound): + snapshot[key] = secretSnapshot{} + default: + return nil, fmt.Errorf("failed to read %s: %w", key, err) + } + } + return snapshot, nil +} + +func (s *AuthService) applySecretValues(keys []string, updates map[string]*string) error { + for _, key := range keys { + value, ok := updates[key] + if !ok { + continue + } + if value == nil { + if err := s.secrets.Delete(key); err != nil { + return fmt.Errorf("failed to clear %s: %w", key, err) + } + continue + } + if err := s.secrets.Set(key, *value); err != nil { + return fmt.Errorf("failed to store %s: %w", key, err) + } + } + return nil +} + +func (s *AuthService) restoreSecretValues(keys []string, snapshot map[string]secretSnapshot) error { + var errs []error + for _, key := range keys { + prev := snapshot[key] + var err error + if prev.present { + err = s.secrets.Set(key, prev.value) + } else { + err = s.secrets.Delete(key) + } + if err != nil { + errs = append(errs, fmt.Errorf("%s: %w", key, err)) + } + } + return errors.Join(errs...) +} + +func stringPtrOrNil(value string) *string { + if value == "" { + return nil + } + return &value +} diff --git a/internal/app/dashboard/auth_service_extra_test.go b/internal/app/dashboard/auth_service_extra_test.go new file mode 100644 index 0000000..d083263 --- /dev/null +++ b/internal/app/dashboard/auth_service_extra_test.go @@ -0,0 +1,267 @@ +package dashboard + +import ( + "context" + "errors" + "testing" + + dashboardadapter "github.com/nylas/cli/internal/adapters/dashboard" + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthServiceLoginStoresTokensAndDefersMFA(t *testing.T) { + t.Parallel() + + t.Run("success stores tokens", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + require.NoError(t, store.Set(ports.KeyDashboardAppID, "stale-app")) + require.NoError(t, store.Set(ports.KeyDashboardAppRegion, "eu")) + require.NoError(t, store.Set(ports.KeyDashboardOrgPublicID, "stale-org")) + mock := &dashboardadapter.MockAccountClient{ + LoginFn: func(_ context.Context, email, password, orgPublicID string) (*domain.DashboardAuthResponse, *domain.DashboardMFARequired, error) { + assert.Equal(t, "user@example.com", email) + assert.Equal(t, "secret", password) + assert.Equal(t, "org-1", orgPublicID) + return &domain.DashboardAuthResponse{ + UserToken: "user-token", + OrgToken: "org-token", + User: domain.DashboardUser{PublicID: "user-1"}, + }, nil, nil + }, + } + + svc := NewAuthService(mock, store) + auth, mfa, err := svc.Login(context.Background(), "user@example.com", "secret", "org-1") + require.NoError(t, err) + assert.NotNil(t, auth) + assert.Nil(t, mfa) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-token", storedUserToken) + storedUserID, _ := store.Get(ports.KeyDashboardUserPublicID) + assert.Equal(t, "user-1", storedUserID) + storedOrgID, _ := store.Get(ports.KeyDashboardOrgPublicID) + assert.Empty(t, storedOrgID) + appID, _ := store.Get(ports.KeyDashboardAppID) + assert.Empty(t, appID) + appRegion, _ := store.Get(ports.KeyDashboardAppRegion) + assert.Empty(t, appRegion) + }) + + t.Run("mfa response does not store tokens", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + mock := &dashboardadapter.MockAccountClient{ + LoginFn: func(_ context.Context, _, _, _ string) (*domain.DashboardAuthResponse, *domain.DashboardMFARequired, error) { + return nil, &domain.DashboardMFARequired{ + User: domain.DashboardUser{PublicID: "user-1"}, + }, nil + }, + } + + svc := NewAuthService(mock, store) + auth, mfa, err := svc.Login(context.Background(), "user@example.com", "secret", "") + require.NoError(t, err) + assert.Nil(t, auth) + assert.NotNil(t, mfa) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Empty(t, storedUserToken) + }) +} + +func TestAuthServiceStoresTokensForVerificationAndMFA(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + run func(t *testing.T, svc *AuthService, store ports.SecretStore) + mock *dashboardadapter.MockAccountClient + }{ + { + name: "verify email code", + mock: &dashboardadapter.MockAccountClient{ + VerifyEmailCodeFn: func(_ context.Context, email, code, region string) (*domain.DashboardAuthResponse, error) { + assert.Equal(t, "user@example.com", email) + assert.Equal(t, "123456", code) + assert.Equal(t, "us", region) + return &domain.DashboardAuthResponse{ + UserToken: "user-token", + OrgToken: "org-token", + User: domain.DashboardUser{PublicID: "user-1"}, + }, nil + }, + }, + run: func(t *testing.T, svc *AuthService, store ports.SecretStore) { + resp, err := svc.VerifyEmailCode(context.Background(), "user@example.com", "123456", "us") + require.NoError(t, err) + assert.Equal(t, "user-token", resp.UserToken) + }, + }, + { + name: "complete mfa", + mock: &dashboardadapter.MockAccountClient{ + LoginMFAFn: func(_ context.Context, userPublicID, code, orgPublicID string) (*domain.DashboardAuthResponse, error) { + assert.Equal(t, "user-1", userPublicID) + assert.Equal(t, "654321", code) + assert.Equal(t, "org-1", orgPublicID) + return &domain.DashboardAuthResponse{ + UserToken: "user-token", + OrgToken: "org-token", + User: domain.DashboardUser{PublicID: "user-1"}, + }, nil + }, + }, + run: func(t *testing.T, svc *AuthService, store ports.SecretStore) { + resp, err := svc.CompleteMFA(context.Background(), "user-1", "654321", "org-1") + require.NoError(t, err) + assert.Equal(t, "org-token", resp.OrgToken) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + svc := NewAuthService(tt.mock, store) + tt.run(t, svc, store) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-token", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "org-token", storedOrgToken) + }) + } +} + +func TestAuthServiceLogoutStatusAndSSOPoll(t *testing.T) { + t.Parallel() + + t.Run("logout clears local state even if server logout fails", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-token", "org-token") + require.NoError(t, store.Set(ports.KeyDashboardUserPublicID, "user-1")) + require.NoError(t, store.Set(ports.KeyDashboardOrgPublicID, "org-1")) + require.NoError(t, store.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, store.Set(ports.KeyDashboardAppRegion, "us")) + + mock := &dashboardadapter.MockAccountClient{ + LogoutFn: func(_ context.Context, userToken, orgToken string) error { + assert.Equal(t, "user-token", userToken) + assert.Equal(t, "org-token", orgToken) + return errors.New("network down") + }, + } + + svc := NewAuthService(mock, store) + require.NoError(t, svc.Logout(context.Background())) + + status := svc.GetStatus() + assert.False(t, status.LoggedIn) + assert.False(t, status.HasOrgToken) + appID, _ := store.Get(ports.KeyDashboardAppID) + assert.Empty(t, appID) + }) + + t.Run("logout returns local deletion failures", func(t *testing.T) { + t.Parallel() + + store := &failingSecretStore{ + memSecretStore: newMemSecretStore(), + failDeleteKey: ports.KeyDashboardOrgToken, + } + seedTokens(store, "user-token", "org-token") + + svc := NewAuthService(&dashboardadapter.MockAccountClient{ + LogoutFn: func(_ context.Context, _, _ string) error { return nil }, + }, store) + err := svc.Logout(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to clear "+ports.KeyDashboardOrgToken) + + userToken, getErr := store.Get(ports.KeyDashboardUserToken) + require.NoError(t, getErr) + assert.Empty(t, userToken) + orgToken, getErr := store.Get(ports.KeyDashboardOrgToken) + require.NoError(t, getErr) + assert.Equal(t, "org-token", orgToken) + }) + + t.Run("SSOPoll stores credentials only on complete", func(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + mock := &dashboardadapter.MockAccountClient{ + SSOPollFn: func(_ context.Context, flowID, orgPublicID string) (*domain.DashboardSSOPollResponse, error) { + assert.Equal(t, "flow-1", flowID) + assert.Equal(t, "org-1", orgPublicID) + return &domain.DashboardSSOPollResponse{ + Status: domain.SSOStatusComplete, + Auth: &domain.DashboardAuthResponse{ + UserToken: "user-token", + OrgToken: "org-token", + User: domain.DashboardUser{PublicID: "user-1"}, + }, + }, nil + }, + } + + svc := NewAuthService(mock, store) + resp, err := svc.SSOPoll(context.Background(), "flow-1", "org-1") + require.NoError(t, err) + require.NotNil(t, resp.Auth) + assert.True(t, svc.IsLoggedIn()) + status := svc.GetStatus() + assert.Equal(t, "user-1", status.UserID) + assert.True(t, status.HasOrgToken) + }) +} + +func TestAuthServiceStoreTokensRollsBackOnFailure(t *testing.T) { + t.Parallel() + + store := &failingSecretStore{ + memSecretStore: newMemSecretStore(), + failSetKey: ports.KeyDashboardUserPublicID, + } + require.NoError(t, store.Set(ports.KeyDashboardUserToken, "stale-user")) + require.NoError(t, store.Set(ports.KeyDashboardOrgToken, "stale-org-token")) + require.NoError(t, store.Set(ports.KeyDashboardOrgPublicID, "stale-org-id")) + require.NoError(t, store.Set(ports.KeyDashboardAppID, "stale-app")) + require.NoError(t, store.Set(ports.KeyDashboardAppRegion, "eu")) + + svc := NewAuthService(&dashboardadapter.MockAccountClient{}, store) + err := svc.storeTokens(&domain.DashboardAuthResponse{ + UserToken: "user-new", + OrgToken: "org-new", + User: domain.DashboardUser{PublicID: "user-new"}, + Organizations: []domain.DashboardOrganization{ + {PublicID: "org-1"}, + }, + }) + + require.Error(t, err) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "stale-user", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "stale-org-token", storedOrgToken) + storedOrgID, _ := store.Get(ports.KeyDashboardOrgPublicID) + assert.Equal(t, "stale-org-id", storedOrgID) + appID, _ := store.Get(ports.KeyDashboardAppID) + assert.Equal(t, "stale-app", appID) + appRegion, _ := store.Get(ports.KeyDashboardAppRegion) + assert.Equal(t, "eu", appRegion) +} diff --git a/internal/app/dashboard/auth_service_test.go b/internal/app/dashboard/auth_service_test.go index 076991e..366d859 100644 --- a/internal/app/dashboard/auth_service_test.go +++ b/internal/app/dashboard/auth_service_test.go @@ -3,6 +3,7 @@ package dashboard import ( "context" "errors" + "fmt" "testing" "github.com/nylas/cli/internal/domain" @@ -47,7 +48,9 @@ func (m *memSecretStore) Name() string { return "mem" } type failingSecretStore struct { *memSecretStore - failSetKey string + failSetKey string + failGetKey string + failDeleteKey string } func (f *failingSecretStore) Set(key, value string) error { @@ -57,6 +60,20 @@ func (f *failingSecretStore) Set(key, value string) error { return f.memSecretStore.Set(key, value) } +func (f *failingSecretStore) Get(key string) (string, error) { + if key == f.failGetKey { + return "", fmt.Errorf("%w: get failed", domain.ErrSecretStoreFailed) + } + return f.memSecretStore.Get(key) +} + +func (f *failingSecretStore) Delete(key string) error { + if key == f.failDeleteKey { + return errors.New("delete failed") + } + return f.memSecretStore.Delete(key) +} + // seedTokens pre-populates userToken (and optionally orgToken) so that // loadTokens() succeeds without going through a full Login flow. func seedTokens(s ports.SecretStore, userToken, orgToken string) { @@ -170,6 +187,72 @@ func TestAuthService_GetCurrentSession(t *testing.T) { } } +func TestAuthService_GetCurrentSessionRefreshesExpiredSession(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-old", "org-old") + + sessionResp := &domain.DashboardSessionResponse{ + User: domain.DashboardUser{PublicID: "user-1"}, + CurrentOrg: "org-1", + } + + var calls []string + mock := &dashboardadapter.MockAccountClient{ + GetCurrentSessionFn: func(_ context.Context, userToken, orgToken string) (*domain.DashboardSessionResponse, error) { + calls = append(calls, userToken+"/"+orgToken) + if len(calls) == 1 { + return nil, domain.NewDashboardAPIError(401, "INVALID_SESSION", "Invalid or expired session") + } + return sessionResp, nil + }, + RefreshFn: func(_ context.Context, userToken, orgToken string) (*domain.DashboardRefreshResponse, error) { + assert.Equal(t, "user-old", userToken) + assert.Equal(t, "org-old", orgToken) + return &domain.DashboardRefreshResponse{ + UserToken: "user-new", + OrgToken: "org-new", + }, nil + }, + } + + svc := NewAuthService(mock, store) + got, err := svc.GetCurrentSession(context.Background()) + + require.NoError(t, err) + assert.Equal(t, sessionResp, got) + assert.Equal(t, []string{"user-old/org-old", "user-new/org-new"}, calls) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-new", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "org-new", storedOrgToken) +} + +func TestAuthService_GetCurrentSessionReturnsRefreshError(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-old", "org-old") + + mock := &dashboardadapter.MockAccountClient{ + GetCurrentSessionFn: func(_ context.Context, _, _ string) (*domain.DashboardSessionResponse, error) { + return nil, domain.NewDashboardAPIError(401, "INVALID_SESSION", "Invalid or expired session") + }, + RefreshFn: func(_ context.Context, _, _ string) (*domain.DashboardRefreshResponse, error) { + return nil, errors.New("refresh backend unavailable") + }, + } + + svc := NewAuthService(mock, store) + got, err := svc.GetCurrentSession(context.Background()) + + require.Error(t, err) + assert.Nil(t, got) + assert.EqualError(t, err, "refresh backend unavailable") +} + // --------------------------------------------------------------------------- // TestAuthService_SwitchOrg // --------------------------------------------------------------------------- @@ -336,6 +419,111 @@ func TestAuthService_SwitchOrg(t *testing.T) { } } +func TestAuthService_SwitchOrgRefreshesExpiredSession(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-old", "org-old") + + var switchCalls []string + mock := &dashboardadapter.MockAccountClient{ + SwitchOrgFn: func(_ context.Context, orgPublicID, userToken, orgToken string) (*domain.DashboardSwitchOrgResponse, error) { + switchCalls = append(switchCalls, orgPublicID+":"+userToken+"/"+orgToken) + if len(switchCalls) == 1 { + return nil, domain.NewDashboardAPIError(401, "INVALID_SESSION", "Invalid or expired session") + } + return &domain.DashboardSwitchOrgResponse{ + OrgToken: "org-fresh", + Org: domain.DashboardSwitchOrgOrg{PublicID: "org-target"}, + }, nil + }, + RefreshFn: func(_ context.Context, userToken, orgToken string) (*domain.DashboardRefreshResponse, error) { + assert.Equal(t, "user-old", userToken) + assert.Equal(t, "org-old", orgToken) + return &domain.DashboardRefreshResponse{ + UserToken: "user-fresh", + OrgToken: "org-fresh-refresh", + }, nil + }, + } + + svc := NewAuthService(mock, store) + resp, err := svc.SwitchOrg(context.Background(), "org-target") + + require.NoError(t, err) + assert.Equal(t, &domain.DashboardSwitchOrgResponse{ + OrgToken: "org-fresh", + Org: domain.DashboardSwitchOrgOrg{PublicID: "org-target"}, + }, resp) + assert.Equal(t, []string{ + "org-target:user-old/org-old", + "org-target:user-fresh/org-fresh-refresh", + }, switchCalls) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-fresh", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "org-fresh", storedOrgToken) +} + +func TestAuthService_RefreshStoresUpdatedTokens(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + seedTokens(store, "user-old", "org-old") + + mock := &dashboardadapter.MockAccountClient{ + RefreshFn: func(_ context.Context, userToken, orgToken string) (*domain.DashboardRefreshResponse, error) { + assert.Equal(t, "user-old", userToken) + assert.Equal(t, "org-old", orgToken) + return &domain.DashboardRefreshResponse{ + UserToken: "user-new", + OrgToken: "org-new", + }, nil + }, + } + + svc := NewAuthService(mock, store) + err := svc.Refresh(context.Background()) + + require.NoError(t, err) + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-new", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "org-new", storedOrgToken) +} + +func TestAuthService_RefreshRollsBackOnStoreFailure(t *testing.T) { + t.Parallel() + + store := &failingSecretStore{ + memSecretStore: newMemSecretStore(), + } + seedTokens(store, "user-old", "org-old") + store.failSetKey = ports.KeyDashboardOrgToken + + mock := &dashboardadapter.MockAccountClient{ + RefreshFn: func(_ context.Context, _, _ string) (*domain.DashboardRefreshResponse, error) { + return &domain.DashboardRefreshResponse{ + UserToken: "user-new", + OrgToken: "org-new", + }, nil + }, + } + + svc := NewAuthService(mock, store) + err := svc.Refresh(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to store refreshed credentials") + + storedUserToken, _ := store.Get(ports.KeyDashboardUserToken) + assert.Equal(t, "user-old", storedUserToken) + storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) + assert.Equal(t, "org-old", storedOrgToken) +} + // --------------------------------------------------------------------------- // TestAuthService_SyncSessionOrg // --------------------------------------------------------------------------- diff --git a/internal/app/dashboard/session_store.go b/internal/app/dashboard/session_store.go index 08f1cc9..3448481 100644 --- a/internal/app/dashboard/session_store.go +++ b/internal/app/dashboard/session_store.go @@ -1,6 +1,7 @@ package dashboard import ( + "errors" "fmt" "github.com/nylas/cli/internal/domain" @@ -11,9 +12,22 @@ import ( // Returns ErrDashboardNotLoggedIn when no user token is present. func loadDashboardTokens(secrets ports.SecretStore) (userToken, orgToken string, err error) { userToken, err = secrets.Get(ports.KeyDashboardUserToken) - if err != nil || userToken == "" { + if err != nil { + if errors.Is(err, domain.ErrSecretNotFound) { + return "", "", fmt.Errorf("%w", domain.ErrDashboardNotLoggedIn) + } + return "", "", fmt.Errorf("failed to load dashboard user token: %w", err) + } + if userToken == "" { return "", "", fmt.Errorf("%w", domain.ErrDashboardNotLoggedIn) } - orgToken, _ = secrets.Get(ports.KeyDashboardOrgToken) + + orgToken, err = secrets.Get(ports.KeyDashboardOrgToken) + if err != nil { + if errors.Is(err, domain.ErrSecretNotFound) { + return userToken, "", nil + } + return "", "", fmt.Errorf("failed to load dashboard organization token: %w", err) + } return userToken, orgToken, nil } diff --git a/internal/cli/dashboard/apps.go b/internal/cli/dashboard/apps.go index 1da29ad..30b312c 100644 --- a/internal/cli/dashboard/apps.go +++ b/internal/cli/dashboard/apps.go @@ -1,6 +1,7 @@ package dashboard import ( + "errors" "fmt" "github.com/spf13/cobra" @@ -71,7 +72,11 @@ Use --region to filter to a specific region.`, apps, err := appSvc.ListApplications(ctx, orgPublicID, region) if err != nil { - return wrapDashboardError(err) + var partialErr *domain.DashboardPartialResultError + if !errors.As(err, &partialErr) || len(apps) == 0 { + return wrapDashboardError(err) + } + common.PrintWarning("%v", partialErr) } if len(apps) == 0 { @@ -92,8 +97,9 @@ Use --region to filter to a specific region.`, func newAppsCreateCmd() *cobra.Command { var ( - name string - region string + name string + region string + secretDelivery string ) cmd := &cobra.Command{ @@ -104,7 +110,10 @@ func newAppsCreateCmd() *cobra.Command { nylas dashboard apps create --name "My App" --region us # Create an EU application - nylas dashboard apps create --name "EU App" --region eu`, + nylas dashboard apps create --name "EU App" --region eu + + # Non-interactive secret delivery + nylas dashboard apps create --name "My App" --region us --secret-delivery file`, RunE: func(cmd *cobra.Command, args []string) error { if name == "" { return dashboardError("application name is required", "Use --name to specify the application name") @@ -115,6 +124,15 @@ func newAppsCreateCmd() *cobra.Command { if region != "us" && region != "eu" { return dashboardError("invalid region", "Use --region us or --region eu") } + if err := validateSecretDelivery(secretDelivery); err != nil { + return err + } + if !isInteractive() && secretDelivery == "" { + return dashboardError( + "client secret delivery requires an explicit choice in non-interactive runs", + "Pass --secret-delivery clipboard or --secret-delivery file", + ) + } appSvc, err := createAppService() if err != nil { @@ -143,7 +161,7 @@ func newAppsCreateCmd() *cobra.Command { if app.ClientSecret != "" { _, _ = common.Yellow.Println("\n Client Secret (available once — save it now):") - if err := handleSecretDelivery(app.ClientSecret, "Client Secret"); err != nil { + if err := handleSecretDelivery(app.ClientSecret, "Client Secret", secretDelivery); err != nil { return err } } @@ -157,6 +175,7 @@ func newAppsCreateCmd() *cobra.Command { cmd.Flags().StringVarP(&name, "name", "n", "", "Application name (required)") cmd.Flags().StringVarP(®ion, "region", "r", "", "Region (required: us or eu)") + cmd.Flags().StringVar(&secretDelivery, "secret-delivery", "", "Secret delivery method (clipboard or file)") return cmd } @@ -189,6 +208,12 @@ When called without arguments, lists your applications and lets you pick one int // If no app ID provided, show interactive selector if appID == "" { + if !isInteractive() { + return dashboardError( + "application selection requires an interactive terminal", + "Pass the application ID and --region in non-interactive runs", + ) + } selectedID, selectedRegion, err := selectApp(region) if err != nil { return wrapDashboardError(err) @@ -226,6 +251,13 @@ When called without arguments, lists your applications and lets you pick one int // selectApp fetches apps and presents an interactive selector. // Returns the selected app ID and region. func selectApp(regionFilter string) (appID, region string, err error) { + if !isInteractive() { + return "", "", dashboardError( + "application selection requires an interactive terminal", + "Pass the application ID and --region in non-interactive runs", + ) + } + appSvc, err := createAppService() if err != nil { return "", "", err @@ -280,7 +312,13 @@ func selectApp(regionFilter string) (appID, region string, err error) { // getActiveApp returns the active app ID and region from the keyring. // Flags take priority over the stored active app. func getActiveApp(appFlag, regionFlag string) (appID, region string, err error) { - if appFlag != "" && regionFlag != "" { + if appFlag != "" || regionFlag != "" { + if appFlag == "" || regionFlag == "" { + return "", "", dashboardError( + "both --app and --region must be provided together", + "Pass both flags, or use 'nylas dashboard apps use --region ' first", + ) + } return appFlag, regionFlag, nil } diff --git a/internal/cli/dashboard/dashboard_test.go b/internal/cli/dashboard/dashboard_test.go index c5b48e5..3beaf32 100644 --- a/internal/cli/dashboard/dashboard_test.go +++ b/internal/cli/dashboard/dashboard_test.go @@ -3,6 +3,7 @@ package dashboard import ( "context" "os" + "runtime" "testing" "github.com/stretchr/testify/assert" @@ -90,6 +91,16 @@ func TestResolveAuthMethod(t *testing.T) { action: "log in", wantErr: "only one auth method flag allowed", }, + { + name: "non-interactive login requires explicit auth method", + action: "log in", + wantErr: "auth method is required", + }, + { + name: "non-interactive register requires explicit auth method", + action: "register", + wantErr: "auth method is required", + }, } for _, tt := range tests { @@ -112,6 +123,25 @@ func TestResolveAuthMethod(t *testing.T) { } } +func TestAcceptPrivacyPolicy(t *testing.T) { + t.Parallel() + + t.Run("accepted flag skips prompt", func(t *testing.T) { + t.Parallel() + + require.NoError(t, acceptPrivacyPolicy(true)) + }) + + t.Run("non-interactive mode requires explicit acceptance", func(t *testing.T) { + t.Parallel() + + err := acceptPrivacyPolicy(false) + require.Error(t, err) + assert.Contains(t, err.Error(), "privacy policy must be accepted") + assert.Contains(t, err.Error(), "--accept-privacy-policy") + }) +} + func TestGetDashboardAccountBaseURL(t *testing.T) { t.Parallel() @@ -252,13 +282,182 @@ func TestPersistActiveOrgSwitchesServerSession(t *testing.T) { {PublicID: "org-1", Name: "Org One"}, {PublicID: "org-2", Name: "Org Two"}, }, - }, "") + }, "org-2") require.NoError(t, err) - assert.Equal(t, "org-1", switchedOrg, "non-interactive selection should fall back to the first org") + assert.Equal(t, "org-2", switchedOrg) storedOrgID, _ := store.Get(ports.KeyDashboardOrgPublicID) - assert.Equal(t, "org-1", storedOrgID) + assert.Equal(t, "org-2", storedOrgID) storedOrgToken, _ := store.Get(ports.KeyDashboardOrgToken) assert.Equal(t, "new-org-token", storedOrgToken) } + +func TestPersistActiveOrgRejectsNonInteractiveMultiOrgSelection(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + require.NoError(t, store.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, store.Set(ports.KeyDashboardOrgToken, "org-token")) + + authSvc := dashboardapp.NewAuthService(&dashboardadapter.MockAccountClient{}, store) + + err := persistActiveOrg(authSvc, &domain.DashboardAuthResponse{ + Organizations: []domain.DashboardOrganization{ + {PublicID: "org-1", Name: "Org One"}, + {PublicID: "org-2", Name: "Org Two"}, + }, + }, "") + + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple organizations available") +} + +func TestRollbackPostAuthFailureClearsStoredSession(t *testing.T) { + t.Parallel() + + store := newMemSecretStore() + require.NoError(t, store.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, store.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, store.Set(ports.KeyDashboardUserPublicID, "user-1")) + require.NoError(t, store.Set(ports.KeyDashboardOrgPublicID, "org-1")) + require.NoError(t, store.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, store.Set(ports.KeyDashboardAppRegion, "us")) + + var logoutCalled bool + authSvc := dashboardapp.NewAuthService(&dashboardadapter.MockAccountClient{ + LogoutFn: func(_ context.Context, userToken, orgToken string) error { + logoutCalled = true + assert.Equal(t, "user-token", userToken) + assert.Equal(t, "org-token", orgToken) + return nil + }, + }, store) + + rollbackPostAuthFailure(authSvc) + + assert.True(t, logoutCalled) + for _, key := range []string{ + ports.KeyDashboardUserToken, + ports.KeyDashboardOrgToken, + ports.KeyDashboardUserPublicID, + ports.KeyDashboardOrgPublicID, + ports.KeyDashboardAppID, + ports.KeyDashboardAppRegion, + } { + _, ok := store.data[key] + assert.False(t, ok, "expected %s to be removed", key) + } +} + +func TestGetActiveAppRequiresPairedFlags(t *testing.T) { + t.Parallel() + + _, _, err := getActiveApp("app-1", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "both --app and --region") + + _, _, err = getActiveApp("", "us") + require.Error(t, err) + assert.Contains(t, err.Error(), "both --app and --region") + + appID, region, err := getActiveApp("app-1", "us") + require.NoError(t, err) + assert.Equal(t, "app-1", appID) + assert.Equal(t, "us", region) +} + +func TestWriteSecretTempFileCreatesUniqueFiles(t *testing.T) { + t.Parallel() + + path1, err := writeSecretTempFile("secret-1", "nylas-api-key.txt") + require.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(path1) }) + + path2, err := writeSecretTempFile("secret-2", "nylas-api-key.txt") + require.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(path2) }) + + assert.NotEqual(t, path1, path2) + + data1, err := os.ReadFile(path1) + require.NoError(t, err) + assert.Equal(t, "secret-1\n", string(data1)) + + data2, err := os.ReadFile(path2) + require.NoError(t, err) + assert.Equal(t, "secret-2\n", string(data2)) + + if runtime.GOOS != "windows" { + info, err := os.Stat(path1) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + } +} + +func TestResolveSSOMFAOrg(t *testing.T) { + t.Parallel() + + t.Run("uses explicit org when provided", func(t *testing.T) { + t.Parallel() + + orgID, err := resolveSSOMFAOrg("org-2", []domain.DashboardOrganization{ + {PublicID: "org-1"}, + {PublicID: "org-2"}, + }) + require.NoError(t, err) + assert.Equal(t, "org-2", orgID) + }) + + t.Run("uses the only organization", func(t *testing.T) { + t.Parallel() + + orgID, err := resolveSSOMFAOrg("", []domain.DashboardOrganization{{PublicID: "org-1"}}) + require.NoError(t, err) + assert.Equal(t, "org-1", orgID) + }) + + t.Run("rejects multi-org MFA without explicit org in non-interactive mode", func(t *testing.T) { + t.Parallel() + + orgID, err := resolveSSOMFAOrg("", []domain.DashboardOrganization{ + {PublicID: "org-1"}, + {PublicID: "org-2"}, + }) + require.Error(t, err) + assert.Empty(t, orgID) + assert.Contains(t, err.Error(), "multiple organizations available for MFA") + }) +} + +func TestHandleAPIKeyDeliveryRejectsUnsafeNonInteractivePrompt(t *testing.T) { + t.Parallel() + + err := handleAPIKeyDelivery("secret", "app-1", "us", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "API key delivery requires an explicit choice") +} + +func TestHandleSecretDeliveryRejectsUnsafeNonInteractivePrompt(t *testing.T) { + t.Parallel() + + err := handleSecretDelivery("secret", "Client Secret", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "client secret delivery requires an explicit choice") +} + +func TestValidateDeliveryChoices(t *testing.T) { + t.Parallel() + + require.NoError(t, validateAPIKeyDelivery("")) + require.NoError(t, validateAPIKeyDelivery("activate")) + require.NoError(t, validateSecretDelivery("clipboard")) + + err := validateAPIKeyDelivery("print") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid API key delivery method") + + err = validateSecretDelivery("terminal") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid secret delivery method") +} diff --git a/internal/cli/dashboard/exports.go b/internal/cli/dashboard/exports.go index 6ca84b1..68427a1 100644 --- a/internal/cli/dashboard/exports.go +++ b/internal/cli/dashboard/exports.go @@ -23,7 +23,7 @@ func RunSSO(provider, mode string, privacyAccepted bool) error { // AcceptPrivacyPolicy prompts for privacy policy acceptance (exported for setup wizard). func AcceptPrivacyPolicy() error { - return acceptPrivacyPolicy() + return acceptPrivacyPolicy(false) } // ActivateAPIKey stores an API key in the keyring and configures the CLI (exported for setup wizard). diff --git a/internal/cli/dashboard/helpers.go b/internal/cli/dashboard/helpers.go index a319e0d..9f56926 100644 --- a/internal/cli/dashboard/helpers.go +++ b/internal/cli/dashboard/helpers.go @@ -13,6 +13,7 @@ import ( "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" + "golang.org/x/term" ) // createDPoPService creates a DPoP service backed by the keyring. @@ -141,6 +142,17 @@ func resolveAuthMethod(google, microsoft, github, email bool, action string) (st // chooseAuthMethod presents an interactive menu. SSO first. // Email/password registration is temporarily disabled. func chooseAuthMethod(action string) (string, error) { + if !isInteractive() { + hint := "Use --google, --microsoft, --github, or --email" + if action == "register" { + hint = "Use --google, --microsoft, or --github" + } + return "", dashboardError( + fmt.Sprintf("an auth method is required to %s in non-interactive runs", action), + hint, + ) + } + opts := []common.SelectOption[string]{ {Label: "Google (recommended)", Value: methodGoogle}, {Label: "Microsoft", Value: methodMicrosoft}, @@ -153,13 +165,24 @@ func chooseAuthMethod(action string) (string, error) { return common.Select(fmt.Sprintf("How would you like to %s?", action), opts) } +func isInteractive() bool { + return term.IsTerminal(int(os.Stdin.Fd())) +} + // selectOrg prompts the user to select an organization if multiple are available. -func selectOrg(orgs []domain.DashboardOrganization) string { +func selectOrg(orgs []domain.DashboardOrganization) (string, error) { if len(orgs) <= 1 { if len(orgs) == 1 { - return orgs[0].PublicID + return orgs[0].PublicID, nil } - return "" + return "", nil + } + + if !isInteractive() { + return "", dashboardError( + "multiple organizations available", + "Pass --org to choose the organization in non-interactive runs", + ) } opts := make([]common.SelectOption[string], len(orgs)) @@ -173,15 +196,19 @@ func selectOrg(orgs []domain.DashboardOrganization) string { selected, err := common.Select("Select organization", opts) if err != nil { - return orgs[0].PublicID + return "", err } - return selected + return selected, nil } func persistActiveOrg(authSvc *dashboardapp.AuthService, auth *domain.DashboardAuthResponse, orgPublicID string) error { selectedOrgID := orgPublicID if selectedOrgID == "" && len(auth.Organizations) > 1 { - selectedOrgID = selectOrg(auth.Organizations) + var err error + selectedOrgID, err = selectOrg(auth.Organizations) + if err != nil { + return err + } } if selectedOrgID == "" { return nil @@ -198,6 +225,16 @@ func persistActiveOrg(authSvc *dashboardapp.AuthService, auth *domain.DashboardA return nil } +func rollbackPostAuthFailure(authSvc *dashboardapp.AuthService) { + if authSvc == nil { + return + } + + ctx, cancel := common.CreateContext() + defer cancel() + _ = authSvc.Logout(ctx) +} + // printAuthSuccess prints the standard post-login success message. // It reads the stored active org from the keyring (set by SyncSessionOrg) // so it reflects the server's actual current org. @@ -240,7 +277,17 @@ func syncSessionOrgWithWarning(authSvc *dashboardapp.AuthService) { } // acceptPrivacyPolicy prompts for or validates privacy policy acceptance. -func acceptPrivacyPolicy() error { +func acceptPrivacyPolicy(acceptedFlag bool) error { + if acceptedFlag { + return nil + } + if !isInteractive() { + return dashboardError( + "privacy policy must be accepted to continue", + "Pass --accept-privacy-policy to confirm acceptance in non-interactive runs", + ) + } + accepted, err := common.ConfirmPrompt("Accept Nylas Privacy Policy?", true) if err != nil { return err diff --git a/internal/cli/dashboard/keys.go b/internal/cli/dashboard/keys.go index b692f1a..3a65948 100644 --- a/internal/cli/dashboard/keys.go +++ b/internal/cli/dashboard/keys.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" "github.com/spf13/cobra" @@ -113,6 +114,7 @@ func newAPIKeysCreateCmd() *cobra.Command { region string name string expiresIn int + delivery string ) cmd := &cobra.Command{ @@ -123,7 +125,7 @@ func newAPIKeysCreateCmd() *cobra.Command { After creation, you choose what to do with the key: 1. Activate it — store in CLI keyring as the active API key (recommended) 2. Copy to clipboard — for use in other tools - 3. Print to terminal — for piping or scripts + 3. Save to file — for handoff or scripts Set an active app with: nylas dashboard apps use --region `, Example: ` # Using active app (simplest) @@ -135,6 +137,9 @@ Set an active app with: nylas dashboard apps use --region `, # Explicit app nylas dashboard apps apikeys create --app --region us + # Non-interactive delivery + nylas dashboard apps apikeys create --delivery activate + # Create with custom expiration (days) nylas dashboard apps apikeys create --expires 30`, RunE: func(cmd *cobra.Command, args []string) error { @@ -148,6 +153,15 @@ Set an active app with: nylas dashboard apps use --region `, if name == "" { name = "CLI-" + time.Now().Format("20060102-150405") } + if err := validateAPIKeyDelivery(delivery); err != nil { + return err + } + if !isInteractive() && delivery == "" { + return dashboardError( + "API key delivery requires an explicit choice in non-interactive runs", + "Pass --delivery activate, --delivery clipboard, or --delivery file", + ) + } appSvc, err := createAppService() if err != nil { @@ -170,7 +184,7 @@ Set an active app with: nylas dashboard apps use --region `, fmt.Printf(" ID: %s\n", key.ID) fmt.Printf(" Name: %s\n", key.Name) - return handleAPIKeyDelivery(key.APIKey, appID, region) + return handleAPIKeyDelivery(key.APIKey, appID, region, delivery) }, } @@ -178,31 +192,35 @@ Set an active app with: nylas dashboard apps use --region `, cmd.Flags().StringVarP(®ion, "region", "r", "", "Region (overrides active app)") cmd.Flags().StringVarP(&name, "name", "n", "", "API key name (default: CLI-)") cmd.Flags().IntVar(&expiresIn, "expires", 0, "Expiration in days (default: no expiration)") + cmd.Flags().StringVar(&delivery, "delivery", "", "API key delivery method (activate, clipboard, or file)") return cmd } // handleAPIKeyDelivery prompts the user to choose how to handle the newly created key. // The API key is never printed to stdout to prevent leaking it in terminal history or logs. -func handleAPIKeyDelivery(apiKey, appID, region string) error { - type deliveryChoice string - const ( - choiceActivate deliveryChoice = "activate" - choiceClipboard deliveryChoice = "clipboard" - choiceFile deliveryChoice = "file" - ) - - choice, err := common.Select("What would you like to do with this API key?", []common.SelectOption[deliveryChoice]{ - {Label: "Activate for this CLI (recommended)", Value: choiceActivate}, - {Label: "Copy to clipboard", Value: choiceClipboard}, - {Label: "Save to file", Value: choiceFile}, - }) - if err != nil { - return wrapDashboardError(err) +func handleAPIKeyDelivery(apiKey, appID, region, delivery string) error { + choice := delivery + if choice == "" { + if !isInteractive() { + return dashboardError( + "API key delivery requires an explicit choice in non-interactive runs", + "Pass --delivery activate, --delivery clipboard, or --delivery file", + ) + } + selected, err := common.Select("What would you like to do with this API key?", []common.SelectOption[string]{ + {Label: "Activate for this CLI (recommended)", Value: "activate"}, + {Label: "Copy to clipboard", Value: "clipboard"}, + {Label: "Save to file", Value: "file"}, + }) + if err != nil { + return wrapDashboardError(err) + } + choice = selected } switch choice { - case choiceActivate: + case "activate": if err := activateAPIKey(apiKey, appID, region); err != nil { _, _ = common.Yellow.Printf(" Could not activate: %v\n", err) return nil @@ -210,7 +228,7 @@ func handleAPIKeyDelivery(apiKey, appID, region string) error { _, _ = common.Green.Println("✓ API key activated — CLI is ready to use") _, _ = common.Dim.Println(" Try: nylas auth status") - case choiceClipboard: + case "clipboard": if err := common.CopyToClipboard(apiKey); err != nil { _, _ = common.Yellow.Printf(" Clipboard unavailable: %v\n", err) _, _ = common.Dim.Println(" Falling back to file save") @@ -218,8 +236,14 @@ func handleAPIKeyDelivery(apiKey, appID, region string) error { } _, _ = common.Green.Println("✓ API key copied to clipboard") - case choiceFile: + case "file": return saveSecretToFile(apiKey, "nylas-api-key.txt", "API key") + + default: + return dashboardError( + "invalid API key delivery method", + "Use --delivery activate, --delivery clipboard, or --delivery file", + ) } return nil @@ -227,23 +251,27 @@ func handleAPIKeyDelivery(apiKey, appID, region string) error { // handleSecretDelivery prompts the user to choose how to receive a secret. // Secrets are never printed to stdout to prevent leaking in terminal history or logs. -func handleSecretDelivery(secret, label string) error { - type deliveryChoice string - const ( - choiceClipboard deliveryChoice = "clipboard" - choiceFile deliveryChoice = "file" - ) - - choice, err := common.Select(fmt.Sprintf("How would you like to receive the %s?", label), []common.SelectOption[deliveryChoice]{ - {Label: "Copy to clipboard (recommended)", Value: choiceClipboard}, - {Label: "Save to file", Value: choiceFile}, - }) - if err != nil { - return wrapDashboardError(err) +func handleSecretDelivery(secret, label, delivery string) error { + choice := delivery + if choice == "" { + if !isInteractive() { + return dashboardError( + "client secret delivery requires an explicit choice in non-interactive runs", + "Pass --secret-delivery clipboard or --secret-delivery file", + ) + } + selected, err := common.Select(fmt.Sprintf("How would you like to receive the %s?", label), []common.SelectOption[string]{ + {Label: "Copy to clipboard (recommended)", Value: "clipboard"}, + {Label: "Save to file", Value: "file"}, + }) + if err != nil { + return wrapDashboardError(err) + } + choice = selected } switch choice { - case choiceClipboard: + case "clipboard": if err := common.CopyToClipboard(secret); err != nil { _, _ = common.Yellow.Printf(" Clipboard unavailable: %v\n", err) _, _ = common.Dim.Println(" Falling back to file save") @@ -251,17 +279,47 @@ func handleSecretDelivery(secret, label string) error { } _, _ = common.Green.Printf("✓ %s copied to clipboard\n", label) - case choiceFile: + case "file": return saveSecretToFile(secret, "nylas-client-secret.txt", label) + + default: + return dashboardError( + "invalid secret delivery method", + "Use --secret-delivery clipboard or --secret-delivery file", + ) } return nil } +func validateAPIKeyDelivery(delivery string) error { + switch delivery { + case "", "activate", "clipboard", "file": + return nil + default: + return dashboardError( + "invalid API key delivery method", + "Use --delivery activate, --delivery clipboard, or --delivery file", + ) + } +} + +func validateSecretDelivery(delivery string) error { + switch delivery { + case "", "clipboard", "file": + return nil + default: + return dashboardError( + "invalid secret delivery method", + "Use --secret-delivery clipboard or --secret-delivery file", + ) + } +} + // saveSecretToFile writes a secret to a temp file with restrictive permissions. func saveSecretToFile(secret, filename, label string) error { - keyFile := filepath.Join(os.TempDir(), filename) - if err := os.WriteFile(keyFile, []byte(secret+"\n"), 0o600); err != nil { // #nosec G306 + keyFile, err := writeSecretTempFile(secret, filename) + if err != nil { return wrapDashboardError(fmt.Errorf("failed to write file: %w", err)) } _, _ = common.Green.Printf("✓ %s saved to: %s\n", label, keyFile) @@ -269,6 +327,34 @@ func saveSecretToFile(secret, filename, label string) error { return nil } +func writeSecretTempFile(secret, filename string) (string, error) { + pattern := tempSecretPattern(filename) + file, err := os.CreateTemp("", pattern) + if err != nil { + return "", err + } + defer func() { _ = file.Close() }() + + if err := file.Chmod(0o600); err != nil { + return "", err + } + if _, err := file.WriteString(secret + "\n"); err != nil { + return "", err + } + + return file.Name(), nil +} + +func tempSecretPattern(filename string) string { + base := filepath.Base(filename) + ext := filepath.Ext(base) + name := strings.TrimSuffix(base, ext) + if name == "" { + name = "nylas-secret" + } + return name + "-*" + ext +} + // activateAPIKey stores the API key and configures the CLI to use it. func activateAPIKey(apiKey, clientID, region string) error { configStore := config.NewDefaultFileStore() diff --git a/internal/cli/dashboard/login.go b/internal/cli/dashboard/login.go index 35658ff..2813fb8 100644 --- a/internal/cli/dashboard/login.go +++ b/internal/cli/dashboard/login.go @@ -117,7 +117,10 @@ func runEmailLogin(userFlag, passFlag, orgPublicID string) error { mfaOrg := orgPublicID if mfaOrg == "" && len(mfa.Organizations) > 0 { if len(mfa.Organizations) > 1 { - mfaOrg = selectOrg(mfa.Organizations) + mfaOrg, err = selectOrg(mfa.Organizations) + if err != nil { + return wrapDashboardError(err) + } } else { mfaOrg = mfa.Organizations[0].PublicID } @@ -136,6 +139,7 @@ func runEmailLogin(userFlag, passFlag, orgPublicID string) error { } if err := persistActiveOrg(authSvc, auth, orgPublicID); err != nil { + rollbackPostAuthFailure(authSvc) return wrapDashboardError(err) } diff --git a/internal/cli/dashboard/refresh.go b/internal/cli/dashboard/refresh.go index d423060..2232dac 100644 --- a/internal/cli/dashboard/refresh.go +++ b/internal/cli/dashboard/refresh.go @@ -1,11 +1,13 @@ package dashboard import ( + "errors" "fmt" "github.com/spf13/cobra" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/domain" ) func newRefreshCmd() *cobra.Command { @@ -25,8 +27,10 @@ func newRefreshCmd() *cobra.Command { return authSvc.Refresh(ctx) }) if err != nil { - fmt.Println("Session expired. Please log in again:") - fmt.Println(" nylas dashboard login") + if errors.Is(err, domain.ErrDashboardSessionExpired) { + fmt.Println("Session expired. Please log in again:") + fmt.Println(" nylas dashboard login") + } return wrapDashboardError(err) } diff --git a/internal/cli/dashboard/register.go b/internal/cli/dashboard/register.go index 11280e1..d584a32 100644 --- a/internal/cli/dashboard/register.go +++ b/internal/cli/dashboard/register.go @@ -6,9 +6,10 @@ import ( func newRegisterCmd() *cobra.Command { var ( - google bool - microsoft bool - github bool + google bool + microsoft bool + github bool + acceptPrivacyPolicy bool ) cmd := &cobra.Command{ @@ -27,7 +28,10 @@ Email/password registration is temporarily disabled. Use SSO instead.`, nylas dashboard register --microsoft # GitHub SSO - nylas dashboard register --github`, + nylas dashboard register --github + + # Non-interactive registration + nylas dashboard register --google --accept-privacy-policy`, RunE: func(cmd *cobra.Command, args []string) error { method, err := resolveAuthMethod(google, microsoft, github, false, "register") if err != nil { @@ -36,7 +40,7 @@ Email/password registration is temporarily disabled. Use SSO instead.`, switch method { case methodGoogle, methodMicrosoft, methodGitHub: - return runSSORegister(method) + return runSSORegister(method, acceptPrivacyPolicy) default: return dashboardError("invalid selection", "Choose a valid SSO provider") } @@ -46,12 +50,13 @@ Email/password registration is temporarily disabled. Use SSO instead.`, cmd.Flags().BoolVar(&google, "google", false, "Register with Google SSO") cmd.Flags().BoolVar(µsoft, "microsoft", false, "Register with Microsoft SSO") cmd.Flags().BoolVar(&github, "github", false, "Register with GitHub SSO") + cmd.Flags().BoolVar(&acceptPrivacyPolicy, "accept-privacy-policy", false, "Confirm that you accept the Nylas Privacy Policy") return cmd } -func runSSORegister(provider string) error { - if err := acceptPrivacyPolicy(); err != nil { +func runSSORegister(provider string, acceptedPrivacyPolicy bool) error { + if err := acceptPrivacyPolicy(acceptedPrivacyPolicy); err != nil { return err } return runSSO(provider, "register", true) diff --git a/internal/cli/dashboard/sso.go b/internal/cli/dashboard/sso.go index ab8d47e..8808867 100644 --- a/internal/cli/dashboard/sso.go +++ b/internal/cli/dashboard/sso.go @@ -3,10 +3,12 @@ package dashboard import ( "context" "fmt" + "os" "strings" "time" "github.com/spf13/cobra" + "golang.org/x/term" "github.com/nylas/cli/internal/adapters/browser" dashboardapp "github.com/nylas/cli/internal/app/dashboard" @@ -47,14 +49,18 @@ func newSSOLoginCmd() *cobra.Command { } func newSSORegisterCmd() *cobra.Command { - var provider string + var ( + provider string + acceptPrivacyPolicyFlag bool + ) cmd := &cobra.Command{ - Use: "register", - Short: "Register via SSO", - Example: ` nylas dashboard sso register --provider google`, + Use: "register", + Short: "Register via SSO", + Example: ` nylas dashboard sso register --provider google + nylas dashboard sso register --provider google --accept-privacy-policy`, RunE: func(cmd *cobra.Command, args []string) error { - if err := acceptPrivacyPolicy(); err != nil { + if err := acceptPrivacyPolicy(acceptPrivacyPolicyFlag); err != nil { return err } return runSSO(provider, "register", true) @@ -62,6 +68,7 @@ func newSSORegisterCmd() *cobra.Command { } cmd.Flags().StringVarP(&provider, "provider", "p", "google", "SSO provider (google, microsoft, github)") + cmd.Flags().BoolVar(&acceptPrivacyPolicyFlag, "accept-privacy-policy", false, "Confirm that you accept the Nylas Privacy Policy") return cmd } @@ -126,6 +133,7 @@ func runSSO(provider, mode string, privacyPolicyAccepted bool, orgPublicIDs ...s } if err := persistActiveOrg(authSvc, auth, orgPublicID); err != nil { + rollbackPostAuthFailure(authSvc) return wrapDashboardError(err) } @@ -167,6 +175,11 @@ func pollSSO(ctx context.Context, authSvc *dashboardapp.AuthService, flowID, org if resp.MFA == nil { return nil, fmt.Errorf("unexpected empty MFA response") } + + mfaOrg, resolveErr := resolveSSOMFAOrg(orgPublicID, resp.MFA.Organizations) + if resolveErr != nil { + return nil, resolveErr + } code, readErr := common.PasswordPrompt("MFA code") if readErr != nil { return nil, readErr @@ -174,10 +187,6 @@ func pollSSO(ctx context.Context, authSvc *dashboardapp.AuthService, flowID, org ctx2, cancel := common.CreateContext() var auth *domain.DashboardAuthResponse - mfaOrg := orgPublicID - if mfaOrg == "" && len(resp.MFA.Organizations) > 0 { - mfaOrg = resp.MFA.Organizations[0].PublicID - } mfaErr := common.RunWithSpinner("Verifying MFA...", func() error { auth, err = authSvc.CompleteMFA(ctx2, resp.MFA.User.PublicID, code, mfaOrg) return err @@ -204,6 +213,36 @@ func pollSSO(ctx context.Context, authSvc *dashboardapp.AuthService, flowID, org } } +func resolveSSOMFAOrg(orgPublicID string, orgs []domain.DashboardOrganization) (string, error) { + if orgPublicID != "" || len(orgs) == 0 { + return orgPublicID, nil + } + if len(orgs) == 1 { + return orgs[0].PublicID, nil + } + if !term.IsTerminal(int(os.Stdin.Fd())) { + return "", dashboardError( + "multiple organizations available for MFA", + "Pass --org to choose the organization", + ) + } + + opts := make([]common.SelectOption[string], len(orgs)) + for i, org := range orgs { + label := org.Name + if label == "" { + label = org.PublicID + } + opts[i] = common.SelectOption[string]{Label: label, Value: org.PublicID} + } + + selected, err := common.Select("Select organization", opts) + if err != nil { + return "", err + } + return selected, nil +} + // mapProvider maps a user-friendly provider name to the server login type. func mapProvider(provider string) (string, error) { switch strings.ToLower(provider) { diff --git a/internal/cli/dashboard/status.go b/internal/cli/dashboard/status.go index 5b6c8d6..9e17d47 100644 --- a/internal/cli/dashboard/status.go +++ b/internal/cli/dashboard/status.go @@ -27,34 +27,37 @@ func newStatusCmd() *cobra.Command { return nil } - _, _ = common.Green.Println("✓ Logged in") - if status.UserID != "" { - fmt.Printf(" User: %s\n", status.UserID) + ctx, cancel := common.CreateContext() + defer cancel() + session, err := authSvc.GetCurrentSession(ctx) + if err != nil { + return wrapDashboardError(err) } - // Try to get org details from server for richer display orgLabel := status.OrgID - orgCount := 0 - ctx, cancel := common.CreateContext() - defer cancel() - if session, sErr := authSvc.GetCurrentSession(ctx); sErr == nil { - if session.CurrentOrg != "" { - orgLabel = session.CurrentOrg - // Find the org name from relations - for _, rel := range session.Relations { - if rel.OrgPublicID == session.CurrentOrg && rel.OrgName != "" { - orgLabel = formatOrgLabel(session.CurrentOrg, rel.OrgName) - break - } + if session.CurrentOrg != "" { + orgLabel = session.CurrentOrg + for _, rel := range session.Relations { + if rel.OrgPublicID == session.CurrentOrg && rel.OrgName != "" { + orgLabel = formatOrgLabel(session.CurrentOrg, rel.OrgName) + break } } - orgCount = len(session.Relations) + } + + _, _ = common.Green.Println("✓ Logged in") + userID := status.UserID + if userID == "" { + userID = session.User.PublicID + } + if userID != "" { + fmt.Printf(" User: %s\n", userID) } if orgLabel != "" { fmt.Printf(" Organization: %s\n", orgLabel) } - if orgCount > 1 { - fmt.Printf(" Total orgs: %d (switch with: nylas dashboard orgs switch)\n", orgCount) + if len(session.Relations) > 1 { + fmt.Printf(" Total orgs: %d (switch with: nylas dashboard orgs switch)\n", len(session.Relations)) } fmt.Printf(" Org token: %s\n", presentAbsent(status.HasOrgToken)) diff --git a/internal/cli/dashboard/switch_org.go b/internal/cli/dashboard/switch_org.go index 3bffd32..3ffa92b 100644 --- a/internal/cli/dashboard/switch_org.go +++ b/internal/cli/dashboard/switch_org.go @@ -30,6 +30,12 @@ or pass --org to switch directly.`, if err != nil { return wrapDashboardError(err) } + if orgFlag == "" && !isInteractive() { + return dashboardError( + "organization selection requires an interactive terminal", + "Pass --org to choose the organization in non-interactive runs", + ) + } ctx, cancel := common.CreateContext() defer cancel() diff --git a/internal/cli/integration/dashboard_test.go b/internal/cli/integration/dashboard_test.go new file mode 100644 index 0000000..0f2dc75 --- /dev/null +++ b/internal/cli/integration/dashboard_test.go @@ -0,0 +1,946 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/nylas/cli/internal/adapters/config" + "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/ports" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const dashboardTestPassphrase = "integration-test-file-store-passphrase" + +func TestCLI_DashboardLoginEmailPasswordPersistsSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "stale-app")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "eu")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/auth/cli/login": + var body map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "user@example.com", body["email"]) + assert.Equal(t, "secret", body["password"]) + assert.Equal(t, "org-2", body["orgPublicId"]) + assert.NotEmpty(t, r.Header.Get("DPoP")) + writeDashboardResponse(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token-initial", + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1", "name": "Org One"}, + {"publicId": "org-2", "name": "Org Two"}, + }, + }) + case "/sessions/switch-org": + var body map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "org-2", body["orgPublicId"]) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + writeDashboardResponse(t, w, map[string]any{ + "orgToken": "org-token-switched", + "org": map[string]any{ + "publicId": "org-2", + "name": "Org Two", + }, + }) + case "/sessions/current": + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-switched", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-2", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Org One"}, + {"orgPublicId": "org-2", "orgName": "Org Two"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "login", "--email", "--user", "user@example.com", "--password", "secret", "--org", "org-2") + if err != nil { + t.Fatalf("dashboard login failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Authenticated as user-1") + assert.Contains(t, stdout, "Organization: Org Two (org-2)") + + userToken, err := secretStore.Get(ports.KeyDashboardUserToken) + require.NoError(t, err) + assert.Equal(t, "user-token", userToken) + orgToken, err := secretStore.Get(ports.KeyDashboardOrgToken) + require.NoError(t, err) + assert.Equal(t, "org-token-switched", orgToken) + orgID, err := secretStore.Get(ports.KeyDashboardOrgPublicID) + require.NoError(t, err) + assert.Equal(t, "org-2", orgID) + appID, err := secretStore.Get(ports.KeyDashboardAppID) + require.Error(t, err) + assert.Empty(t, appID) +} + +func TestCLI_DashboardLoginRequiresOrgForMultiOrgNonInteractive(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/auth/cli/login": + writeDashboardResponse(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token", + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1", "name": "Org One"}, + {"publicId": "org-2", "name": "Org Two"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "login", "--email", "--user", "user@example.com", "--password", "secret") + if err == nil { + t.Fatalf("expected dashboard login without --org to fail for multi-org account\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "multiple organizations available") + assert.Contains(t, stderr, "--org") + + userToken, tokenErr := secretStore.Get(ports.KeyDashboardUserToken) + require.Error(t, tokenErr) + assert.Empty(t, userToken) +} + +func TestCLI_DashboardLoginRequiresExplicitAuthMethodInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, _ := newDashboardTestSecretStore(t) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "login") + if err == nil { + t.Fatalf("expected dashboard login without auth method to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "auth method is required") + assert.Contains(t, stderr, "--google") +} + +func TestCLI_DashboardLoginRollsBackSessionWhenOrgSwitchFails(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "stale-app")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "eu")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/auth/cli/login": + writeDashboardResponse(t, w, map[string]any{ + "userToken": "user-token", + "orgToken": "org-token-initial", + "user": map[string]any{ + "publicId": "user-1", + }, + "organizations": []map[string]any{ + {"publicId": "org-1", "name": "Org One"}, + {"publicId": "org-2", "name": "Org Two"}, + }, + }) + case "/sessions/switch-org": + writeDashboardErrorResponse(t, w, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "dashboard unavailable") + case "/auth/cli/logout": + writeDashboardResponse(t, w, map[string]any{}) + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "login", "--email", "--user", "user@example.com", "--password", "secret", "--org", "org-2") + if err == nil { + t.Fatalf("expected dashboard login to fail when org switch fails\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "failed to switch organization") + for _, key := range []string{ + ports.KeyDashboardUserToken, + ports.KeyDashboardOrgToken, + ports.KeyDashboardUserPublicID, + ports.KeyDashboardOrgPublicID, + ports.KeyDashboardAppID, + ports.KeyDashboardAppRegion, + } { + value, getErr := secretStore.Get(key) + require.Error(t, getErr, "expected %s to be removed", key) + assert.Empty(t, value) + } +} + +func TestCLI_DashboardRefreshUpdatesStoredTokens(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token-old")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token-old")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/auth/cli/refresh", r.URL.Path) + assert.Equal(t, "Bearer user-token-old", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-old", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "userToken": "user-token-new", + "orgToken": "org-token-new", + }) + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "refresh") + if err != nil { + t.Fatalf("dashboard refresh failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Session refreshed") + userToken, err := secretStore.Get(ports.KeyDashboardUserToken) + require.NoError(t, err) + assert.Equal(t, "user-token-new", userToken) + orgToken, err := secretStore.Get(ports.KeyDashboardOrgToken) + require.NoError(t, err) + assert.Equal(t, "org-token-new", orgToken) +} + +func TestCLI_DashboardRegisterRequiresExplicitAuthMethodInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, _ := newDashboardTestSecretStore(t) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "register") + if err == nil { + t.Fatalf("expected dashboard register without auth method to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "auth method is required") + assert.Contains(t, stderr, "--google") +} + +func TestCLI_DashboardRegisterRequiresExplicitPrivacyAcceptanceInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, _ := newDashboardTestSecretStore(t) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "register", "--google") + if err == nil { + t.Fatalf("expected dashboard register without privacy acceptance to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "privacy policy must be accepted") + assert.Contains(t, stderr, "--accept-privacy-policy") +} + +func TestCLI_DashboardLogoutClearsSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "us")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/auth/cli/logout", r.URL.Path) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + writeDashboardResponse(t, w, map[string]any{}) + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "logout") + if err != nil { + t.Fatalf("dashboard logout failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Logged out") + userToken, err := secretStore.Get(ports.KeyDashboardUserToken) + require.Error(t, err) + assert.Empty(t, userToken) + appID, err := secretStore.Get(ports.KeyDashboardAppID) + require.Error(t, err) + assert.Empty(t, appID) +} + +func TestCLI_DashboardOrgsListUsesCurrentSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/sessions/current", r.URL.Path) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + writeDashboardResponse(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-1", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Acme", "role": "admin"}, + {"orgPublicId": "org-2", "orgName": "Beta", "role": "member"}, + }, + }) + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "orgs", "list") + if err != nil { + t.Fatalf("dashboard orgs list failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Acme") + assert.Contains(t, stdout, "Beta") +} + +func TestCLI_DashboardOrgsListRefreshesExpiredSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token-old")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token-old")) + + currentCalls := 0 + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sessions/current": + currentCalls++ + if currentCalls == 1 { + assert.Equal(t, "Bearer user-token-old", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-old", r.Header.Get("X-Nylas-Org")) + writeDashboardErrorResponse(t, w, http.StatusUnauthorized, "INVALID_SESSION", "") + return + } + assert.Equal(t, "Bearer user-token-new", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-new", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-1", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Acme"}, + }, + }) + case "/auth/cli/refresh": + assert.Equal(t, "Bearer user-token-old", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-old", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "userToken": "user-token-new", + "orgToken": "org-token-new", + }) + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "orgs", "list") + if err != nil { + t.Fatalf("dashboard orgs list with refresh failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Acme") + userToken, err := secretStore.Get(ports.KeyDashboardUserToken) + require.NoError(t, err) + assert.Equal(t, "user-token-new", userToken) + orgToken, err := secretStore.Get(ports.KeyDashboardOrgToken) + require.NoError(t, err) + assert.Equal(t, "org-token-new", orgToken) +} + +func TestCLI_DashboardSwitchOrgUpdatesStoredSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token-old")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sessions/current": + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-old", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-1", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Acme"}, + {"orgPublicId": "org-2", "orgName": "Beta"}, + }, + }) + case "/sessions/switch-org": + var body map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "org-2", body["orgPublicId"]) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token-old", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "orgToken": "org-token-new", + "org": map[string]any{ + "publicId": "org-2", + "name": "Beta", + }, + }) + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "orgs", "switch", "--org", "org-2") + if err != nil { + t.Fatalf("dashboard orgs switch failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Switched to organization: Beta (org-2)") + + orgID, err := secretStore.Get(ports.KeyDashboardOrgPublicID) + require.NoError(t, err) + assert.Equal(t, "org-2", orgID) + orgToken, err := secretStore.Get(ports.KeyDashboardOrgToken) + require.NoError(t, err) + assert.Equal(t, "org-token-new", orgToken) +} + +func TestCLI_DashboardSwitchOrgRequiresOrgInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, _ := newDashboardTestSecretStore(t) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "orgs", "switch") + if err == nil { + t.Fatalf("expected dashboard orgs switch without --org to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "interactive terminal") + assert.Contains(t, stderr, "--org") +} + +func TestCLI_DashboardStatusShowsCurrentSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserPublicID, "user-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "us")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/sessions/current", r.URL.Path) + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + writeDashboardResponse(t, w, map[string]any{ + "user": map[string]any{ + "publicId": "user-1", + }, + "currentOrg": "org-1", + "relations": []map[string]any{ + {"orgPublicId": "org-1", "orgName": "Acme"}, + {"orgPublicId": "org-2", "orgName": "Beta"}, + }, + }) + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "status") + if err != nil { + t.Fatalf("dashboard status failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "Logged in") + assert.Contains(t, stdout, "User: user-1") + assert.Contains(t, stdout, "Organization: Acme (org-1)") + assert.Contains(t, stdout, "Total orgs: 2") + assert.Contains(t, stdout, "Org token: present") + assert.Contains(t, stdout, "Active app: app-1 (us)") + assert.Contains(t, stdout, "DPoP key:") +} + +func TestCLI_DashboardStatusFailsWhenSessionCannotBeValidated(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserPublicID, "user-1")) + + accountServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sessions/current": + writeDashboardErrorResponse(t, w, http.StatusUnauthorized, "INVALID_SESSION", "") + case "/auth/cli/refresh": + writeDashboardErrorResponse(t, w, http.StatusUnauthorized, "INVALID_SESSION", "") + default: + http.NotFound(w, r) + } + })) + defer accountServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_ACCOUNT_URL": accountServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "status") + if err == nil { + t.Fatalf("expected dashboard status to fail when session validation fails\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.NotContains(t, stdout, "Logged in") + assert.Contains(t, strings.ToLower(stderr), "invalid_session") +} + +func TestCLI_DashboardAppsAndAPIKeysList(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "us")) + + gatewayServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + + var body map[string]any + require.NoError(t, json.Unmarshal(raw, &body)) + + query := body["query"].(string) + variables := body["variables"].(map[string]any) + + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + assert.NotEmpty(t, r.Header.Get("DPoP")) + + switch { + case strings.Contains(query, "applications("): + filter := variables["filter"].(map[string]any) + assert.Equal(t, "org-1", filter["orgPublicId"]) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "applications": map[string]any{ + "applications": []map[string]any{ + { + "applicationId": "app-1", + "organizationId": "org-1", + "region": "us", + "environment": "sandbox", + "branding": map[string]any{ + "name": "Primary", + }, + }, + }, + }, + }, + })) + case strings.Contains(query, "apiKeys("): + assert.Equal(t, "app-1", variables["appId"]) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "apiKeys": []map[string]any{ + { + "id": "key-1", + "name": "CI", + "status": "active", + "permissions": []string{"send"}, + "expiresAt": 0.0, + "createdAt": 1710000000.0, + }, + }, + }, + })) + default: + t.Fatalf("unexpected GraphQL query: %s", query) + } + })) + defer gatewayServer.Close() + + env := dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_GATEWAY_URL": gatewayServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, env, "dashboard", "apps", "list", "--region", "us") + if err != nil { + t.Fatalf("dashboard apps list failed: %v\nstderr: %s", err, stderr) + } + assert.Contains(t, stdout, "app-1") + assert.Contains(t, stdout, "Primary") + + stdout, stderr, err = runCLIWithOverrides(30*time.Second, env, "dashboard", "apps", "apikeys", "list") + if err != nil { + t.Fatalf("dashboard apps apikeys list failed: %v\nstderr: %s", err, stderr) + } + assert.Contains(t, stdout, "key-1") + assert.Contains(t, stdout, "CI") +} + +func TestCLI_DashboardAppsListSurfacesGraphQLInvalidSession(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + + gatewayServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer user-token", r.Header.Get("Authorization")) + assert.Equal(t, "org-token", r.Header.Get("X-Nylas-Org")) + assert.NotEmpty(t, r.Header.Get("DPoP")) + + w.WriteHeader(http.StatusUnauthorized) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "errors": []map[string]any{ + { + "message": "INVALID_SESSION", + "extensions": map[string]any{ + "code": "UNAUTHENTICATED", + }, + }, + }, + })) + })) + defer gatewayServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_GATEWAY_URL": gatewayServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "list", "--region", "us") + if err == nil { + t.Fatalf("expected dashboard apps list to fail on GraphQL INVALID_SESSION\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "invalid_session") + assert.Contains(t, strings.ToLower(stderr), "invalid or expired session") +} + +func TestCLI_DashboardAppsListWarnsOnPartialRegionFailure(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardUserToken, "user-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgToken, "org-token")) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + + usServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "applications": map[string]any{ + "applications": []map[string]any{ + { + "applicationId": "app-1", + "organizationId": "org-1", + "region": "us", + "environment": "sandbox", + "branding": map[string]any{ + "name": "Primary", + }, + }, + }, + }, + }, + })) + })) + defer usServer.Close() + + euServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "EU unavailable", http.StatusBadGateway) + })) + defer euServer.Close() + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_DASHBOARD_GATEWAY_US_URL": usServer.URL, + "NYLAS_DASHBOARD_GATEWAY_EU_URL": euServer.URL, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "list") + if err != nil { + t.Fatalf("dashboard apps list with partial failure failed: %v\nstderr: %s", err, stderr) + } + + assert.Contains(t, stdout, "app-1") + assert.Contains(t, stdout, "partial results") +} + +func TestCLI_DashboardAPIKeysRequirePairedOverrides(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "us")) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "apikeys", "list", "--app", "app-2") + if err == nil { + t.Fatalf("expected --app without --region to fail\nstdout: %s\nstderr: %s", stdout, stderr) + } + assert.Contains(t, strings.ToLower(stderr), "both --app and --region") + + stdout, stderr, err = runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "apikeys", "list", "--region", "eu") + if err == nil { + t.Fatalf("expected --region without --app to fail\nstdout: %s\nstderr: %s", stdout, stderr) + } + assert.Contains(t, strings.ToLower(stderr), "both --app and --region") +} + +func TestCLI_DashboardAppsUseRequiresExplicitAppInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, _ := newDashboardTestSecretStore(t) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "use") + if err == nil { + t.Fatalf("expected dashboard apps use without app id to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "interactive terminal") + assert.Contains(t, stderr, "--region") +} + +func TestCLI_DashboardAppCreateRequiresSecretDeliveryInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardOrgPublicID, "org-1")) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "create", "--name", "My App", "--region", "us") + if err == nil { + t.Fatalf("expected dashboard apps create without secret delivery to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "client secret delivery requires an explicit choice") + assert.Contains(t, stderr, "--secret-delivery") +} + +func TestCLI_DashboardAPIKeyCreateRequiresDeliveryInNonInteractiveMode(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + configHome, tempHome, secretStore := newDashboardTestSecretStore(t) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppID, "app-1")) + require.NoError(t, secretStore.Set(ports.KeyDashboardAppRegion, "us")) + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, dashboardEnvOverrides(configHome, tempHome, map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }), "dashboard", "apps", "apikeys", "create") + if err == nil { + t.Fatalf("expected dashboard apps apikeys create without delivery to fail in non-interactive mode\nstdout: %s\nstderr: %s", stdout, stderr) + } + + assert.Contains(t, strings.ToLower(stderr), "api key delivery requires an explicit choice") + assert.Contains(t, stderr, "--delivery") +} + +func newDashboardTestSecretStore(t *testing.T) (configHome string, tempHome string, store ports.SecretStore) { + t.Helper() + + origPassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") + origXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + origHome := os.Getenv("HOME") + t.Cleanup(func() { + setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origPassphrase) + setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) + setEnvOrUnset("XDG_CONFIG_HOME", origXDGConfigHome) + setEnvOrUnset("HOME", origHome) + }) + + tempHome = t.TempDir() + configHome = filepath.Join(tempHome, "xdg") + require.NoError(t, os.Setenv("XDG_CONFIG_HOME", configHome)) + require.NoError(t, os.Setenv("HOME", tempHome)) + require.NoError(t, os.Setenv("NYLAS_DISABLE_KEYRING", "true")) + require.NoError(t, os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", dashboardTestPassphrase)) + + secretStore, err := keyring.NewEncryptedFileStore(config.DefaultConfigDir()) + require.NoError(t, err) + return configHome, tempHome, secretStore +} + +func dashboardEnvOverrides(configHome, tempHome string, extra map[string]string) map[string]string { + overrides := map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempHome, + "NYLAS_DISABLE_KEYRING": "true", + "NYLAS_FILE_STORE_PASSPHRASE": dashboardTestPassphrase, + } + for k, v := range extra { + overrides[k] = v + } + return overrides +} + +func writeDashboardResponse(t *testing.T, w http.ResponseWriter, data any) { + t.Helper() + + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "request_id": "req-1", + "success": true, + "data": data, + })) +} + +func writeDashboardErrorResponse(t *testing.T, w http.ResponseWriter, status int, code, message string) { + t.Helper() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + errBody := map[string]any{ + "request_id": "req-1", + "success": false, + "error": map[string]any{ + "code": code, + }, + } + if message != "" { + errBody["error"].(map[string]any)["message"] = message + } + + require.NoError(t, json.NewEncoder(w).Encode(errBody)) +} diff --git a/internal/domain/dashboard_errors.go b/internal/domain/dashboard_errors.go new file mode 100644 index 0000000..934a795 --- /dev/null +++ b/internal/domain/dashboard_errors.go @@ -0,0 +1,102 @@ +package domain + +import ( + "errors" + "fmt" + "sort" + "strings" +) + +// DashboardAPIError preserves structured dashboard API failures while allowing +// callers to detect specific auth states with errors.Is. +type DashboardAPIError struct { + StatusCode int + Code string + ServerMsg string + sentinel error +} + +// NewDashboardAPIError creates a dashboard API error and classifies it when a +// domain-level sentinel applies. +func NewDashboardAPIError(statusCode int, code, message string) *DashboardAPIError { + serverMsg := message + if code != "" { + if message != "" { + serverMsg = code + ": " + message + } else { + serverMsg = code + } + } + + return &DashboardAPIError{ + StatusCode: statusCode, + Code: code, + ServerMsg: serverMsg, + sentinel: classifyDashboardAPIError(statusCode, code), + } +} + +func (e *DashboardAPIError) Error() string { + if e == nil { + return "dashboard API error" + } + if e.ServerMsg != "" { + return fmt.Sprintf("dashboard API error (HTTP %d): %s", e.StatusCode, e.ServerMsg) + } + return fmt.Sprintf("dashboard API error (HTTP %d)", e.StatusCode) +} + +func (e *DashboardAPIError) Unwrap() error { + if e == nil { + return nil + } + return e.sentinel +} + +func classifyDashboardAPIError(statusCode int, code string) error { + if statusCode == 401 && code == "INVALID_SESSION" { + return ErrDashboardSessionExpired + } + return nil +} + +// DashboardPartialResultError indicates that a dashboard operation returned a +// usable partial result set while one or more backends failed. +type DashboardPartialResultError struct { + Operation string + Failures map[string]error +} + +func (e *DashboardPartialResultError) Error() string { + if e == nil { + return "dashboard operation returned partial results" + } + + operation := e.Operation + if operation == "" { + operation = "dashboard operation" + } + + parts := make([]string, 0, len(e.Failures)) + regions := make([]string, 0, len(e.Failures)) + for region := range e.Failures { + regions = append(regions, region) + } + sort.Strings(regions) + for _, region := range regions { + parts = append(parts, fmt.Sprintf("%s: %v", region, e.Failures[region])) + } + return fmt.Sprintf("%s returned partial results (%s)", operation, strings.Join(parts, "; ")) +} + +func (e *DashboardPartialResultError) Unwrap() error { + if e == nil || len(e.Failures) == 0 { + return nil + } + + errs := make([]error, 0, len(e.Failures)) + for _, err := range e.Failures { + errs = append(errs, err) + } + return errors.Join(errs...) +}