Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,19 @@ Token resolution order:
"token": client.MaskToken(token),
}

// Show stored region/URL and account ID.
if creds, err := client.ReadCredentials(); err == nil {
if creds.Region != "" {
statusResult["region"] = creds.Region
}
if creds.AccountID != "" {
statusResult["account_id"] = creds.AccountID
// Stored metadata is only authoritative when the stored token is active.
if tokenSource == "config_file" {
if creds, err := client.ReadCredentials(); err == nil {
if creds.Region != "" {
statusResult["region"] = creds.Region
}
if creds.AccountID != "" {
statusResult["account_id"] = creds.AccountID
}
}
}

// Verify by exchanging the token.
// Verify the active token and derive account metadata from that session.
c := clientFromCmd(cmd)
if c != nil {
statusResult["base_url"] = c.BaseURL()
Expand All @@ -241,6 +243,16 @@ Token resolution order:
statusResult["verify_error"] = err.Error()
} else {
statusResult["verified"] = true
if info, err := c.CurrentAccountInfo(cmd.Context()); err == nil {
if info.Region != "" {
statusResult["region"] = info.Region
}
if info.AccountID != "" {
statusResult["account_id"] = info.AccountID
}
} else if tokenSource != "config_file" {
statusResult["account_info_error"] = err.Error()
}
}
}

Expand Down
70 changes: 70 additions & 0 deletions cmd/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/customerio/cli/internal/client"
)
Expand Down Expand Up @@ -551,6 +552,75 @@ func TestAuthStatus_WithEnvToken(t *testing.T) {
}
}

