Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ var errNonNilContext = errors.New("context must be non-nil")

// A Client manages communication with the GitHub API.
type Client struct {
clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func.
clientMu sync.Mutex // clientMu protects the client fields during copy and Client calls.
client *http.Client // HTTP client used to communicate with the API.
clientIgnoreRedirects *http.Client // HTTP client used to communicate with the API on endpoints where we don't want to follow redirects.

Expand Down
25 changes: 6 additions & 19 deletions github/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"
)

// MigrationService provides access to the migration related functions
Expand Down Expand Up @@ -186,29 +184,18 @@ func (s *MigrationService) MigrationArchiveURL(ctx context.Context, org string,
if err != nil {
return "", err
}

req.Header.Set("Accept", mediaTypeMigrationsPreview)
Comment thread
stevehipwell marked this conversation as resolved.

s.client.clientMu.Lock()
defer s.client.clientMu.Unlock()

// Disable the redirect mechanism because AWS fails if the GitHub auth token is provided.
var loc string
saveRedirect := s.client.client.CheckRedirect
s.client.client.CheckRedirect = func(req *http.Request, _ []*http.Request) error {
loc = req.URL.String()
return errors.New("disable redirect")
loc, _, err := s.client.bareDoUntilFound(ctx, req, 10)
if err != nil {
return "", err
}
defer func() { s.client.client.CheckRedirect = saveRedirect }()

_, err = s.client.Do(ctx, req, nil) // expect error from disable redirect
if err == nil {
if loc == nil {
return "", errors.New("expected redirect, none provided")
}
if !strings.Contains(err.Error(), "disable redirect") {
return "", err
}
return loc, nil

return loc.String(), nil
}

// DeleteMigration deletes a previous migration archive.
Expand Down
24 changes: 23 additions & 1 deletion github/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestMigrationService_MigrationStatus(t *testing.T) {
})
}

func TestMigrationService_MigrationArchiveURL(t *testing.T) {
func TestMigrationService_MigrationArchiveURL_Redirect(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

Expand Down Expand Up @@ -161,6 +161,28 @@ func TestMigrationService_MigrationArchiveURL(t *testing.T) {
})
}

func TestMigrationService_MigrationArchiveURL_NoRedirect(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

mux.HandleFunc("/orgs/o/migrations/1/archive", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
testHeader(t, r, "Accept", mediaTypeMigrationsPreview)

w.WriteHeader(http.StatusOK)
assertWrite(t, w, []byte("0123456789abcdef"))
})

ctx := t.Context()
got, err := client.Migrations.MigrationArchiveURL(ctx, "o", 1)
if err == nil {
t.Error("Migrations.MigrationArchiveURL did not return expected error")
}
if got != "" {
t.Errorf("MigrationArchiveURL = %+v, want %+v", got, "")
}
}

func TestMigrationService_DeleteMigration(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)
Expand Down
39 changes: 15 additions & 24 deletions github/repos_releases.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,36 +365,27 @@ func (s *RepositoriesService) DownloadReleaseAsset(ctx context.Context, owner, r
}
req.Header.Set("Accept", defaultMediaType)

s.client.clientMu.Lock()
defer s.client.clientMu.Unlock()

var loc string
saveRedirect := s.client.client.CheckRedirect
s.client.client.CheckRedirect = func(req *http.Request, _ []*http.Request) error {
loc = req.URL.String()
return errors.New("disable redirect")
loc, resp, err := s.client.bareDoUntilFound(ctx, req, 10)
if err != nil {
return nil, "", err
}
defer func() { s.client.client.CheckRedirect = saveRedirect }()

req = withContext(ctx, req)
resp, err := s.client.client.Do(req)
if err != nil {
if !strings.Contains(err.Error(), "disable redirect") {
return nil, "", err
}
if followRedirectsClient != nil {
rc, err := s.downloadReleaseAssetFromURL(ctx, followRedirectsClient, loc)
return rc, "", err
}
return nil, loc, nil // Intentionally return no error with valid redirect URL.
// No redirect, stream the response body directly.
if loc == nil {
return resp.Body, "", nil
}

if err := CheckResponse(resp); err != nil {
_ = resp.Body.Close()
return nil, "", err
// Close body as it's not needed when following redirects or returning the redirect URL.
_ = resp.Body.Close()

// Got a redirect URL.
redirectStr := loc.String()
if followRedirectsClient != nil {
rc, err := s.downloadReleaseAssetFromURL(ctx, followRedirectsClient, redirectStr)
return rc, "", err
}

return resp.Body, "", nil
return nil, redirectStr, nil
}

func (s *RepositoriesService) downloadReleaseAssetFromURL(ctx context.Context, followRedirectsClient *http.Client, url string) (rc io.ReadCloser, err error) {
Expand Down
Loading