From 0d7e317cb891f1dea392bbc3538199ad09020950 Mon Sep 17 00:00:00 2001 From: Marc Campbell Date: Sun, 17 May 2026 10:02:53 -0600 Subject: [PATCH 1/2] use context in request --- pacts/replicated-cli-vendor-api.json | 27 ++++++++++-------- pkg/platformclient/client.go | 5 +++- pkg/platformclient/client_test.go | 42 ++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 pkg/platformclient/client_test.go diff --git a/pacts/replicated-cli-vendor-api.json b/pacts/replicated-cli-vendor-api.json index 08ac5d29f..0c7fb29c1 100644 --- a/pacts/replicated-cli-vendor-api.json +++ b/pacts/replicated-cli-vendor-api.json @@ -28,11 +28,11 @@ "app": { "channels": [ { - }, + }, { - }, + }, { - } + } ], "created": "2000-02-01T12:30:00Z", "description": "", @@ -83,7 +83,7 @@ { "channels": [ { - } + } ], "id": "replicated-cli-list-apps-app", "name": "Replicated CLI List Apps App", @@ -223,7 +223,8 @@ "id": "replicated-cli-get-channel-unstable", "name": "Unstable", "releases": [ - ] + + ] } }, "matchingRules": { @@ -773,7 +774,7 @@ { "eventType": "release.promoted", "filters": { - } + } } ], "id": "notif-sub-1", @@ -817,7 +818,7 @@ { "eventType": "release.promoted", "filters": { - } + } } ], "isEnabled": true, @@ -834,7 +835,7 @@ { "eventType": "release.promoted", "filters": { - } + } } ], "id": "notif-sub-1", @@ -877,7 +878,7 @@ { "eventType": "release.promoted", "filters": { - } + } } ], "id": "notif-sub-1", @@ -923,7 +924,7 @@ { "eventType": "release.promoted", "filters": { - } + } } ], "id": "notif-sub-1", @@ -983,7 +984,8 @@ "events": [ { "attempts": [ - ], + + ], "eventData": { "releaseId": "rel-1" }, @@ -1720,7 +1722,8 @@ }, "body": { "releases": [ - ] + + ] } } }, diff --git a/pkg/platformclient/client.go b/pkg/platformclient/client.go index 14824bd61..dbbc9f20a 100644 --- a/pkg/platformclient/client.go +++ b/pkg/platformclient/client.go @@ -176,7 +176,10 @@ func (c *HTTPClient) DoJSON(ctx context.Context, method string, path string, suc return err } } - req, err := http.NewRequest(method, endpoint, &buf) + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, method, endpoint, &buf) if err != nil { return err } diff --git a/pkg/platformclient/client_test.go b/pkg/platformclient/client_test.go new file mode 100644 index 000000000..b23738ac4 --- /dev/null +++ b/pkg/platformclient/client_test.go @@ -0,0 +1,42 @@ +package platformclient + +import ( + "context" + "errors" + "net/http" + "testing" + "time" +) + +func TestGetFeaturesRespectsCanceledContext(t *testing.T) { + originalHTTPClient := httpClient + defer func() { + httpClient = originalHTTPClient + }() + + httpClient = &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + select { + case <-r.Context().Done(): + return nil, r.Context().Err() + case <-time.After(20 * time.Millisecond): + return nil, errors.New("request context was not canceled") + } + }), + } + + client := NewHTTPClient("https://example.test", "test-api-key") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.GetFeatures(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled error, got %v", err) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} From 16406102bf66a998bd6733dbad36965e67889619 Mon Sep 17 00:00:00 2001 From: Marc Campbell Date: Sun, 17 May 2026 10:04:53 -0600 Subject: [PATCH 2/2] attach client context, not new --- pkg/ociclient/manifest.go | 2 +- pkg/ociclient/upload.go | 6 +- pkg/ociclient/upload_test.go | 124 +++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 pkg/ociclient/upload_test.go diff --git a/pkg/ociclient/manifest.go b/pkg/ociclient/manifest.go index a03fe12bd..65dbf390d 100644 --- a/pkg/ociclient/manifest.go +++ b/pkg/ociclient/manifest.go @@ -49,7 +49,7 @@ func uploadManifest(ctx context.Context, blobs []*Blob, configBlob *Blob, repoUR return err } - req, err := http.NewRequest("PUT", fmt.Sprintf("%s/manifests/%s", repoURL, tag), bytes.NewReader(manifestBytes)) + req, err := http.NewRequestWithContext(ctx, "PUT", fmt.Sprintf("%s/manifests/%s", repoURL, tag), bytes.NewReader(manifestBytes)) if err != nil { return err } diff --git a/pkg/ociclient/upload.go b/pkg/ociclient/upload.go index 69875cc98..4a006d864 100644 --- a/pkg/ociclient/upload.go +++ b/pkg/ociclient/upload.go @@ -29,7 +29,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro } defer file.Close() - req, err := http.NewRequest("POST", fmt.Sprintf("%s/blobs/uploads/", repoURL), nil) + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/blobs/uploads/", repoURL), nil) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro chunk := bytes.NewReader(buf[:n]) contentRange := fmt.Sprintf("bytes %d-%d/%d", totalSize-int64(n), totalSize-1, totalSize) - req, err := http.NewRequest("PATCH", uploadURL, chunk) + req, err := http.NewRequestWithContext(ctx, "PATCH", uploadURL, chunk) if err != nil { return nil, err } @@ -135,7 +135,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro digest := fmt.Sprintf("sha256:%s", hex.EncodeToString(hasher.Sum(nil))) - req, err = http.NewRequest("PUT", uploadURL+"?digest="+digest, nil) + req, err = http.NewRequestWithContext(ctx, "PUT", uploadURL+"?digest="+digest, nil) if err != nil { return nil, err } diff --git a/pkg/ociclient/upload_test.go b/pkg/ociclient/upload_test.go new file mode 100644 index 000000000..4a5f15cac --- /dev/null +++ b/pkg/ociclient/upload_test.go @@ -0,0 +1,124 @@ +package ociclient + +import ( + "context" + "io" + "net/http" + "os" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestUploadBlobUsesCallerContext(t *testing.T) { + tempFile, err := os.CreateTemp("", "ociclient-upload-*.bin") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + if _, err := tempFile.Write([]byte("content")); err != nil { + t.Fatal(err) + } + if err := tempFile.Close(); err != nil { + t.Fatal(err) + } + + type contextKey string + ctx := context.WithValue(context.Background(), contextKey("request"), "upload") + + oldClient := http.DefaultClient + defer func() { + http.DefaultClient = oldClient + }() + + requests := 0 + http.DefaultClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Context().Value(contextKey("request")); got != "upload" { + t.Fatalf("request context value = %v, want upload", got) + } + + requests++ + resp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: make(http.Header), + Body: io.NopCloser(http.NoBody), + } + + switch requests { + case 1: + if req.Method != http.MethodPost { + t.Fatalf("request %d method = %s, want POST", requests, req.Method) + } + resp.Header.Set("Location", "https://registry.example/upload/1") + case 2: + if req.Method != http.MethodPatch { + t.Fatalf("request %d method = %s, want PATCH", requests, req.Method) + } + resp.Header.Set("Location", "https://registry.example/upload/1") + case 3: + if req.Method != http.MethodPut { + t.Fatalf("request %d method = %s, want PUT", requests, req.Method) + } + resp.StatusCode = http.StatusCreated + default: + t.Fatalf("unexpected request %d", requests) + } + + return resp, nil + }), + } + + if _, err := uploadBlob(ctx, tempFile.Name(), "https://registry.example/v2/repo", "token", false, ""); err != nil { + t.Fatal(err) + } + + if requests != 3 { + t.Fatalf("requests = %d, want 3", requests) + } +} + +func TestUploadManifestUsesCallerContext(t *testing.T) { + type contextKey string + ctx := context.WithValue(context.Background(), contextKey("request"), "manifest") + + oldClient := http.DefaultClient + defer func() { + http.DefaultClient = oldClient + }() + + http.DefaultClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Context().Value(contextKey("request")); got != "manifest" { + t.Fatalf("request context value = %v, want manifest", got) + } + if req.Method != http.MethodPut { + t.Fatalf("request method = %s, want PUT", req.Method) + } + + return &http.Response{ + StatusCode: http.StatusCreated, + Body: io.NopCloser(http.NoBody), + }, nil + }), + } + + blob := &Blob{ + Digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111", + Size: 7, + RelativePath: "model.bin", + } + configBlob := &Blob{ + Digest: "sha256:2222222222222222222222222222222222222222222222222222222222222222", + Size: 2, + } + + if err := uploadManifest(ctx, []*Blob{blob}, configBlob, "https://registry.example/v2/repo", "token", "tag", "."); err != nil { + t.Fatal(err) + } +}