Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions pacts/replicated-cli-vendor-api.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
"app": {
"channels": [
{
},
},
{
},
},
{
}
}
],
"created": "2000-02-01T12:30:00Z",
"description": "",
Expand Down Expand Up @@ -83,7 +83,7 @@
{
"channels": [
{
}
}
],
"id": "replicated-cli-list-apps-app",
"name": "Replicated CLI List Apps App",
Expand Down Expand Up @@ -223,7 +223,8 @@
"id": "replicated-cli-get-channel-unstable",
"name": "Unstable",
"releases": [
]

]
}
},
"matchingRules": {
Expand Down Expand Up @@ -773,7 +774,7 @@
{
"eventType": "release.promoted",
"filters": {
}
}
}
],
"id": "notif-sub-1",
Expand Down Expand Up @@ -817,7 +818,7 @@
{
"eventType": "release.promoted",
"filters": {
}
}
}
],
"isEnabled": true,
Expand All @@ -834,7 +835,7 @@
{
"eventType": "release.promoted",
"filters": {
}
}
}
],
"id": "notif-sub-1",
Expand Down Expand Up @@ -877,7 +878,7 @@
{
"eventType": "release.promoted",
"filters": {
}
}
}
],
"id": "notif-sub-1",
Expand Down Expand Up @@ -923,7 +924,7 @@
{
"eventType": "release.promoted",
"filters": {
}
}
}
],
"id": "notif-sub-1",
Expand Down Expand Up @@ -983,7 +984,8 @@
"events": [
{
"attempts": [
],

],
"eventData": {
"releaseId": "rel-1"
},
Expand Down Expand Up @@ -1720,7 +1722,8 @@
},
"body": {
"releases": [
]

]
}
}
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/ociclient/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func uploadManifest(ctx context.Context, blobs []*Blob, configBlob *Blob, repoUR
return err
}

req, err := http.NewRequest("PUT", fmt.Sprintf("%s/manifests/%s", repoURL, tag), bytes.NewReader(manifestBytes))
req, err := http.NewRequestWithContext(ctx, "PUT", fmt.Sprintf("%s/manifests/%s", repoURL, tag), bytes.NewReader(manifestBytes))
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/ociclient/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro
}
defer file.Close()

req, err := http.NewRequest("POST", fmt.Sprintf("%s/blobs/uploads/", repoURL), nil)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/blobs/uploads/", repoURL), nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro
chunk := bytes.NewReader(buf[:n])
contentRange := fmt.Sprintf("bytes %d-%d/%d", totalSize-int64(n), totalSize-1, totalSize)

req, err := http.NewRequest("PATCH", uploadURL, chunk)
req, err := http.NewRequestWithContext(ctx, "PATCH", uploadURL, chunk)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func uploadBlob(ctx context.Context, filePath, repoURL, jwtToken string, showPro

digest := fmt.Sprintf("sha256:%s", hex.EncodeToString(hasher.Sum(nil)))

req, err = http.NewRequest("PUT", uploadURL+"?digest="+digest, nil)
req, err = http.NewRequestWithContext(ctx, "PUT", uploadURL+"?digest="+digest, nil)
if err != nil {
return nil, err
}
Expand Down
124 changes: 124 additions & 0 deletions pkg/ociclient/upload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package ociclient

import (
"context"
"io"
"net/http"
"os"
"testing"
)

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func TestUploadBlobUsesCallerContext(t *testing.T) {
tempFile, err := os.CreateTemp("", "ociclient-upload-*.bin")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())

if _, err := tempFile.Write([]byte("content")); err != nil {
t.Fatal(err)
}
if err := tempFile.Close(); err != nil {
t.Fatal(err)
}

type contextKey string
ctx := context.WithValue(context.Background(), contextKey("request"), "upload")

oldClient := http.DefaultClient
defer func() {
http.DefaultClient = oldClient
}()

requests := 0
http.DefaultClient = &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if got := req.Context().Value(contextKey("request")); got != "upload" {
t.Fatalf("request context value = %v, want upload", got)
}

requests++
resp := &http.Response{
StatusCode: http.StatusAccepted,
Header: make(http.Header),
Body: io.NopCloser(http.NoBody),
}

switch requests {
case 1:
if req.Method != http.MethodPost {
t.Fatalf("request %d method = %s, want POST", requests, req.Method)
}
resp.Header.Set("Location", "https://registry.example/upload/1")
case 2:
if req.Method != http.MethodPatch {
t.Fatalf("request %d method = %s, want PATCH", requests, req.Method)
}
resp.Header.Set("Location", "https://registry.example/upload/1")
case 3:
if req.Method != http.MethodPut {
t.Fatalf("request %d method = %s, want PUT", requests, req.Method)
}
resp.StatusCode = http.StatusCreated
default:
t.Fatalf("unexpected request %d", requests)
}

return resp, nil
}),
}

if _, err := uploadBlob(ctx, tempFile.Name(), "https://registry.example/v2/repo", "token", false, ""); err != nil {
t.Fatal(err)
}

if requests != 3 {
t.Fatalf("requests = %d, want 3", requests)
}
}

func TestUploadManifestUsesCallerContext(t *testing.T) {
type contextKey string
ctx := context.WithValue(context.Background(), contextKey("request"), "manifest")

oldClient := http.DefaultClient
defer func() {
http.DefaultClient = oldClient
}()

http.DefaultClient = &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if got := req.Context().Value(contextKey("request")); got != "manifest" {
t.Fatalf("request context value = %v, want manifest", got)
}
if req.Method != http.MethodPut {
t.Fatalf("request method = %s, want PUT", req.Method)
}

return &http.Response{
StatusCode: http.StatusCreated,
Body: io.NopCloser(http.NoBody),
}, nil
}),
}

blob := &Blob{
Digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111",
Size: 7,
RelativePath: "model.bin",
}
configBlob := &Blob{
Digest: "sha256:2222222222222222222222222222222222222222222222222222222222222222",
Size: 2,
}

if err := uploadManifest(ctx, []*Blob{blob}, configBlob, "https://registry.example/v2/repo", "token", "tag", "."); err != nil {
t.Fatal(err)
}
}
5 changes: 4 additions & 1 deletion pkg/platformclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ func (c *HTTPClient) DoJSON(ctx context.Context, method string, path string, suc
return err
}
}
req, err := http.NewRequest(method, endpoint, &buf)
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, method, endpoint, &buf)
if err != nil {
return err
}
Expand Down
42 changes: 42 additions & 0 deletions pkg/platformclient/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package platformclient

import (
"context"
"errors"
"net/http"
"testing"
"time"
)

func TestGetFeaturesRespectsCanceledContext(t *testing.T) {
originalHTTPClient := httpClient
defer func() {
httpClient = originalHTTPClient
}()

httpClient = &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
select {
case <-r.Context().Done():
return nil, r.Context().Err()
case <-time.After(20 * time.Millisecond):
return nil, errors.New("request context was not canceled")
}
}),
}

client := NewHTTPClient("https://example.test", "test-api-key")
ctx, cancel := context.WithCancel(context.Background())
cancel()

_, err := client.GetFeatures(ctx)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got %v", err)
}
}

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
Loading