Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/cmd/lab_reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func purgeUnmanaged(ctx context.Context, cfg *config.Config, opts purgeOptions)
if err != nil {
return err
}
client, err := daws.NewClient(ctx, region)
client, err := daws.NewClient(ctx, region, "")
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions cli/cmd/scoreboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func init() {
scoreboardRunCmd.Flags().Duration("interval", 3*time.Second, "Poll interval (e.g. 3s, 1500ms)")
scoreboardRunCmd.Flags().Bool("restart", false, "Delete the existing report file on the target before starting (no-op for --transport=ares)")
scoreboardRunCmd.Flags().Bool("once", false, "Fetch + verify once, print the static board, exit (no TUI)")
scoreboardRunCmd.Flags().String("profile", "", "AWS named profile for SSM/ares transports")
}

func runScoreboardGenerateKey(cmd *cobra.Command, _ []string) error {
Expand Down Expand Up @@ -153,6 +154,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config)
instanceID, _ := cmd.Flags().GetString("instance-id")
ssmRegion, _ := cmd.Flags().GetString("ssm-region")
aresBinary, _ := cmd.Flags().GetString("ares-binary")
profile, _ := cmd.Flags().GetString("profile")

switch transport {
case "local":
Expand All @@ -165,7 +167,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config)
if region == "" {
region = cfg.Region
}
st, err := scoreboard.NewSSMTransport(ctx, instanceID, reportPath, region)
st, err := scoreboard.NewSSMTransport(ctx, instanceID, reportPath, region, profile)
if err != nil {
return nil, "", err
}
Expand All @@ -178,7 +180,7 @@ func buildTransport(ctx context.Context, cmd *cobra.Command, cfg *config.Config)
if region == "" {
region = cfg.Region
}
at, err := scoreboard.NewAresTransport(ctx, instanceID, aresBinary, region)
at, err := scoreboard.NewAresTransport(ctx, instanceID, aresBinary, region, profile)
if err != nil {
return nil, "", err
}
Expand Down
21 changes: 15 additions & 6 deletions cli/internal/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,27 @@ var (
mu sync.Mutex
)

// NewClient creates or returns a cached AWS client for the given region.
func NewClient(ctx context.Context, region string) (*Client, error) {
// NewClient creates or returns a cached AWS client for the given region and
// optional profile. Pass an empty profile to use the SDK default chain.
func NewClient(ctx context.Context, region, profile string) (*Client, error) {
mu.Lock()
defer mu.Unlock()

if c, ok := clients[region]; ok {
key := region + "\x00" + profile
if c, ok := clients[key]; ok {
return c, nil
}

cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
opts := []func(*awsconfig.LoadOptions) error{awsconfig.WithRegion(region)}
if profile != "" {
opts = append(opts, awsconfig.WithSharedConfigProfile(profile))
}
cfg, err := awsconfig.LoadDefaultConfig(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("load AWS config for %s: %w", region, err)
if profile != "" {
return nil, fmt.Errorf("load AWS config for region=%s profile=%s: %w", region, profile, err)
}
return nil, fmt.Errorf("load AWS config for region=%s: %w", region, err)
}

c := &Client{
Expand All @@ -45,7 +54,7 @@ func NewClient(ctx context.Context, region string) (*Client, error) {
STS: sts.NewFromConfig(cfg),
Region: region,
}
clients[region] = c
clients[key] = c
return c, nil
}

Expand Down
2 changes: 1 addition & 1 deletion cli/internal/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func init() {
if opts.Region == "" {
return nil, fmt.Errorf("AWS region is required")
}
client, err := NewClient(ctx, opts.Region)
client, err := NewClient(ctx, opts.Region, "")
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions cli/internal/scoreboard/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ type SSMTransport struct {

// NewSSMTransport builds an SSM transport. Region defaults to the SDK's
// default if empty.
func NewSSMTransport(ctx context.Context, instanceID, reportPath, region string) (*SSMTransport, error) {
func NewSSMTransport(ctx context.Context, instanceID, reportPath, region, profile string) (*SSMTransport, error) {
if instanceID == "" {
return nil, fmt.Errorf("instance ID is required")
}
c, err := awsclient.NewClient(ctx, region)
c, err := awsclient.NewClient(ctx, region, profile)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions cli/internal/scoreboard/transport_ares.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ type AresTransport struct {

// NewAresTransport constructs an AresTransport. binaryPath defaults to
// /usr/local/bin/ares when empty.
func NewAresTransport(ctx context.Context, instanceID, binaryPath, region string) (*AresTransport, error) {
func NewAresTransport(ctx context.Context, instanceID, binaryPath, region, profile string) (*AresTransport, error) {
if instanceID == "" {
return nil, fmt.Errorf("instance ID is required")
}
c, err := awsclient.NewClient(ctx, region)
c, err := awsclient.NewClient(ctx, region, profile)
if err != nil {
return nil, err
}
Expand Down
Loading