diff --git a/github/github.go b/github/github.go index f0c0d5abf91..5566fc16691 100644 --- a/github/github.go +++ b/github/github.go @@ -159,7 +159,7 @@ var errNonNilContext = errors.New("context must be non-nil") // A Client manages communication with the GitHub API. type Client struct { - clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func. + clientMu sync.Mutex // clientMu protects the client fields during copy and Client calls. client *http.Client // HTTP client used to communicate with the API. clientIgnoreRedirects *http.Client // HTTP client used to communicate with the API on endpoints where we don't want to follow redirects. diff --git a/github/migrations.go b/github/migrations.go index e53a9ca9398..3e9f92b1b10 100644 --- a/github/migrations.go +++ b/github/migrations.go @@ -9,8 +9,6 @@ import ( "context" "errors" "fmt" - "net/http" - "strings" ) // MigrationService provides access to the migration related functions @@ -186,29 +184,18 @@ func (s *MigrationService) MigrationArchiveURL(ctx context.Context, org string, if err != nil { return "", err } - req.Header.Set("Accept", mediaTypeMigrationsPreview) - s.client.clientMu.Lock() - defer s.client.clientMu.Unlock() - - // Disable the redirect mechanism because AWS fails if the GitHub auth token is provided. - var loc string - saveRedirect := s.client.client.CheckRedirect - s.client.client.CheckRedirect = func(req *http.Request, _ []*http.Request) error { - loc = req.URL.String() - return errors.New("disable redirect") + loc, _, err := s.client.bareDoUntilFound(ctx, req, 10) + if err != nil { + return "", err } - defer func() { s.client.client.CheckRedirect = saveRedirect }() - _, err = s.client.Do(ctx, req, nil) // expect error from disable redirect - if err == nil { + if loc == nil { return "", errors.New("expected redirect, none provided") } - if !strings.Contains(err.Error(), "disable redirect") { - return "", err - } - return loc, nil + + return loc.String(), nil } // DeleteMigration deletes a previous migration archive. diff --git a/github/migrations_test.go b/github/migrations_test.go index 68ef325ff8b..002acdc48f6 100644 --- a/github/migrations_test.go +++ b/github/migrations_test.go @@ -128,7 +128,7 @@ func TestMigrationService_MigrationStatus(t *testing.T) { }) } -func TestMigrationService_MigrationArchiveURL(t *testing.T) { +func TestMigrationService_MigrationArchiveURL_Redirect(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -161,6 +161,28 @@ func TestMigrationService_MigrationArchiveURL(t *testing.T) { }) } +func TestMigrationService_MigrationArchiveURL_NoRedirect(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + mux.HandleFunc("/orgs/o/migrations/1/archive", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + testHeader(t, r, "Accept", mediaTypeMigrationsPreview) + + w.WriteHeader(http.StatusOK) + assertWrite(t, w, []byte("0123456789abcdef")) + }) + + ctx := t.Context() + got, err := client.Migrations.MigrationArchiveURL(ctx, "o", 1) + if err == nil { + t.Error("Migrations.MigrationArchiveURL did not return expected error") + } + if got != "" { + t.Errorf("MigrationArchiveURL = %+v, want %+v", got, "") + } +} + func TestMigrationService_DeleteMigration(t *testing.T) { t.Parallel() client, mux, _ := setup(t) diff --git a/github/repos_releases.go b/github/repos_releases.go index 18bbe885a53..a0876cf38ce 100644 --- a/github/repos_releases.go +++ b/github/repos_releases.go @@ -365,36 +365,27 @@ func (s *RepositoriesService) DownloadReleaseAsset(ctx context.Context, owner, r } req.Header.Set("Accept", defaultMediaType) - s.client.clientMu.Lock() - defer s.client.clientMu.Unlock() - - var loc string - saveRedirect := s.client.client.CheckRedirect - s.client.client.CheckRedirect = func(req *http.Request, _ []*http.Request) error { - loc = req.URL.String() - return errors.New("disable redirect") + loc, resp, err := s.client.bareDoUntilFound(ctx, req, 10) + if err != nil { + return nil, "", err } - defer func() { s.client.client.CheckRedirect = saveRedirect }() - req = withContext(ctx, req) - resp, err := s.client.client.Do(req) - if err != nil { - if !strings.Contains(err.Error(), "disable redirect") { - return nil, "", err - } - if followRedirectsClient != nil { - rc, err := s.downloadReleaseAssetFromURL(ctx, followRedirectsClient, loc) - return rc, "", err - } - return nil, loc, nil // Intentionally return no error with valid redirect URL. + // No redirect, stream the response body directly. + if loc == nil { + return resp.Body, "", nil } - if err := CheckResponse(resp); err != nil { - _ = resp.Body.Close() - return nil, "", err + // Close body as it's not needed when following redirects or returning the redirect URL. + _ = resp.Body.Close() + + // Got a redirect URL. + redirectStr := loc.String() + if followRedirectsClient != nil { + rc, err := s.downloadReleaseAssetFromURL(ctx, followRedirectsClient, redirectStr) + return rc, "", err } - return resp.Body, "", nil + return nil, redirectStr, nil } func (s *RepositoriesService) downloadReleaseAssetFromURL(ctx context.Context, followRedirectsClient *http.Client, url string) (rc io.ReadCloser, err error) {