Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/hypervisor/firecracker/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ func toSnapshotCreateParams(snapshotDir string) snapshotCreateParams {
}
}

func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride) snapshotLoadParams {
func toSnapshotLoadParams(snapshotDir string, networkOverrides []networkOverride, resumeVM bool) snapshotLoadParams {
return snapshotLoadParams{
MemFilePath: snapshotMemoryPath(snapshotDir),
SnapshotPath: snapshotStatePath(snapshotDir),
EnableDiffSnapshots: true,
ResumeVM: false,
ResumeVM: resumeVM,
NetworkOverrides: networkOverrides,
}
}
Expand Down
4 changes: 2 additions & 2 deletions lib/hypervisor/firecracker/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ func TestSnapshotParamPaths(t *testing.T) {

load := toSnapshotLoadParams("/tmp/snapshot-latest", []networkOverride{
{IfaceID: "eth0", HostDevName: "hype-abc123"},
})
}, true)
assert.Equal(t, "/tmp/snapshot-latest/state", load.SnapshotPath)
assert.Equal(t, "/tmp/snapshot-latest/memory", load.MemFilePath)
assert.True(t, load.EnableDiffSnapshots)
assert.False(t, load.ResumeVM)
assert.True(t, load.ResumeVM)
require.Len(t, load.NetworkOverrides, 1)
}