func TestAuthStatus_WithEnvTokenReportsEnvAccountID(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("CIO_TOKEN", "sa_live_envtoken")
t.Setenv("CIO_ACCESS_TOKEN", "")

if err := client.WriteCredentials(&client.Credentials{
ServiceAccountToken: "sa_live_filetoken",
AccountID: "1",
Region: "us",
AccessToken: "jwt-file",
AccessTokenExpiresAt: time.Now().Add(time.Hour),
}); err != nil {
t.Fatalf("seed credentials: %v", err)
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/service_accounts/oauth/token":
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
switch r.PostFormValue("client_secret") {
case "sa_live_envtoken":
_, _ = w.Write([]byte(`{"access_token":"jwt-env","token_type":"Bearer","expires_in":3600}`))
case "sa_live_filetoken":
_, _ = w.Write([]byte(`{"access_token":"jwt-file","token_type":"Bearer","expires_in":3600}`))
default:
w.WriteHeader(http.StatusUnauthorized)
}
case "/v1/accounts/current":
switch r.Header.Get("Authorization") {
case "Bearer jwt-env":
_, _ = w.Write([]byte(`{"account":{"id":2,"name":"Env Account","data_center":"eu"}}`))
case "Bearer jwt-file":
_, _ = w.Write([]byte(`{"account":{"id":1,"name":"File Account","data_center":"us"}}`))
default:
w.WriteHeader(http.StatusUnauthorized)
}
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()

stdout, _, err := executeCommand("auth", "status", "--api-url", server.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

var result map[string]any
if err := json.Unmarshal([]byte(stdout), &result); err != nil {
t.Fatalf("invalid JSON: %v\nstdout: %s", err, stdout)
}
if result["token_source"] != "environment" {
t.Errorf("expected token_source 'environment', got %v", result["token_source"])
}
if result["verified"] != true {
t.Errorf("expected verified=true, got %v (error: %v)", result["verified"], result["verify_error"])
}
if result["account_id"] != "2" {
t.Errorf("expected account_id from environment token, got %v", result["account_id"])
}
if result["region"] != "eu" {
t.Errorf("expected region from environment token, got %v", result["region"])
}
}

func TestAuthStatus_InvalidToken(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
Expand Down
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/customerio/cli/internal/client"
"github.com/customerio/cli/internal/output"
"github.com/customerio/cli/internal/useragent"
"github.com/customerio/cli/internal/validate"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -219,6 +220,7 @@ func GetJSONBody(cmd *cobra.Command) ([]byte, error) {
// SetVersion sets the CLI version string (called from main with ldflags value).
func SetVersion(v string) {
if v != "" {
useragent.SetVersion(v)
rootCmd.Version = v
}
}
Expand Down
25 changes: 25 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package cmd

import (
"testing"

"github.com/customerio/cli/internal/useragent"
)

func TestSetVersionIgnoresEmptyVersion(t *testing.T) {
oldRootVersion := rootCmd.Version
t.Cleanup(func() {
rootCmd.Version = oldRootVersion
useragent.SetVersion("dev")
})

SetVersion("v1.2.3")
SetVersion("")

if got := rootCmd.Version; got != "v1.2.3" {
t.Fatalf("rootCmd.Version = %q, want %q", got, "v1.2.3")
}
if got, want := useragent.Get(), "Customer.io-CLI/v1.2.3 (+https://github.com/customerio/cli)"; got != want {
t.Fatalf("useragent.Get() = %q, want %q", got, want)
}
}
33 changes: 33 additions & 0 deletions internal/client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,26 @@ func ResolveRegion(apiURL string, apiURLChanged bool) string {
// accidentally using a restricted token for a full-access session).
// The scopes parameter must match the cached scopes exactly for reuse.
func CachedAccessToken(readOnly bool, scopes []string) string {
return cachedAccessToken("", readOnly, scopes)
}

// CachedAccessTokenForServiceAccount returns a cached JWT only when it belongs
// to the same service-account token that is active for this invocation.
func CachedAccessTokenForServiceAccount(serviceAccountToken string, readOnly bool, scopes []string) string {
if serviceAccountToken == "" {
return ""
}
return cachedAccessToken(serviceAccountToken, readOnly, scopes)
}

func cachedAccessToken(serviceAccountToken string, readOnly bool, scopes []string) string {
creds, err := ReadCredentials()
if err != nil {
return ""
}
if serviceAccountToken != "" && creds.ServiceAccountToken != serviceAccountToken {
return ""
}
if creds.AccessToken == "" {
return ""
}
Expand Down Expand Up @@ -190,6 +206,19 @@ func stringsEqual(a, b []string) bool {
// Holds an exclusive lock across the read-modify-write sequence so two
// concurrent invocations don't lose each other's update.
func CacheAccessToken(accessToken string, expiresIn int, readOnly bool, scopes []string) error {
return cacheAccessToken("", accessToken, expiresIn, readOnly, scopes)
}

// CacheAccessTokenForServiceAccount stores a JWT only when the stored config
// still belongs to the same service-account token that minted the JWT.
func CacheAccessTokenForServiceAccount(serviceAccountToken, accessToken string, expiresIn int, readOnly bool, scopes []string) error {
if serviceAccountToken == "" {
return nil
}
return cacheAccessToken(serviceAccountToken, accessToken, expiresIn, readOnly, scopes)
}

func cacheAccessToken(serviceAccountToken, accessToken string, expiresIn int, readOnly bool, scopes []string) error {
unlock, err := lockConfigDir()
if err != nil {
// Can't cache without a lock — don't fail the caller's request.
Expand All @@ -202,6 +231,10 @@ func CacheAccessToken(accessToken string, expiresIn int, readOnly bool, scopes [
// No existing config — can't cache without stored credentials.
return nil
}
if serviceAccountToken != "" && creds.ServiceAccountToken != serviceAccountToken {
// Env/flag overrides should not rewrite the cache for the stored token.
return nil
}
creds.AccessToken = accessToken
creds.AccessTokenExpiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second)
creds.ReadOnly = readOnly
Expand Down
45 changes: 45 additions & 0 deletions internal/client/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,51 @@ func TestCacheAccessToken_ConcurrentWritesProduceValidConfig(t *testing.T) {
}
}

func TestCachedAccessTokenRequiresMatchingServiceAccountToken(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)

if err := WriteCredentials(&Credentials{
ServiceAccountToken: "sa_live_file",
AccessToken: "jwt-file",
AccessTokenExpiresAt: time.Now().Add(time.Hour),
}); err != nil {
t.Fatalf("seed write: %v", err)
}

if got := CachedAccessTokenForServiceAccount("sa_live_env", false, nil); got != "" {
t.Fatalf("expected mismatched token cache miss, got %q", got)
}
if got := CachedAccessTokenForServiceAccount("sa_live_file", false, nil); got != "jwt-file" {
t.Fatalf("expected matching token cache hit, got %q", got)
}
}

func TestCacheAccessTokenSkipsMismatchedServiceAccountToken(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)

if err := WriteCredentials(&Credentials{
ServiceAccountToken: "sa_live_file",
AccessToken: "jwt-file",
AccessTokenExpiresAt: time.Now().Add(time.Hour),
}); err != nil {
t.Fatalf("seed write: %v", err)
}

if err := CacheAccessTokenForServiceAccount("sa_live_env", "jwt-env", 3600, false, nil); err != nil {
t.Fatalf("cache write: %v", err)
}

got, err := ReadCredentials()
if err != nil {
t.Fatalf("read credentials: %v", err)
}
if got.AccessToken != "jwt-file" {
t.Fatalf("expected cached access token to stay unchanged, got %q", got.AccessToken)
}
}

func TestWriteCredentials_AtomicNoPartialFile(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
Expand Down
32 changes: 30 additions & 2 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ import (
"os"
"strings"
"time"

"github.com/customerio/cli/internal/useragent"
)

// setStandardHeaders stamps headers that every outgoing CLI request should
// carry:
//
// - User-Agent — identifies this CLI and its release version in API logs.
// - X-Validate: strict — opts into strict server-side JSON validation so
// that unknown/typo'd body fields produce a 400 instead of a silent 200.
// Harmless on GETs and form-encoded bodies (the server only consults it
Expand All @@ -25,6 +28,7 @@ import (
// sandbox that runs the CLI on behalf of an AI agent sets this so
// downstream metrics can attribute traffic to the agent.
func setStandardHeaders(req *http.Request) {
req.Header.Set("User-Agent", useragent.Get())
req.Header.Set("X-Validate", "strict")
if os.Getenv("CIO_AGENT") == "1" {
req.Header.Set("X-CIO-Agent", "1")
Expand Down Expand Up @@ -155,7 +159,7 @@ func (c *Client) EnsureAccessToken(ctx context.Context) (string, error) {
}

// Check the file cache.
if cached := CachedAccessToken(c.readOnly, c.scopes); cached != "" {
if cached := CachedAccessTokenForServiceAccount(c.serviceAccountToken, c.readOnly, c.scopes); cached != "" {
c.accessToken = cached
// File cache already applies 60s buffer, so set a conservative in-memory expiry.
c.accessTokenExpiresAt = time.Now().Add(55 * time.Minute)
Expand All @@ -180,7 +184,7 @@ func (c *Client) EnsureAccessToken(ctx context.Context) (string, error) {
}

// Cache for future invocations.
_ = CacheAccessToken(token, expiresIn, c.readOnly, c.scopes)
_ = CacheAccessTokenForServiceAccount(c.serviceAccountToken, token, expiresIn, c.readOnly, c.scopes)

return token, nil
}
Expand Down Expand Up @@ -280,6 +284,30 @@ type DiscoverRegionResult struct {
AccountID string
}

// AccountInfo describes the account represented by the active access token.
type AccountInfo struct {
Region string
AccountID string
}

// CurrentAccountInfo returns account metadata for the active token.
func (c *Client) CurrentAccountInfo(ctx context.Context) (*AccountInfo, error) {
accessToken, err := c.EnsureAccessToken(ctx)
if err != nil {
return nil, err
}

region, accountID, err := fetchAccountInfo(ctx, c.httpClient, c.baseURL, accessToken)
if err != nil {
return nil, err
}

return &AccountInfo{
Region: region,
AccountID: accountID,
}, nil
}

// DiscoverRegion exchanges the sa_live_ token against the default US endpoint,
// then calls GET /v1/accounts/current to read the account's data_center field.
//
Expand Down
Loading