diff --git a/cmd/atelet/internal/ategcs/objects.go b/cmd/atelet/internal/ategcs/objects.go index 0a9c40f51..3fa9f0e61 100644 --- a/cmd/atelet/internal/ategcs/objects.go +++ b/cmd/atelet/internal/ategcs/objects.go @@ -58,6 +58,20 @@ func FetchFromGCS(ctx context.Context, client ObjectStorage, gsURL string) ([]by return content, nil } +// Open streams the object at gsURL; the caller must Close the returned reader. +// Unlike FetchFromGCS it does not buffer the whole object in memory. +func Open(ctx context.Context, client ObjectStorage, gsURL string) (io.ReadCloser, error) { + bucket, object, err := parseGCSURL(gsURL) + if err != nil { + return nil, fmt.Errorf("while parsing url: %w", err) + } + rc, err := client.GetObject(ctx, bucket, object) + if err != nil { + return nil, fmt.Errorf("while getting object bucket=%q object=%q: %w", bucket, object, err) + } + return rc, nil +} + // SendBytesToGCS uploads the given bytes (uncompressed) to gsURL. Intended for // small objects such as the snapshot manifest. func SendBytesToGCS(ctx context.Context, client ObjectStorage, gsURL string, content []byte) error { diff --git a/cmd/atelet/main_test.go b/cmd/atelet/main_test.go index e76953636..6be78fa51 100644 --- a/cmd/atelet/main_test.go +++ b/cmd/atelet/main_test.go @@ -15,9 +15,15 @@ package main import ( + "bytes" "context" + "crypto/sha256" + "errors" + "fmt" + "io" "os" "path/filepath" + "strings" "testing" "github.com/agent-substrate/substrate/internal/ateompath" @@ -259,6 +265,73 @@ func TestFetchAssetRejectsBadHash(t *testing.T) { } } +// fakeObjectStorage serves fixed bytes for GetObject so fetchAsset can be tested. +type fakeObjectStorage struct { + data []byte + err error +} + +func (f fakeObjectStorage) GetObject(_ context.Context, _, _ string) (io.ReadCloser, error) { + if f.err != nil { + return nil, f.err + } + return io.NopCloser(bytes.NewReader(f.data)), nil +} + +func (fakeObjectStorage) PutObject(_ context.Context, _, _ string, _ io.Reader) error { return nil } + +// TestFetchAssetStreaming covers the streamed download: good asset cached, +// over-cap rejected, hash mismatch rejected (failures leave no cache file). +func TestFetchAssetStreaming(t *testing.T) { + origDir, origCap := ateompath.StaticFilesDir, maxAssetBytes + t.Cleanup(func() { ateompath.StaticFilesDir, maxAssetBytes = origDir, origCap }) + + content := []byte("micro-vm kernel bytes") + goodHash := fmt.Sprintf("%x", sha256.Sum256(content)) + const url = "gs://test-bucket/asset" + + t.Run("good asset is cached", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + path, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash}) + if err != nil { + t.Fatalf("fetchAsset: %v", err) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading cached asset: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("cached bytes = %q, want %q", got, content) + } + }) + + t.Run("over-cap asset rejected, cache not written", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + maxAssetBytes = 4 // content is longer than this + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash}); err == nil { + t.Fatal("fetchAsset accepted an over-cap asset") + } + if _, err := os.Stat(ateompath.RunSCBinaryPath(goodHash)); !errors.Is(err, os.ErrNotExist) { + t.Errorf("over-cap download left a file at the cache path (stat err = %v)", err) + } + }) + + t.Run("hash mismatch rejected, cache not written", func(t *testing.T) { + ateompath.StaticFilesDir = t.TempDir() + maxAssetBytes = origCap + wrongHash := strings.Repeat("a", 64) // valid 64-hex format, wrong value + s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}} + if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: wrongHash}); err == nil { + t.Fatal("fetchAsset accepted a hash mismatch") + } + if _, err := os.Stat(ateompath.RunSCBinaryPath(wrongHash)); !errors.Is(err, os.ErrNotExist) { + t.Errorf("mismatched download left a file at the cache path (stat err = %v)", err) + } + }) +} + // TestRPCBoundariesReject confirms each of the three RPCs validates path inputs // before touching its (here nil) dependencies. A traversal value must be // rejected as InvalidArgument rather than panicking or surfacing as diff --git a/cmd/atelet/sandbox_assets.go b/cmd/atelet/sandbox_assets.go index 9b06d2a7d..876a4b9f0 100644 --- a/cmd/atelet/sandbox_assets.go +++ b/cmd/atelet/sandbox_assets.go @@ -22,6 +22,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" "runtime" @@ -40,6 +41,10 @@ import ( // so a Restore — possibly on another node — is self-describing. const sandboxManifestName = "manifest.json" +// maxAssetBytes guards disk against an unbounded download URL; a var so tests can lower it. +// ponytail: 8GiB ceiling, make it a flag if a rootfs ever needs more. +var maxAssetBytes int64 = 8 << 30 + // assetEntry is one content-addressed sandbox asset (url + sha256). type assetEntry struct { URL string `json:"url"` @@ -112,41 +117,46 @@ func (s *AteomHerder) fetchAsset(ctx context.Context, entry assetEntry) (string, // gVisor's runsc lives in the public gs://gvisor bucket, so the anonymous // client suffices. TODO: drive authenticated asset fetches from atelet // configuration for assets in private buckets. - content, err := ategcs.FetchFromGCS(ctx, s.anonGCSClient, entry.URL) + rc, err := ategcs.Open(ctx, s.anonGCSClient, entry.URL) if err != nil { return "", fmt.Errorf("while fetching %v: %w", entry.URL, err) } + defer rc.Close() - sum := sha256.Sum256(content) wantSum, err := hex.DecodeString(entry.SHA256) if err != nil { return "", fmt.Errorf("while parsing sha256 hash: %w", err) } - if !bytes.Equal(sum[:], wantSum) { - return "", fmt.Errorf("sha256 mismatch; got=%s want=%s", hex.EncodeToString(sum[:]), entry.SHA256) - } - - tmpFileName, err := func() (string, error) { - localDir := filepath.Dir(localPath) - tmpFile, err := os.CreateTemp(localDir, filepath.Base(localPath)+"-download-") - if err != nil { - return "", fmt.Errorf("while temp file: %w", err) - } - defer tmpFile.Close() - - if _, err := tmpFile.Write(content); err != nil { - return "", fmt.Errorf("while writing content to temp file: %w", err) - } - if err := tmpFile.Chmod(0o755); err != nil { - return "", fmt.Errorf("while setting file mode: %w", err) - } - return tmpFile.Name(), nil - }() + + tmpFile, err := os.CreateTemp(filepath.Dir(localPath), filepath.Base(localPath)+"-download-") if err != nil { - return "", fmt.Errorf("while populating temp file: %w", err) + return "", fmt.Errorf("while creating temp file: %w", err) } + tmpName := tmpFile.Name() + defer os.Remove(tmpName) // partial-download cleanup; no-op after rename + defer tmpFile.Close() - if err := os.Rename(tmpFileName, localPath); err != nil { + // Stream to disk, hashing as we go; +1 lets an over-cap asset trip n > cap. + // Verify-after-copy keeps a bad download at the temp path, never the cache. + hasher := sha256.New() + n, err := io.Copy(io.MultiWriter(tmpFile, hasher), io.LimitReader(rc, maxAssetBytes+1)) + if err != nil { + return "", fmt.Errorf("while downloading %v: %w", entry.URL, err) + } + if n > maxAssetBytes { + return "", fmt.Errorf("asset %v exceeds %d-byte cap", entry.URL, maxAssetBytes) + } + if got := hasher.Sum(nil); !bytes.Equal(got, wantSum) { + return "", fmt.Errorf("sha256 mismatch; got=%x want=%s", got, entry.SHA256) + } + + if err := tmpFile.Chmod(0o755); err != nil { + return "", fmt.Errorf("while setting file mode: %w", err) + } + if err := tmpFile.Close(); err != nil { // flush before rename + return "", fmt.Errorf("while closing temp file: %w", err) + } + if err := os.Rename(tmpName, localPath); err != nil { return "", fmt.Errorf("while renaming temp file to target: %w", err) }