Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/atelet/internal/ategcs/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ func FetchFromGCS(ctx context.Context, client ObjectStorage, gsURL string) ([]by
return content, nil
}

// Open streams the object at gsURL; the caller must Close the returned reader.
// Unlike FetchFromGCS it does not buffer the whole object in memory.
func Open(ctx context.Context, client ObjectStorage, gsURL string) (io.ReadCloser, error) {
bucket, object, err := parseGCSURL(gsURL)
if err != nil {
return nil, fmt.Errorf("while parsing url: %w", err)
}
rc, err := client.GetObject(ctx, bucket, object)
if err != nil {
return nil, fmt.Errorf("while getting object bucket=%q object=%q: %w", bucket, object, err)
}
return rc, nil
}

// SendBytesToGCS uploads the given bytes (uncompressed) to gsURL. Intended for
// small objects such as the snapshot manifest.
func SendBytesToGCS(ctx context.Context, client ObjectStorage, gsURL string, content []byte) error {
Expand Down
73 changes: 73 additions & 0 deletions cmd/atelet/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
package main

import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"

"github.com/agent-substrate/substrate/internal/ateompath"
Expand Down Expand Up @@ -259,6 +265,73 @@ func TestFetchAssetRejectsBadHash(t *testing.T) {
}
}

// fakeObjectStorage serves fixed bytes for GetObject so fetchAsset can be tested.
type fakeObjectStorage struct {
data []byte
err error
}

func (f fakeObjectStorage) GetObject(_ context.Context, _, _ string) (io.ReadCloser, error) {
if f.err != nil {
return nil, f.err
}
return io.NopCloser(bytes.NewReader(f.data)), nil
}

func (fakeObjectStorage) PutObject(_ context.Context, _, _ string, _ io.Reader) error { return nil }

// TestFetchAssetStreaming covers the streamed download: good asset cached,
// over-cap rejected, hash mismatch rejected (failures leave no cache file).
func TestFetchAssetStreaming(t *testing.T) {
origDir, origCap := ateompath.StaticFilesDir, maxAssetBytes
t.Cleanup(func() { ateompath.StaticFilesDir, maxAssetBytes = origDir, origCap })

content := []byte("micro-vm kernel bytes")
goodHash := fmt.Sprintf("%x", sha256.Sum256(content))
const url = "gs://test-bucket/asset"

t.Run("good asset is cached", func(t *testing.T) {
ateompath.StaticFilesDir = t.TempDir()
s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}}
path, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash})
if err != nil {
t.Fatalf("fetchAsset: %v", err)
}
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("reading cached asset: %v", err)
}
if !bytes.Equal(got, content) {
t.Errorf("cached bytes = %q, want %q", got, content)
}
})

t.Run("over-cap asset rejected, cache not written", func(t *testing.T) {
ateompath.StaticFilesDir = t.TempDir()
maxAssetBytes = 4 // content is longer than this
s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}}
if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: goodHash}); err == nil {
t.Fatal("fetchAsset accepted an over-cap asset")
}
if _, err := os.Stat(ateompath.RunSCBinaryPath(goodHash)); !errors.Is(err, os.ErrNotExist) {
t.Errorf("over-cap download left a file at the cache path (stat err = %v)", err)
}
})

t.Run("hash mismatch rejected, cache not written", func(t *testing.T) {
ateompath.StaticFilesDir = t.TempDir()
maxAssetBytes = origCap
wrongHash := strings.Repeat("a", 64) // valid 64-hex format, wrong value
s := &AteomHerder{anonGCSClient: fakeObjectStorage{data: content}}
if _, err := s.fetchAsset(context.Background(), assetEntry{URL: url, SHA256: wrongHash}); err == nil {
t.Fatal("fetchAsset accepted a hash mismatch")
}
if _, err := os.Stat(ateompath.RunSCBinaryPath(wrongHash)); !errors.Is(err, os.ErrNotExist) {
t.Errorf("mismatched download left a file at the cache path (stat err = %v)", err)
}
})
}

