diff --git a/cli/cmd/lab_reset.go b/cli/cmd/lab_reset.go index f92b0c51..fdad9080 100644 --- a/cli/cmd/lab_reset.go +++ b/cli/cmd/lab_reset.go @@ -398,7 +398,7 @@ func purgeUnmanaged(ctx context.Context, cfg *config.Config, opts purgeOptions) if err != nil { return err } - client, err := daws.NewClient(ctx, region) + client, err := daws.NewClient(ctx, region, "") if err != nil { return err } diff --git a/cli/cmd/scoreboard.go b/cli/cmd/scoreboard.go index 188dac7a..8f8de009 100644 --- a/cli/cmd/scoreboard.go +++ b/cli/cmd/scoreboard.go @@ -64,6 +64,7 @@ func init() { scoreboardRunCmd.Flags().Duration("interval", 3*time.Second, "Poll interval (e.g. 3s, 1500ms)") scoreboardRunCmd.Flags().Bool("restart", false, "Delete the existing report file on the target before starting (no-op for --transport=ares)") scoreboardRunCmd.Flags().Bool("once", false, "Fetch + verify once, print the static board, exit (no TUI)") + scoreboardRunCmd.Flags().String("profile", "", "AWS named profile for SSM/ares transports") } func runScoreboardGenerateKey(cmd *cobra.Command, _ []string) error { @@ -153,6 +154,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config) instanceID, _ := cmd.Flags().GetString("instance-id") ssmRegion, _ := cmd.Flags().GetString("ssm-region") aresBinary, _ := cmd.Flags().GetString("ares-binary") + profile, _ := cmd.Flags().GetString("profile") switch transport { case "local": @@ -165,7 +167,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config) if region == "" { region = cfg.Region } - st, err := scoreboard.NewSSMTransport(ctx, instanceID, reportPath, region) + st, err := scoreboard.NewSSMTransport(ctx, instanceID, reportPath, region, profile) if err != nil { return nil, "", err } @@ -178,7 +180,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config) if region == "" { region = cfg.Region } - at, err := scoreboard.NewAresTransport(ctx, instanceID, aresBinary, region) + at, err := scoreboard.NewAresTransport(ctx, instanceID, aresBinary, region, profile) if err != nil { return nil, "", err } diff --git a/cli/internal/aws/client.go b/cli/internal/aws/client.go index e48d2732..5a232024 100644 --- a/cli/internal/aws/client.go +++ b/cli/internal/aws/client.go @@ -25,18 +25,27 @@ var ( mu sync.Mutex ) -// NewClient creates or returns a cached AWS client for the given region. -func NewClient(ctx context.Context, region string) (*Client, error) { +// NewClient creates or returns a cached AWS client for the given region and +// optional profile. Pass an empty profile to use the SDK default chain. +func NewClient(ctx context.Context, region, profile string) (*Client, error) { mu.Lock() defer mu.Unlock() - if c, ok := clients[region]; ok { + key := region + "\x00" + profile + if c, ok := clients[key]; ok { return c, nil } - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) + opts := []func(*awsconfig.LoadOptions) error{awsconfig.WithRegion(region)} + if profile != "" { + opts = append(opts, awsconfig.WithSharedConfigProfile(profile)) + } + cfg, err := awsconfig.LoadDefaultConfig(ctx, opts...) if err != nil { - return nil, fmt.Errorf("load AWS config for %s: %w", region, err) + if profile != "" { + return nil, fmt.Errorf("load AWS config for region=%s profile=%s: %w", region, profile, err) + } + return nil, fmt.Errorf("load AWS config for region=%s: %w", region, err) } c := &Client{ @@ -45,7 +54,7 @@ func NewClient(ctx context.Context, region string) (*Client, error) { STS: sts.NewFromConfig(cfg), Region: region, } - clients[region] = c + clients[key] = c return c, nil } diff --git a/cli/internal/aws/provider.go b/cli/internal/aws/provider.go index b4ea058d..33d4acff 100644 --- a/cli/internal/aws/provider.go +++ b/cli/internal/aws/provider.go @@ -18,7 +18,7 @@ func init() { if opts.Region == "" { return nil, fmt.Errorf("AWS region is required") } - client, err := NewClient(ctx, opts.Region) + client, err := NewClient(ctx, opts.Region, "") if err != nil { return nil, err } diff --git a/cli/internal/scoreboard/transport.go b/cli/internal/scoreboard/transport.go index 1b6bb071..102930f2 100644 --- a/cli/internal/scoreboard/transport.go +++ b/cli/internal/scoreboard/transport.go @@ -66,11 +66,11 @@ type SSMTransport struct { // NewSSMTransport builds an SSM transport. Region defaults to the SDK's // default if empty. -func NewSSMTransport(ctx context.Context, instanceID, reportPath, region string) (*SSMTransport, error) { +func NewSSMTransport(ctx context.Context, instanceID, reportPath, region, profile string) (*SSMTransport, error) { if instanceID == "" { return nil, fmt.Errorf("instance ID is required") } - c, err := awsclient.NewClient(ctx, region) + c, err := awsclient.NewClient(ctx, region, profile) if err != nil { return nil, err } diff --git a/cli/internal/scoreboard/transport_ares.go b/cli/internal/scoreboard/transport_ares.go index 627d5df3..b10b87cc 100644 --- a/cli/internal/scoreboard/transport_ares.go +++ b/cli/internal/scoreboard/transport_ares.go @@ -28,11 +28,11 @@ type AresTransport struct { // NewAresTransport constructs an AresTransport. binaryPath defaults to // /usr/local/bin/ares when empty. -func NewAresTransport(ctx context.Context, instanceID, binaryPath, region string) (*AresTransport, error) { +func NewAresTransport(ctx context.Context, instanceID, binaryPath, region, profile string) (*AresTransport, error) { if instanceID == "" { return nil, fmt.Errorf("instance ID is required") } - c, err := awsclient.NewClient(ctx, region) + c, err := awsclient.NewClient(ctx, region, profile) if err != nil { return nil, err }