Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions internal/clients/nightfall/nightfall.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -266,6 +268,28 @@ func (n *Client) buildScanRequest(items []string) *nf.ScanTextRequest {
}
}

// isAuthError reports whether err represents a Nightfall API authentication
// failure (HTTP 401 or 403). The Nightfall Go SDK returns *nf.Error for any
// non-2xx response, with Code set to the HTTP status.
func isAuthError(err error) bool {
var apiErr *nf.Error
if errors.As(err, &apiErr) {
return apiErr.Code == http.StatusUnauthorized || apiErr.Code == http.StatusForbidden
}
return false
}

// describeScanError returns a user-facing error message for a failed scan
// request. Authentication failures are called out explicitly so the action
// fails loudly when NIGHTFALL_API_KEY is missing/expired/revoked, instead of
// silently passing the PR through (see INTM-4909).
func describeScanError(err error) error {
if isAuthError(err) {
return fmt.Errorf("nightfall API authentication failed (check NIGHTFALL_API_KEY is valid and not expired): %w", err)
}
return err
}

func (n *Client) scanFileContent(
ctx context.Context,
cts []*fileToScan,
Expand All @@ -282,7 +306,7 @@ func (n *Client) scanFileContent(
resp, err := n.Scan(ctx, items)
if err != nil {
logger.Debug(fmt.Sprintf("Error sending request number %d with %d items: %v", requestNum, len(items), err))
return nil, err
return nil, describeScanError(err)
}

// Determine findings from response and create comments
Expand All @@ -296,8 +320,10 @@ func (n *Client) scanAllFiles(
logger logger.Logger,
cts []*fileToScan,
commentCh chan<- []*diffreviewer.Comment,
errCh chan<- error,
) {
defer close(commentCh)
defer close(errCh)
blockingCh := make(chan struct{}, n.MaxNumberRoutines)
var wg sync.WaitGroup
requestBatches := make([][]*fileToScan, 0)
Expand Down Expand Up @@ -338,17 +364,24 @@ func (n *Client) scanAllFiles(
blockingCh <- struct{}{}
go func(loopCount int, cts []*fileToScan) {
defer wg.Done()
defer func() { <-blockingCh }()
if ctx.Err() != nil {
return
}

c, err := n.scanFileContent(ctx, cts, loopCount+1, logger)
if err != nil {
logger.Error(fmt.Sprintf("Unable to scan %d content items", len(cts)))
} else {
commentCh <- c
logger.Error(fmt.Sprintf("Unable to scan %d content items: %v", len(cts), err))
// Propagate error so the action exits non-zero. We do a
// non-blocking send: the consumer only needs the first
// failure to fail the run (see INTM-4909).
select {
case errCh <- err:
default:
}
return
}
<-blockingCh
commentCh <- c
}(i, requestBatches[i])
}
wg.Wait()
Expand Down Expand Up @@ -379,19 +412,38 @@ func (n *Client) ReviewDiff(ctx context.Context, logger logger.Logger, fileDiffs
}

commentCh := make(chan []*diffreviewer.Comment)
// Buffered so a worker can record an error without blocking; we only
// surface the first error we observe (see INTM-4909).
errCh := make(chan error, 1)
newCtx, cancel := context.WithDeadline(ctx, time.Now().Add(defaultTimeout))
defer cancel()

go n.scanAllFiles(newCtx, logger, fileToScanList, commentCh)
go n.scanAllFiles(newCtx, logger, fileToScanList, commentCh, errCh)

comments := make([]*diffreviewer.Comment, 0)
var scanErr error
for {
select {
case c, chOpen := <-commentCh:
if !chOpen {
// commentCh is closed once scanAllFiles is done. Drain any
// pending error before returning so we don't lose a failure
// that came in alongside the close.
if scanErr == nil {
if err, ok := <-errCh; ok {
scanErr = err
}
}
if scanErr != nil {
return nil, scanErr
}
return comments, nil
}
comments = append(comments, c...)
case err, ok := <-errCh:
if ok && scanErr == nil {
scanErr = err
}
case <-newCtx.Done():
return nil, newCtx.Err()
}
Expand Down
59 changes: 59 additions & 0 deletions internal/clients/nightfall/nightfall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,65 @@ func TestReviewDiffHasFindingMetadata(t *testing.T) {
assert.Equal(t, expectedComments, comments, "Received incorrect response from ReviewDiff")
}

// TestReviewDiffPropagatesScanError verifies that a scan-time failure (such
// as a bad NIGHTFALL_API_KEY producing a 401) bubbles up from ReviewDiff
// rather than being silently swallowed. See INTM-4909: previously the
// worker logged the error and the action exited 0, letting unscanned PRs
// merge.
func TestReviewDiffPropagatesScanError(t *testing.T) {
mockAPIClient := &mockNightfall{
scanFn: func(ctx context.Context, request *nf.ScanTextRequest) (*nf.ScanTextResponse, error) {
return nil, &nf.Error{Code: 401, Message: "Unauthorized"}
},
}
client := Client{
APIClient: mockAPIClient,
DetectionRules: testDetectionRules,
MaxNumberRoutines: 1,
}

input := []*diffreviewer.FileDiff{
{
PathNew: "test/data",
Hunks: []*diffreviewer.Hunk{
{
Lines: []*diffreviewer.Line{
{LnumNew: 1, Content: "some content"},
},
},
},
},
}

comments, err := client.ReviewDiff(context.Background(), githublogger.NewDefaultGithubLogger(), input)
assert.Error(t, err, "ReviewDiff must surface scan-time errors so the action exits non-zero")
assert.Nil(t, comments)
assert.Contains(t, err.Error(), "authentication failed",
"401 errors should be wrapped with a clear NIGHTFALL_API_KEY hint")
}

// TestIsAuthError covers the 401/403 detection used to produce a clear
// error message when NIGHTFALL_API_KEY is invalid (INTM-4909).
func TestIsAuthError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{name: "nil", err: nil, want: false},
{name: "generic error", err: errors.New("boom"), want: false},
{name: "500 from API", err: &nf.Error{Code: 500}, want: false},
{name: "401 from API", err: &nf.Error{Code: 401}, want: true},
{name: "403 from API", err: &nf.Error{Code: 403}, want: true},
{name: "wrapped 401", err: fmt.Errorf("outer: %w", &nf.Error{Code: 401}), want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, isAuthError(tt.err))
})
}
}

func TestScanPaths(t *testing.T) {
client := Client{
DetectionRules: testDetectionRules,
Expand Down