From 942fcc98a36f2cf86d75141e1a7a0dd43a2773b2 Mon Sep 17 00:00:00 2001 From: Davanum Srinivas Date: Fri, 19 Jun 2026 09:01:05 -0400 Subject: [PATCH] fix(atelet): stream sandbox asset downloads instead of buffering in memory Previously, fetchAsset retrieved each asset through FetchFromGCS, which reads the entire object into a byte slice before hashing and writing it. Peak memory therefore scaled with the asset size. This was acceptable for the roughly 50 MB runsc binary, but the micro-VM runtime's kernel and root filesystem images range from hundreds of megabytes to several gigabytes and would exhaust the memory of atelet, which is shared across all actors on a node. The download now streams directly to the temporary file, computing the SHA-256 digest in the same pass via io.MultiWriter and bounding the transfer with an io.LimitReader. Peak memory is now the size of the io.Copy buffer, independent of the asset size. The LimitReader also provides a secondary safeguard for disk usage against a misconfigured or malicious URL that returns an unbounded stream. The limit is 8 GiB and is declared as a variable so that tests may lower it. The digest is verified only after the copy completes. A size or hash failure therefore leaves the data at the temporary path, which is subsequently removed; it is never renamed to the content-addressed cache path, so a failed download cannot corrupt the cache. This change also introduces ategcs.Open, a streaming reader, along with the accompanying TestFetchAssetStreaming. --- cmd/atelet/internal/ategcs/objects.go | 14 +++++ cmd/atelet/main_test.go | 73 +++++++++++++++++++++++++++ cmd/atelet/sandbox_assets.go | 58 ++++++++++++--------- 3 files changed, 121 insertions(+), 24 deletions(-) diff --git a/cmd/atelet/internal/ategcs/objects.go b/cmd/atelet/internal/ategcs/objects.go index 0a9c40f51..3fa9f0e61 100644 --- a/cmd/atelet/internal/ategcs/objects.go +++ b/cmd/atelet/internal/ategcs/objects.go @@ -58,6 +58,20 @@ func FetchFromGCS(ctx context.Context, client ObjectStorage, gsURL string) ([]by return content, nil } +// Open streams the object at gsURL; the caller must Close the returned reader. +// Unlike FetchFromGCS it does not buffer the whole object in memory. +func Open(ctx context.Context, client ObjectStorage, gsURL string) (io.ReadCloser, error) { + bucket, object, err := parseGCSURL(gsURL) + if err != nil { + return nil, fmt.Errorf("while parsing url: %w", err) + } + rc, err := client.GetObject(ctx, bucket, object) + if err != nil { + return nil, fmt.Errorf("while getting object bucket=%q object=%q: %w", bucket, object, err) + } + return rc, nil +} + // SendBytesToGCS uploads the given bytes (uncompressed) to gsURL. Intended for // small objects such as the snapshot manifest. func SendBytesToGCS(ctx context.Context, client ObjectStorage, gsURL string, content []byte) error { diff --git a/cmd/atelet/main_test.go b/cmd/atelet/main_test.go index e76953636..6be78fa51 100644 --- a/cmd/atelet/main_test.go +++ b/cmd/atelet/main_test.go @@ -15,9 +15,15 @@ package main import ( + "bytes" "context" + "crypto/sha256" + "errors" + "fmt" + "io" "os" "path/filepath" + "strings" "testing" "github.com/agent-substrate/substrate/internal/ateompath" @@ -259,6 +265,73 @@ func TestFetchAssetRejectsBadHash(t *testing.T) { } } +// fakeObjectStorage serves fixed bytes for GetObject so fetchAsset can be tested. +type fakeObjectStorage struct { + data []byte + err error +} + +func (f fakeObjectStorage) GetObject(_ context.Context, _, _ string) (io.ReadCloser, error) { + if f.err != nil { + return nil, f.err + } + return io.NopCloser(bytes.NewReader(f.data)), nil +} + +func (fakeObjectStorage) PutObject(_ context.Context, _, _ string, _ io.Reader) error { return nil } + +// TestFetchAssetStreaming covers the streamed download: good asset cached, +// over-cap rejected, hash mismatch rejected (failures leave no cache file). +func TestFetchAssetStreaming(t *testing.T) { + origDir, origCap := ateompath.StaticFilesDir, maxAssetBytes + t.Cleanup(func() { ateompath.StaticFilesDir, maxAssetBytes = origDir, origCap }) + + content := []byte("micro-vm kernel bytes") + goodHash := fmt.Sprintf("%x", sha256.Sum256(content)) + const url = "gs://test-bucket/asset" + + t.Run("good asset is cached", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + path, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash}) + if err != nil { + t.Fatalf("fetchAsset: %v", err) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading cached asset: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("cached bytes = %q, want %q", got, content) + } + }) + + t.Run("over-cap asset rejected, cache not written", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + maxAssetBytes = 4 // content is longer than this + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash}); err == nil { + t.Fatal("fetchAsset accepted an over-cap asset") + } + if _, err := os.Stat(ateompath.RunSCBinaryPath(goodHash)); !errors.Is(err, os.ErrNotExist) { + t.Errorf("over-cap download left a file at the cache path (stat err = %v)", err) + } + }) + + t.Run("hash mismatch rejected, cache not written", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + maxAssetBytes = origCap + wrongHash := strings.Repeat("a", 64) // valid 64-hex format, wrong value + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: wrongHash}); err == nil { + t.Fatal("fetchAsset accepted a hash mismatch") + } + if _, err := os.Stat(ateompath.RunSCBinaryPath(wrongHash)); !errors.Is(err, os.ErrNotExist) { + t.Errorf("mismatched download left a file at the cache path (stat err = %v)", err) + } + }) +} + // TestRPCBoundariesReject confirms each of the three RPCs validates path inputs // before touching its (here nil) dependencies. A traversal value must be // rejected as InvalidArgument rather than panicking or surfacing as diff --git a/cmd/atelet/sandbox_assets.go b/cmd/atelet/sandbox_assets.go index 9b06d2a7d..876a4b9f0 100644 --- a/cmd/atelet/sandbox_assets.go +++ b/cmd/atelet/sandbox_assets.go @@ -22,6 +22,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" "runtime" @@ -40,6 +41,10 @@ import ( // so a Restore — possibly on another node — is self-describing. const sandboxManifestName = "manifest.json" +// maxAssetBytes guards disk against an unbounded download URL; a var so tests can lower it. +// ponytail: 8GiB ceiling, make it a flag if a rootfs ever needs more. +var maxAssetBytes int64 = 8 << 30 + // assetEntry is one content-addressed sandbox asset (url + sha256). type assetEntry struct { URL string `json:"url"` @@ -112,41 +117,46 @@ func (s *AteomHerder) fetchAsset(ctx context.Context, entry assetEntry) (string, // gVisor's runsc lives in the public gs://gvisor bucket, so the anonymous // client suffices. TODO: drive authenticated asset fetches from atelet // configuration for assets in private buckets. - content, err := ategcs.FetchFromGCS(ctx, s.anonGCSClient, entry.URL) + rc, err := ategcs.Open(ctx, s.anonGCSClient, entry.URL) if err != nil { return "", fmt.Errorf("while fetching %v: %w", entry.URL, err) } + defer rc.Close() - sum := sha256.Sum256(content) wantSum, err := hex.DecodeString(entry.SHA256) if err != nil { return "", fmt.Errorf("while parsing sha256 hash: %w", err) } - if !bytes.Equal(sum[:], wantSum) { - return "", fmt.Errorf("sha256 mismatch; got=%s want=%s", hex.EncodeToString(sum[:]), entry.SHA256) - } - - tmpFileName, err := func() (string, error) { - localDir := filepath.Dir(localPath) - tmpFile, err := os.CreateTemp(localDir, filepath.Base(localPath)+"-download-") - if err != nil { - return "", fmt.Errorf("while temp file: %w", err) - } - defer tmpFile.Close() - - if _, err := tmpFile.Write(content); err != nil { - return "", fmt.Errorf("while writing content to temp file: %w", err) - } - if err := tmpFile.Chmod(0o755); err != nil { - return "", fmt.Errorf("while setting file mode: %w", err) - } - return tmpFile.Name(), nil - }() + + tmpFile, err := os.CreateTemp(filepath.Dir(localPath), filepath.Base(localPath)+"-download-") if err != nil { - return "", fmt.Errorf("while populating temp file: %w", err) + return "", fmt.Errorf("while creating temp file: %w", err) } + tmpName := tmpFile.Name() + defer os.Remove(tmpName) // partial-download cleanup; no-op after rename + defer tmpFile.Close() - if err := os.Rename(tmpFileName, localPath); err != nil { + // Stream to disk, hashing as we go; +1 lets an over-cap asset trip n > cap. + // Verify-after-copy keeps a bad download at the temp path, never the cache. + hasher := sha256.New() + n, err := io.Copy(io.MultiWriter(tmpFile, hasher), io.LimitReader(rc, maxAssetBytes+1)) + if err != nil { + return "", fmt.Errorf("while downloading %v: %w", entry.URL, err) + } + if n > maxAssetBytes { + return "", fmt.Errorf("asset %v exceeds %d-byte cap", entry.URL, maxAssetBytes) + } + if got := hasher.Sum(nil); !bytes.Equal(got, wantSum) { + return "", fmt.Errorf("sha256 mismatch; got=%x want=%s", got, entry.SHA256) + } + + if err := tmpFile.Chmod(0o755); err != nil { + return "", fmt.Errorf("while setting file mode: %w", err) + } + if err := tmpFile.Close(); err != nil { // flush before rename + return "", fmt.Errorf("while closing temp file: %w", err) + } + if err := os.Rename(tmpName, localPath); err != nil { return "", fmt.Errorf("while renaming temp file to target: %w", err) }