// TestRPCBoundariesReject confirms each of the three RPCs validates path inputs
// before touching its (here nil) dependencies. A traversal value must be
// rejected as InvalidArgument rather than panicking or surfacing as
Expand Down
58 changes: 34 additions & 24 deletions cmd/atelet/sandbox_assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
Expand All @@ -40,6 +41,10 @@ import (
// so a Restore — possibly on another node — is self-describing.
const sandboxManifestName = "manifest.json"

// maxAssetBytes guards disk against an unbounded download URL; a var so tests can lower it.
// ponytail: 8GiB ceiling, make it a flag if a rootfs ever needs more.
var maxAssetBytes int64 = 8 << 30

// assetEntry is one content-addressed sandbox asset (url + sha256).
type assetEntry struct {
URL string `json:"url"`
Expand Down Expand Up @@ -112,41 +117,46 @@ func (s *AteomHerder) fetchAsset(ctx context.Context, entry assetEntry) (string,
// gVisor's runsc lives in the public gs://gvisor bucket, so the anonymous
// client suffices. TODO: drive authenticated asset fetches from atelet
// configuration for assets in private buckets.
content, err := ategcs.FetchFromGCS(ctx, s.anonGCSClient, entry.URL)
rc, err := ategcs.Open(ctx, s.anonGCSClient, entry.URL)
if err != nil {
return "", fmt.Errorf("while fetching %v: %w", entry.URL, err)
}
defer rc.Close()

sum := sha256.Sum256(content)
wantSum, err := hex.DecodeString(entry.SHA256)
if err != nil {
return "", fmt.Errorf("while parsing sha256 hash: %w", err)
}
if !bytes.Equal(sum[:], wantSum) {
return "", fmt.Errorf("sha256 mismatch; got=%s want=%s", hex.EncodeToString(sum[:]), entry.SHA256)
}

tmpFileName, err := func() (string, error) {
localDir := filepath.Dir(localPath)
tmpFile, err := os.CreateTemp(localDir, filepath.Base(localPath)+"-download-")
if err != nil {
return "", fmt.Errorf("while temp file: %w", err)
}
defer tmpFile.Close()

if _, err := tmpFile.Write(content); err != nil {
return "", fmt.Errorf("while writing content to temp file: %w", err)
}
if err := tmpFile.Chmod(0o755); err != nil {
return "", fmt.Errorf("while setting file mode: %w", err)
}
return tmpFile.Name(), nil
}()

tmpFile, err := os.CreateTemp(filepath.Dir(localPath), filepath.Base(localPath)+"-download-")
if err != nil {
return "", fmt.Errorf("while populating temp file: %w", err)
return "", fmt.Errorf("while creating temp file: %w", err)
}
tmpName := tmpFile.Name()
defer os.Remove(tmpName) // partial-download cleanup; no-op after rename
defer tmpFile.Close()

if err := os.Rename(tmpFileName, localPath); err != nil {
// Stream to disk, hashing as we go; +1 lets an over-cap asset trip n > cap.
// Verify-after-copy keeps a bad download at the temp path, never the cache.
hasher := sha256.New()
n, err := io.Copy(io.MultiWriter(tmpFile, hasher), io.LimitReader(rc, maxAssetBytes+1))
if err != nil {
return "", fmt.Errorf("while downloading %v: %w", entry.URL, err)
}
if n > maxAssetBytes {
return "", fmt.Errorf("asset %v exceeds %d-byte cap", entry.URL, maxAssetBytes)
}
if got := hasher.Sum(nil); !bytes.Equal(got, wantSum) {
return "", fmt.Errorf("sha256 mismatch; got=%x want=%s", got, entry.SHA256)
}

if err := tmpFile.Chmod(0o755); err != nil {
return "", fmt.Errorf("while setting file mode: %w", err)
}
if err := tmpFile.Close(); err != nil { // flush before rename
return "", fmt.Errorf("while closing temp file: %w", err)
}
if err := os.Rename(tmpName, localPath); err != nil {
return "", fmt.Errorf("while renaming temp file to target: %w", err)
}

Expand Down
Loading