Expand Down
13 changes: 9 additions & 4 deletions lib/hypervisor/firecracker/firecracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ type apiError struct {

// Firecracker implements hypervisor.Hypervisor for the Firecracker VMM.
type Firecracker struct {
socketPath string
client *http.Client
socketPath string
client *http.Client
restoredResumed bool
}

func New(socketPath string) (*Firecracker, error) {
Expand All @@ -50,6 +51,10 @@ func New(socketPath string) (*Firecracker, error) {

var _ hypervisor.Hypervisor = (*Firecracker)(nil)

func (f *Firecracker) RestoredResumed() bool {
return f != nil && f.restoredResumed
}

func (f *Firecracker) Capabilities() hypervisor.Capabilities {
return capabilities()
}
Expand Down Expand Up @@ -223,8 +228,8 @@ func (f *Firecracker) instanceStart(ctx context.Context) error {
return f.postAction(ctx, "InstanceStart")
}

func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride) error {
params := toSnapshotLoadParams(snapshotDir, networkOverrides)
func (f *Firecracker) loadSnapshot(ctx context.Context, snapshotDir string, networkOverrides []networkOverride, resumeVM bool) error {
params := toSnapshotLoadParams(snapshotDir, networkOverrides, resumeVM)
if _, err := f.do(ctx, http.MethodPut, "/snapshot/load", params, http.StatusNoContent); err != nil {
return err
}
Expand Down
45 changes: 33 additions & 12 deletions lib/hypervisor/firecracker/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import (
)

const (
socketWaitTimeout = 10 * time.Second
socketPollEvery = 50 * time.Millisecond
socketDialTimeout = 100 * time.Millisecond
socketWaitTimeout = 10 * time.Second
socketReadyRetryEvery = 1 * time.Millisecond
socketDialTimeout = 100 * time.Millisecond
restoreResumeOnLoadEnv = "HYPEMAN_FIRECRACKER_RESTORE_RESUME_ON_LOAD"
)

func init() {
Expand Down Expand Up @@ -115,16 +116,18 @@ func (s *Starter) RestoreVM(ctx context.Context, p *paths.Paths, version string,
if err != nil {
return 0, nil, fmt.Errorf("load firecracker restore metadata: %w", err)
}
resumeOnLoad := shouldResumeOnSnapshotLoad()
err = func() error {
snapshotSourceAliasMu.Lock()
defer snapshotSourceAliasMu.Unlock()
return withSnapshotSourceDirAlias(meta, filepath.Dir(socketPath), func() error {
return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides)
return hv.loadSnapshot(ctx, snapshotPath, meta.NetworkOverrides, resumeOnLoad)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guest runs before alias cleanup

Medium Severity

With resume-on-load enabled, snapshot load can start the guest while still inside withSnapshotSourceDirAlias, before the temporary source-data symlink is removed. Restore then skips the separate Resume call when RestoredResumed is set, so alias restores no longer guarantee the guest stays paused until after that teardown finishes.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit b7c7c05. Configure here.

})
}()
if err != nil {
return 0, nil, fmt.Errorf("load firecracker snapshot: %w", err)
}
hv.restoredResumed = resumeOnLoad
if meta.SnapshotSourceDataDir != "" && !meta.RetainSnapshotSourceDataDirAlias {
meta.SnapshotSourceDataDir = ""
if err := saveRestoreMetadataState(filepath.Dir(socketPath), meta); err != nil {
Expand Down Expand Up @@ -244,23 +247,41 @@ func (s *Starter) startProcess(_ context.Context, p *paths.Paths, version string
}

func isSocketInUse(socketPath string) bool {
conn, err := net.DialTimeout("unix", socketPath, socketDialTimeout)
if err != nil {
return tryDialUnixSocket(socketPath) == nil
}

func shouldResumeOnSnapshotLoad() bool {
if envBoolDisabled(os.Getenv(restoreResumeOnLoadEnv)) {
return false
}
_ = conn.Close()
return true
}

func waitForSocket(path string, timeout time.Duration) error {
func envBoolDisabled(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "0", "false", "no", "off":
return true
default:
return false
}
}

func tryDialUnixSocket(path string) error {
conn, err := net.DialTimeout("unix", path, socketDialTimeout)
if err != nil {
return err
}
_ = conn.Close()
return nil
}

func waitForSocketByPolling(path string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("unix", path, socketDialTimeout)
if err == nil {
_ = conn.Close()
if err := tryDialUnixSocket(path); err == nil {
return nil
}
time.Sleep(socketPollEvery)
time.Sleep(socketReadyRetryEvery)
}
return fmt.Errorf("timeout waiting for socket")
}
37 changes: 37 additions & 0 deletions lib/hypervisor/firecracker/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package firecracker

import (
"errors"
"net"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -78,3 +80,38 @@ func TestWithSnapshotSourceDirAlias_RejectsNestedPaths(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "must not be nested")
}

func TestShouldResumeOnSnapshotLoad(t *testing.T) {
t.Setenv(restoreResumeOnLoadEnv, "")
assert.True(t, shouldResumeOnSnapshotLoad())

t.Setenv(restoreResumeOnLoadEnv, "0")
assert.False(t, shouldResumeOnSnapshotLoad())
}

func TestWaitForSocketReturnsWhenSocketAppears(t *testing.T) {
tmp, err := os.MkdirTemp("/tmp", "fcwait-")
require.NoError(t, err)
t.Cleanup(func() { _ = os.RemoveAll(tmp) })
socketPath := filepath.Join(tmp, "fc.sock")
done := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(done)
time.Sleep(10 * time.Millisecond)
listener, err := net.Listen("unix", socketPath)
if err != nil {
errCh <- err
return
}
errCh <- nil
defer listener.Close()
<-time.After(50 * time.Millisecond)
}()

start := time.Now()
require.NoError(t, waitForSocket(socketPath, time.Second))
assert.Less(t, time.Since(start), 250*time.Millisecond)
require.NoError(t, <-errCh)
<-done
}
96 changes: 96 additions & 0 deletions lib/hypervisor/firecracker/socket_wait_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//go:build linux

package firecracker

import (
"fmt"
"os"
"path/filepath"
"time"

"golang.org/x/sys/unix"
)

func waitForSocket(path string, timeout time.Duration) error {
if err := tryDialUnixSocket(path); err == nil {
return nil
}

parent := filepath.Dir(path)
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we attempt to not wait via polling, then we fall back to polling

if err != nil {
return waitForSocketByPolling(path, timeout)
}
defer unix.Close(fd)

wd, err := unix.InotifyAddWatch(fd, parent, unix.IN_CREATE|unix.IN_MOVED_TO|unix.IN_ATTRIB)
if err != nil {
return waitForSocketByPolling(path, timeout)
}
defer unix.InotifyRmWatch(fd, uint32(wd))

deadline := time.Now().Add(timeout)
buf := make([]byte, 4096)
for {
if err := tryDialUnixSocket(path); err == nil {
return nil
}
remaining := time.Until(deadline)
if remaining <= 0 {
return fmt.Errorf("timeout waiting for socket")
}

pollTimeout := remaining
if socketPathExists(path) {
pollTimeout = minDuration(pollTimeout, socketReadyRetryEvery)
}
events := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
n, err := unix.Poll(events, durationMillisCeil(pollTimeout))
if err != nil {
if err == unix.EINTR {
continue
}
return waitForSocketByPolling(path, remaining)
}
if n > 0 {
for {
n, err := unix.Read(fd, buf)
if err != nil {
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
break
}
return waitForSocketByPolling(path, time.Until(deadline))
}
if n == 0 {
break
}
}
}
}
}

func socketPathExists(path string) bool {
_, err := os.Lstat(path)
return err == nil
}

func durationMillisCeil(d time.Duration) int {
if d <= 0 {
return 0
}
ms := d / time.Millisecond
if d%time.Millisecond != 0 {
ms++
}
if int64(ms) > int64(^uint(0)>>1) {
return int(^uint(0) >> 1)
}
return int(ms)
}

func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
9 changes: 9 additions & 0 deletions lib/hypervisor/firecracker/socket_wait_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !linux

package firecracker

import "time"

func waitForSocket(path string, timeout time.Duration) error {
return waitForSocketByPolling(path, timeout)
}
13 changes: 12 additions & 1 deletion lib/hypervisor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ type VMStarter interface {
// Each hypervisor implements its own restore flow:
// - Cloud Hypervisor: starts process, calls Restore API
// - QEMU: would start with -incoming or -loadvm flags (not yet implemented)
// Returns the process ID and a Hypervisor client. The VM is in paused state after restore.
// Returns the process ID and a Hypervisor client. The VM is usually paused
// after restore, unless the returned client reports RestoredResumed.
RestoreVM(ctx context.Context, p *paths.Paths, version string, socketPath string, snapshotPath string) (pid int, hv Hypervisor, err error)

// PrepareFork allows hypervisors to prepare forked instance state.
Expand Down Expand Up @@ -202,6 +203,16 @@ type Hypervisor interface {
Capabilities() Capabilities
}

type restoredResumedHypervisor interface {
RestoredResumed() bool
}

// RestoredResumed reports whether RestoreVM already resumed guest execution.
func RestoredResumed(hv Hypervisor) bool {
resumed, ok := hv.(restoredResumedHypervisor)
return ok && resumed.RestoredResumed()
}

// Capabilities indicates which optional features a hypervisor supports.
// Callers should check these before calling optional methods.
type Capabilities struct {
Expand Down
4 changes: 4 additions & 0 deletions lib/hypervisor/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ func (h *tracingHypervisor) Capabilities() Capabilities {
return h.next.Capabilities()
}

func (h *tracingHypervisor) RestoredResumed() bool {
return RestoredResumed(h.next)
}

func (h *tracingHypervisor) spanAttrs(attrs ...attribute.KeyValue) []attribute.KeyValue {
out := make([]attribute.KeyValue, 0, len(h.attrs)+len(attrs))
out = append(out, h.attrs...)
Expand Down
9 changes: 9 additions & 0 deletions lib/hypervisor/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (

type fakeHypervisor struct{}
type fakeHypervisorGetVMInfoError struct{}
type fakeRestoredResumedHypervisor struct {
fakeHypervisor
}

func (fakeHypervisor) DeleteVM(context.Context) error { return nil }
func (fakeHypervisor) Shutdown(context.Context) error { return nil }
Expand All @@ -36,6 +39,7 @@ func (fakeHypervisor) GetTargetGuestMemoryBytes(context.Context) (int64, error)
return 0, nil
}
func (fakeHypervisor) Capabilities() Capabilities { return Capabilities{} }
func (fakeRestoredResumedHypervisor) RestoredResumed() bool { return true }
func (fakeHypervisorGetVMInfoError) DeleteVM(context.Context) error { return nil }
func (fakeHypervisorGetVMInfoError) Shutdown(context.Context) error { return nil }
func (fakeHypervisorGetVMInfoError) GetVMInfo(context.Context) (*VMInfo, error) {
Expand Down Expand Up @@ -123,6 +127,11 @@ func TestWrapVMStarterWrapsReturnedHypervisor(t *testing.T) {
assert.Equal(t, string(TypeCloudHypervisor), attrs["hypervisor"])
}

func TestWrapHypervisorPreservesRestoredResumed(t *testing.T) {
hv := WrapHypervisor(TypeFirecracker, fakeRestoredResumedHypervisor{})
require.True(t, RestoredResumed(hv))
}

func TestWrapHypervisorSkipsGetVMInfoTraceByDefault(t *testing.T) {
recorder, _ := newTestTracerProvider(t)

Expand Down
4 changes: 3 additions & 1 deletion lib/instances/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ func (m *manager) restoreInstance(
attribute.String("operation", "resume_vm"),
)
log.InfoContext(ctx, "resuming VM", "instance_id", id)
if err := hv.Resume(resumeCtx); err != nil {
if hypervisor.RestoredResumed(hv) {
log.InfoContext(ctx, "VM was resumed during snapshot load", "instance_id", id)
} else if err := hv.Resume(resumeCtx); err != nil {
resumeSpanEnd(err)
log.ErrorContext(ctx, "failed to resume VM", "instance_id", id, "error", err)
// Cleanup on failure
Expand Down
Loading