diff --git a/internal/controller/linuxcontainer/container.go b/internal/controller/linuxcontainer/container.go new file mode 100644 index 0000000000..0fc0e9c21f --- /dev/null +++ b/internal/controller/linuxcontainer/container.go @@ -0,0 +1,605 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "fmt" + "sync" + "time" + + runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" + "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" + "github.com/Microsoft/hcsshim/internal/controller/process" + "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/hcs/schema1" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oci" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/signals" + "github.com/Microsoft/hcsshim/internal/vm/vmutils" + + "github.com/Microsoft/go-winio/pkg/guid" + eventstypes "github.com/containerd/containerd/api/events" + "github.com/containerd/containerd/api/runtime/task/v2" + containerdtypes "github.com/containerd/containerd/api/types/task" + "github.com/containerd/errdefs" + "github.com/containerd/typeurl/v2" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Controller is the concrete implementation of the LCOW container controller. +// It manages the full lifecycle of a single LCOW container. +type Controller struct { + // mu guards all mutable fields in this struct. + mu sync.RWMutex + + // vmID is the identifier of the utility VM that hosts this container. + vmID string + + // gcsPodID is the sandbox/pod identifier within the GCS. + gcsPodID string + + // containerID is the unique identifier for this container. + // This is the containerd-visible identifier. + containerID string + + // gcsContainerID is the identifier for the container used + // while interacting with GCS. + gcsContainerID string + + // guest is used to create and manage the GCS container entity. + guest guest + + // scsi manages SCSI disk attachments for the container. + scsi scsiController + + // plan9 manages Plan9 file-share mounts for the container. + plan9 plan9Controller + + // vpci manages virtual PCI device assignments for the container. + vpci vPCIController + + // Host-side resource reservations released during teardown. + layers *scsiLayers + scsiResources []guid.GUID + plan9Resources []guid.GUID + devices []guid.GUID + + // container is the GCS container handle used for lifecycle operations. + container *gcs.Container + + // state tracks the current lifecycle state of the container. + // Access must be guarded by mu. + state State + + // closeOnce ensures closeContainer executes its teardown exactly once. + closeOnce sync.Once + + // terminatedCh is closed exactly once when the container is closed. + // All callers of Wait block on this channel, and closing it unblocks + // every waiter simultaneously. + terminatedCh chan struct{} + + // processes maps exec IDs to their process controllers. + // The init process is stored with exec ID "". + // Access must be guarded by mu. + processes map[string]*process.Controller + + // ioRetryTimeout is the duration to retry IO relay operations before giving up. + ioRetryTimeout time.Duration +} + +// New creates a ready-to-use Controller. +func New( + vmID string, + gcsPodID string, + containerID string, + guestMgr guest, + scsiCtrl scsiController, + plan9Ctrl plan9Controller, + vpci vPCIController, +) *Controller { + return &Controller{ + vmID: vmID, + gcsPodID: gcsPodID, + containerID: containerID, + // Same id is used as the container. Post migration, we can always + // change the primary ID while gcs uses the original ID. + gcsContainerID: containerID, + guest: guestMgr, + scsi: scsiCtrl, + plan9: plan9Ctrl, + vpci: vpci, + processes: make(map[string]*process.Controller), + state: StateNotCreated, + terminatedCh: make(chan struct{}), + } +} + +// Create allocates host-side resources, creates the container in the guest, +// and sets up the init process. +func (c *Controller) Create(ctx context.Context, spec *specs.Spec, opts *task.CreateTaskRequest, copts *CreateContainerOpts) (err error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.GCSContainerID, c.gcsContainerID)) + log.G(ctx).Debug("creating container") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateNotCreated { + return fmt.Errorf("container %s is in state %s; cannot create: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + // Parse the runtime options from the request. + shimOpts, err := vmutils.UnmarshalRuntimeOptions(ctx, opts.Options) + if err != nil { + return fmt.Errorf("unmarshal runtime options: %w", err) + } + + // Apply any updates to the OCI spec based on the shim options. + *spec = oci.UpdateSpecFromOptions(*spec, shimOpts) + + // Expand annotations after defaults have been loaded in from options. + // Since annotation expansion is used to toggle security features, + // raise the error rather than suppress and move on. + if err = oci.ProcessAnnotations(ctx, spec.Annotations); err != nil { + return fmt.Errorf("process OCI spec annotations: %w", err) + } + + // Upon any failure from this point onwards, perform a teardown + // of container and set state as invalid. + defer func() { + if err != nil { + c.state = StateInvalid + c.closeContainer(ctx) + } + }() + + // Allocate all host-side resources and build the GCS container document. + gcsDocument, err := c.generateContainerDocument(ctx, spec, opts.Rootfs, copts.IsScratchEncryptionEnabled) + if err != nil { + return fmt.Errorf("generate container document: %w", err) + } + + // Create the container within the UVM. + c.container, err = c.guest.CreateContainer(ctx, c.gcsContainerID, gcsDocument) + if err != nil { + return fmt.Errorf("create container in guest: %w", err) + } + + // Default to an infinite timeout (zero value). + if shimOpts != nil { + c.ioRetryTimeout = time.Duration(shimOpts.IoRetryTimeoutInSec) * time.Second + } + + // Create the initial process controller with exec ID "". + initProcess := process.New(c.containerID, "", c.container, c.ioRetryTimeout) + if err = initProcess.Create(ctx, &process.CreateOptions{ + Bundle: opts.Bundle, + Terminal: opts.Terminal, + Stdin: opts.Stdin, + Stdout: opts.Stdout, + Stderr: opts.Stderr, + }); err != nil { + return fmt.Errorf("create init process: %w", err) + } + c.processes[""] = initProcess + + c.state = StateCreated + return nil +} + +// closeContainer performs full container teardown exactly once. +func (c *Controller) closeContainer(ctx context.Context) { + c.closeOnce.Do(func() { + c.releaseResources(ctx) + + // Delete guest-side container state after resources have been released. + if c.container != nil { + // Delete the container state if supported. + if c.guest.Capabilities().IsDeleteContainerStateSupported() { + if err := c.guest.DeleteContainerState(ctx, c.gcsContainerID); err != nil { + log.G(ctx).WithError(err).Error("failed to delete container state") + } + } + + // Close the container handle. + _ = c.container.Close() + } + + // Release all waiters. + close(c.terminatedCh) + }) +} + +// releaseResources undoes each allocation in reverse order. +// It is idempotent — subsequent calls after the first are no-ops. +func (c *Controller) releaseResources(ctx context.Context) { + // Combined layers must be removed before unmapping the underlying SCSI + // layer devices. + if c.layers != nil && c.layers.layersCombined { + var hcsLayers []hcsschema.Layer + for _, layer := range c.layers.roLayers { + hcsLayers = append(hcsLayers, hcsschema.Layer{Path: layer.guestPath}) + } + + if err := c.guest.RemoveLCOWCombinedLayers(ctx, guestresource.LCOWCombinedLayers{ + ContainerID: c.gcsContainerID, + ContainerRootPath: c.layers.rootfsPath, + Layers: hcsLayers, + ScratchPath: c.layers.scratch.guestPath, + }); err != nil { + log.G(ctx).WithError(err).Error("failed to remove combined layers from guest") + } + } + + // Unmap layers (scratch + RO layers). + if c.layers != nil { + if err := c.scsi.UnmapFromGuest(ctx, c.layers.scratch.id); err != nil { + log.G(ctx).WithError(err).Error("failed to unmap scratch layer") + } + + for _, layer := range c.layers.roLayers { + if err := c.scsi.UnmapFromGuest(ctx, layer.id); err != nil { + log.G(ctx).WithError(err).Error("failed to unmap ro layer") + } + } + } + + // Unmap additional SCSI mounts. + for _, id := range c.scsiResources { + if err := c.scsi.UnmapFromGuest(ctx, id); err != nil { + log.G(ctx).WithError(err).Error("failed to unmap scsi resource") + } + } + + // Unmap Plan9 shares. + for _, id := range c.plan9Resources { + if err := c.plan9.UnmapFromGuest(ctx, id); err != nil { + log.G(ctx).WithError(err).Error("failed to unmap plan9 share") + } + } + + // Remove VPCI devices. + for _, id := range c.devices { + if err := c.vpci.RemoveFromVM(ctx, id); err != nil { + log.G(ctx).WithError(err).Error("failed to remove vpci device") + } + } + + // Clear all resource references so a repeated call is a no-op, and the + // GC can reclaim the slices. + c.layers = nil + c.scsiResources = nil + c.plan9Resources = nil + c.devices = nil +} + +// Start starts the container and its init process, returning the init PID. +func (c *Controller) Start(ctx context.Context, events chan interface{}) (uint32, error) { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.GCSContainerID, c.gcsContainerID)) + log.G(ctx).Debug("starting container") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateCreated { + return 1, fmt.Errorf("container %s is in state %s; cannot start: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + // Start the container. + if err := c.container.Start(ctx); err != nil { + c.state = StateInvalid + c.closeContainer(ctx) + return 1, fmt.Errorf("start container %s: %w", c.containerID, err) + } + + // Start the init process. Pass nil for sendEvent because the init + // process exit event is published by handleInitProcessExit after + // full container teardown. + initProcess := c.processes[""] + pid, err := initProcess.Start(ctx, nil) + if err != nil { + c.state = StateInvalid + c.closeContainer(ctx) + return 1, fmt.Errorf("start init process: %w", err) + } + + c.state = StateRunning + go c.handleInitProcessExit(ctx, initProcess, events) + + return uint32(pid), nil +} + +// handleInitProcessExit blocks until the init process exits, then tears down +// the container, marks it stopped, and publishes the exit event. +func (c *Controller) handleInitProcessExit(ctx context.Context, initProcess *process.Controller, events chan interface{}) { + // Detach from the caller's context so upstream cancellation/timeout does + // not abort the background teardown. + ctx = context.WithoutCancel(ctx) + + // Block until the init process exits. + initProcess.Wait(ctx) + + c.mu.Lock() + c.state = StateStopped + c.closeContainer(ctx) + c.mu.Unlock() + + // Publish the exit event after teardown is complete. + if events != nil { + status := initProcess.Status(true) + events <- &eventstypes.TaskExit{ + ContainerID: c.containerID, + ID: status.ExecID, + Pid: status.Pid, + ExitStatus: status.ExitStatus, + ExitedAt: status.ExitedAt, + } + } +} + +// Wait blocks until the container has fully terminated. +func (c *Controller) Wait(ctx context.Context) { + select { + case <-c.terminatedCh: + case <-ctx.Done(): + log.G(ctx).WithError(ctx.Err()).Error("wait for container to exit failed") + } +} + +// Update modifies the container's resource constraints. +func (c *Controller) Update(ctx context.Context, resources interface{}) error { + ctx, _ = log.WithContext(ctx, logrus.WithField(logfields.GCSContainerID, c.gcsContainerID)) + log.G(ctx).Debug("updating container") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateRunning { + return fmt.Errorf("container %s is in state %s; cannot update: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + linuxRes, ok := resources.(*specs.LinuxResources) + if !ok { + return fmt.Errorf("invalid container resources: expected *specs.LinuxResources, got %T", resources) + } + + return c.container.Modify(ctx, guestrequest.ModificationRequest{ + ResourceType: guestresource.ResourceTypeContainerConstraints, + RequestType: guestrequest.RequestTypeUpdate, + Settings: guestresource.LCOWContainerConstraints{ + Linux: *linuxRes, + }, + }) +} + +// NewProcess creates a new exec process controller in the container. +func (c *Controller) NewProcess(execID string) (*process.Controller, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateRunning { + return nil, fmt.Errorf("container %s is in state %s; cannot create new process: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + if _, exists := c.processes[execID]; exists { + return nil, fmt.Errorf("exec process %q already exists in container %s", execID, c.containerID) + } + + newProcess := process.New(c.containerID, execID, c.container, c.ioRetryTimeout) + c.processes[execID] = newProcess + + return newProcess, nil +} + +// GetProcess returns the process controller for the given exec ID. +func (c *Controller) GetProcess(execID string) (*process.Controller, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.getProcess(execID) +} + +// getProcess retrieves a process by exec ID. +func (c *Controller) getProcess(execID string) (*process.Controller, error) { + // Must be called with c.mu held (read or write). + proc, ok := c.processes[execID] + if !ok { + return nil, fmt.Errorf("process %q not found in container %s: %w", + execID, c.containerID, errdefs.ErrNotFound) + } + return proc, nil +} + +// ListProcesses returns all exec processes (excluding the init process). +func (c *Controller) ListProcesses() (map[string]*process.Controller, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + result := make(map[string]*process.Controller, len(c.processes)) + for id, proc := range c.processes { + if id == "" { + continue + } + result[id] = proc + } + return result, nil +} + +// Pids queries the guest for the full process list and annotates each entry +// with the exec ID from the local process registry. +func (c *Controller) Pids(ctx context.Context) ([]*containerdtypes.ProcessInfo, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.state != StateRunning { + return nil, fmt.Errorf("container %s is in state %s; cannot query pids: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + // Build a pid→execID lookup from locally tracked processes. + pidMap := make(map[int]string, len(c.processes)) + for execID, proc := range c.processes { + pidMap[proc.Pid()] = execID + } + + // Query the guest for the actual process list. + props, err := c.container.Properties(ctx, schema1.PropertyTypeProcessList) + if err != nil { + return nil, fmt.Errorf("fetch container properties: %w", err) + } + + // Build ProcessDetails for each process in the guest. + processes := make([]*containerdtypes.ProcessInfo, len(props.ProcessList)) + for i, proc := range props.ProcessList { + pd := &runhcsopts.ProcessDetails{ + ImageName: proc.ImageName, + CreatedAt: timestamppb.New(proc.CreateTimestamp), + KernelTime_100Ns: proc.KernelTime100ns, + MemoryCommitBytes: proc.MemoryCommitBytes, + MemoryWorkingSetPrivateBytes: proc.MemoryWorkingSetPrivateBytes, + MemoryWorkingSetSharedBytes: proc.MemoryWorkingSetSharedBytes, + ProcessID: proc.ProcessId, + UserTime_100Ns: proc.UserTime100ns, + } + if execID, ok := pidMap[int(proc.ProcessId)]; ok { + pd.ExecID = execID + } + + anyVal, err := typeurl.MarshalAny(pd) + if err != nil { + return nil, fmt.Errorf("marshal process details for exec %s in container %s: %w", pd.ExecID, c.containerID, err) + } + processes[i] = &containerdtypes.ProcessInfo{ + Pid: pd.ProcessID, + Info: typeurl.MarshalProto(anyVal), + } + } + return processes, nil +} + +// Stats returns the runtime statistics for the container. +func (c *Controller) Stats(ctx context.Context) (*stats.Statistics, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.state != StateRunning { + return nil, fmt.Errorf("container %s is in state %s; cannot fetch stats: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + props, err := c.container.PropertiesV2(ctx, hcsschema.PTStatistics) + if err != nil { + return nil, fmt.Errorf("fetch container statistics: %w", err) + } + + containerStats := &stats.Statistics{} + if props != nil { + containerStats.Container = &stats.Statistics_Linux{Linux: props.Metrics} + } + return containerStats, nil +} + +// KillProcess delivers a signal to the specified process or all processes in the container. +func (c *Controller) KillProcess(ctx context.Context, execID string, signal uint32, all bool) error { + if all && execID != "" { + return fmt.Errorf("cannot signal all for non-empty exec %q: %w", execID, errdefs.ErrFailedPrecondition) + } + + signalsSupported := c.guest.Capabilities().IsSignalProcessSupported() + signalOptions, err := signals.ValidateLCOW(int(signal), signalsSupported) + if err != nil { + return fmt.Errorf("validate signal %d for container %s: %w", signal, c.containerID, err) + } + + c.mu.Lock() + defer c.mu.Unlock() + + // The container must have been created for any process to exist. + if c.state == StateNotCreated { + return fmt.Errorf("container %s is in state %s; cannot kill: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + // When "all" is requested, deliver the signal to every additional exec + // on a best-effort basis. Errors are logged but do not prevent the + // target process from being signaled. + if all { + for eid, proc := range c.processes { + if eid == "" { + // The init process is signaled as the explicit target below. + continue + } + if killErr := proc.Kill(ctx, signalOptions); killErr != nil { + log.G(ctx).WithError(killErr).WithField(logfields.ExecID, eid).Warn("failed to kill exec in container") + } + } + } + + // Now signal the actual process identified by execID. + targetProcess, err := c.getProcess(execID) + if err != nil { + return err + } + return targetProcess.Kill(ctx, signalOptions) +} + +// DeleteProcess removes the process identified by execID and returns its last status. +func (c *Controller) DeleteProcess(ctx context.Context, execID string) (*task.StateResponse, error) { + // When deleting the init process, wait for handleInitProcessExit to + // complete container teardown first. + // In short, this prevents race of DeleteProcess with handleInitProcessExit. + if execID == "" { + c.mu.RLock() + isStarted := c.state == StateRunning || c.state == StateStopped + c.mu.RUnlock() + + if isStarted { + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + c.Wait(waitCtx) + if waitCtx.Err() != nil { + return nil, fmt.Errorf("wait for container %s resource cleanup: %w", c.containerID, waitCtx.Err()) + } + } + } + + c.mu.Lock() + defer c.mu.Unlock() + + // The container must have been created for any process to exist. + if c.state == StateNotCreated { + return nil, fmt.Errorf("container %s is in state %s; cannot delete process: %w", c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + proc, err := c.getProcess(execID) + if err != nil { + return nil, err + } + + // Move the process into deleted state. + if err = proc.Delete(ctx); err != nil { + return nil, err + } + + // Capture the process status before removing the entry from map. + status := proc.Status(true) + + // Deleting the init process (execID "") means the container itself is + // being torn down. + if execID == "" { + // For containers that were created but never started, handleInitProcessExit + // was never launched, so closeContainer was never called. Perform full + // teardown now. For already-stopped containers, closeOnce makes this a no-op. + c.closeContainer(ctx) + } + + // Remove the process entry only after all fallible operations have + // succeeded, so that a retry can still locate the process. + delete(c.processes, execID) + + return status, nil +} diff --git a/internal/controller/linuxcontainer/container_test.go b/internal/controller/linuxcontainer/container_test.go new file mode 100644 index 0000000000..258ed96c8b --- /dev/null +++ b/internal/controller/linuxcontainer/container_test.go @@ -0,0 +1,845 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer/mocks" + "github.com/Microsoft/hcsshim/internal/controller/process" + "github.com/Microsoft/hcsshim/internal/gcs" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/signals" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/containerd/errdefs" + "go.uber.org/mock/gomock" +) + +const ( + testVMID = "test-vm" + testPodID = "test-pod" + testContainerID = "test-ctr" +) + +var ( + errUnmapSCSI = errors.New("unmap scsi failed") + errUnmapPlan9 = errors.New("unmap plan9 failed") + errRemoveVPCI = errors.New("remove vpci failed") +) + +// newContainerTestController creates a Controller wired to fresh mock +// controllers for scsi, plan9, vpci, and guest. +func newContainerTestController(t *testing.T) ( + *Controller, + *mocks.MockscsiController, + *mocks.Mockplan9Controller, + *mocks.MockvPCIController, + *mocks.Mockguest, +) { + t.Helper() + ctrl := gomock.NewController(t) + scsiCtrl := mocks.NewMockscsiController(ctrl) + plan9Ctrl := mocks.NewMockplan9Controller(ctrl) + vpciCtrl := mocks.NewMockvPCIController(ctrl) + guestCtrl := mocks.NewMockguest(ctrl) + + c := New(testVMID, testPodID, testContainerID, guestCtrl, scsiCtrl, plan9Ctrl, vpciCtrl) + return c, scsiCtrl, plan9Ctrl, vpciCtrl, guestCtrl +} + +// --- New --- + +// TestNew_InitializesFields verifies that New sets all fields correctly and +// that the initial state is StateNotCreated. +func TestNew_InitializesFields(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + if c.vmID != testVMID { + t.Errorf("vmID = %q, want %q", c.vmID, testVMID) + } + if c.gcsPodID != testPodID { + t.Errorf("gcsPodID = %q, want %q", c.gcsPodID, testPodID) + } + if c.containerID != testContainerID { + t.Errorf("containerID = %q, want %q", c.containerID, testContainerID) + } + if c.gcsContainerID != testContainerID { + t.Errorf("gcsContainerID = %q, want %q", c.gcsContainerID, testContainerID) + } + if c.state != StateNotCreated { + t.Errorf("initial state = %s, want NotCreated", c.state) + } + if c.terminatedCh == nil { + t.Fatal("terminatedCh must not be nil after New") + } + if c.processes == nil { + t.Fatal("processes map must not be nil after New") + } + if len(c.processes) != 0 { + t.Errorf("expected empty processes map, got %d entries", len(c.processes)) + } +} + +// --- Wait --- + +// TestWait_AlreadyClosed verifies that Wait returns immediately when the +// terminated channel is already closed. +func TestWait_AlreadyClosed(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + close(c.terminatedCh) + + // Should return immediately. + doneCh := make(chan struct{}) + go func() { + c.Wait(t.Context()) + close(doneCh) + }() + + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatal("Wait did not return for already-closed channel") + } +} + +// TestWait_ContextCancellation verifies that Wait respects context cancellation. +func TestWait_ContextCancellation(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + ctx, cancel := context.WithCancel(t.Context()) + doneCh := make(chan struct{}) + go func() { + c.Wait(ctx) + close(doneCh) + }() + + // Wait should block until we cancel. + cancel() + + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatal("Wait did not return after context cancellation") + } +} + +// TestWait_UnblocksAllWaiters verifies that closing the terminated channel +// unblocks multiple concurrent waiters. +func TestWait_UnblocksAllWaiters(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + const numWaiters = 5 + var waitGroup sync.WaitGroup + waitGroup.Add(numWaiters) + for range numWaiters { + go func() { + defer waitGroup.Done() + c.Wait(t.Context()) + }() + } + + // Close the terminated channel to unblock all waiters. + close(c.terminatedCh) + + doneCh := make(chan struct{}) + go func() { + waitGroup.Wait() + close(doneCh) + }() + + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatal("not all waiters were unblocked") + } +} + +// --- NewProcess --- + +// TestNewProcess_WrongState verifies that NewProcess rejects calls outside StateRunning. +func TestNewProcess_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateCreated, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + _, err := c.NewProcess("exec-1") + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("NewProcess() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// TestNewProcess_DuplicateExecID verifies that NewProcess rejects a duplicate exec ID. +func TestNewProcess_DuplicateExecID(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateRunning + + // Pre-populate the exec ID. + c.processes["exec-1"] = process.New(testContainerID, "exec-1", nil, 0) + + _, err := c.NewProcess("exec-1") + if err == nil { + t.Fatal("expected error for duplicate exec ID") + } +} + +// TestNewProcess_Success verifies that NewProcess creates and tracks a new +// process controller. +func TestNewProcess_Success(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateRunning + + proc, err := c.NewProcess("exec-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if proc == nil { + t.Fatal("expected non-nil process controller") + } + if len(c.processes) != 1 { + t.Errorf("expected 1 tracked process, got %d", len(c.processes)) + } + if c.processes["exec-1"] != proc { + t.Error("tracked process does not match returned process") + } +} + +// --- GetProcess --- + +// TestGetProcess_Found verifies that GetProcess returns the correct process. +func TestGetProcess_Found(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + initProc := process.New(testContainerID, "", nil, 0) + c.processes[""] = initProc + + got, err := c.GetProcess("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != initProc { + t.Error("returned process does not match stored process") + } +} + +// TestGetProcess_NotFound verifies that GetProcess returns ErrNotFound for +// an unknown exec ID. +func TestGetProcess_NotFound(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + _, err := c.GetProcess("nonexistent") + if !errors.Is(err, errdefs.ErrNotFound) { + t.Errorf("GetProcess() error = %v, want ErrNotFound", err) + } +} + +// --- ListProcesses --- + +// TestListProcesses_Empty verifies that ListProcesses returns an empty map +// when only the init process is registered. +func TestListProcesses_Empty(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.processes[""] = process.New(testContainerID, "", nil, 0) + + result, err := c.ListProcesses() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 0 { + t.Errorf("expected 0 exec processes, got %d", len(result)) + } +} + +// TestListProcesses_ExcludesInit verifies that the init process (exec ID "") +// is excluded from the result while all other execs are returned. +func TestListProcesses_ExcludesInit(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + initProc := process.New(testContainerID, "", nil, 0) + exec1 := process.New(testContainerID, "exec-1", nil, 0) + exec2 := process.New(testContainerID, "exec-2", nil, 0) + c.processes[""] = initProc + c.processes["exec-1"] = exec1 + c.processes["exec-2"] = exec2 + + result, err := c.ListProcesses() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 2 { + t.Fatalf("expected 2 exec processes, got %d", len(result)) + } + if result["exec-1"] != exec1 { + t.Error("exec-1 not found or mismatched") + } + if result["exec-2"] != exec2 { + t.Error("exec-2 not found or mismatched") + } +} + +// TestListProcesses_NoProcesses verifies that ListProcesses returns an empty +// map when the processes map is completely empty. +func TestListProcesses_NoProcesses(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + result, err := c.ListProcesses() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 0 { + t.Errorf("expected 0 processes, got %d", len(result)) + } +} + +// --- KillProcess --- + +// TestKillProcess_AllWithNonEmptyExecID verifies that KillProcess rejects the +// combination of all=true with a non-empty exec ID. +func TestKillProcess_AllWithNonEmptyExecID(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateRunning + + err := c.KillProcess(t.Context(), "exec-1", 15, true) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("KillProcess() error = %v, want ErrFailedPrecondition", err) + } +} + +// TestKillProcess_InvalidSignal verifies that KillProcess rejects invalid +// signals before checking container state. +func TestKillProcess_InvalidSignal(t *testing.T) { + t.Parallel() + c, _, _, _, guestCtrl := newContainerTestController(t) + c.state = StateRunning + + // Signal process is supported but signal value is invalid. + guestCtrl.EXPECT(). + Capabilities(). + Return(&gcs.LCOWGuestDefinedCapabilities{}) + + err := c.KillProcess(t.Context(), "", 999, false) + if !errors.Is(err, signals.ErrInvalidSignal) { + t.Errorf("KillProcess() error = %v, want ErrInvalidSignal", err) + } +} + +// TestKillProcess_NotCreatedState verifies that KillProcess rejects calls +// when the container has not been created yet. +func TestKillProcess_NotCreatedState(t *testing.T) { + t.Parallel() + c, _, _, _, guestCtrl := newContainerTestController(t) + c.state = StateNotCreated + + // SIGTERM (15) with no signal support returns nil signal options. + guestCtrl.EXPECT(). + Capabilities(). + Return(&gcs.LCOWGuestDefinedCapabilities{}) + + err := c.KillProcess(t.Context(), "", 15, false) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("KillProcess() error = %v, want ErrFailedPrecondition", err) + } +} + +// TestKillProcess_ProcessNotFound verifies that KillProcess returns ErrNotFound +// when the target exec ID does not exist. +func TestKillProcess_ProcessNotFound(t *testing.T) { + t.Parallel() + c, _, _, _, guestCtrl := newContainerTestController(t) + c.state = StateRunning + + guestCtrl.EXPECT(). + Capabilities(). + Return(&gcs.LCOWGuestDefinedCapabilities{}) + + err := c.KillProcess(t.Context(), "nonexistent", 15, false) + if !errors.Is(err, errdefs.ErrNotFound) { + t.Errorf("KillProcess() error = %v, want ErrNotFound", err) + } +} + +// --- DeleteProcess --- + +// TestDeleteProcess_NotCreatedState verifies that DeleteProcess rejects calls +// when the container has not been created yet. +func TestDeleteProcess_NotCreatedState(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateNotCreated + + _, err := c.DeleteProcess(t.Context(), "exec-1") + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("DeleteProcess() error = %v, want ErrFailedPrecondition", err) + } +} + +// TestDeleteProcess_ProcessNotFound verifies that DeleteProcess returns +// ErrNotFound when the target exec ID does not exist. +func TestDeleteProcess_ProcessNotFound(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateCreated + + _, err := c.DeleteProcess(t.Context(), "nonexistent") + if !errors.Is(err, errdefs.ErrNotFound) { + t.Errorf("DeleteProcess() error = %v, want ErrNotFound", err) + } +} + +// TestDeleteProcess_InitProcessNotStarted verifies that deleting the init +// process on a created-but-never-started container triggers closeContainer. +func TestDeleteProcess_InitProcessNotStarted(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateCreated + + // Create a process controller in terminated state so Delete succeeds. + // process.New starts in StateNotCreated; Delete from NotCreated returns error. + // We need a process in StateCreated or StateTerminated for Delete to succeed. + // Since we can't directly set the state of process.Controller from outside + // the package, we use the fact that Kill on a StateCreated process aborts it + // into StateTerminated. + initProc := process.New(testContainerID, "", nil, 0) + c.processes[""] = initProc + + // The init process is in StateNotCreated. Delete on a process in + // StateNotCreated hits the default case and returns ErrFailedPrecondition. + _, err := c.DeleteProcess(t.Context(), "") + if err == nil { + t.Fatal("expected error deleting init process in StateNotCreated") + } +} + +// --- Update --- + +// TestUpdate_WrongState verifies that Update rejects calls outside StateRunning. +func TestUpdate_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateCreated, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + err := c.Update(t.Context(), nil) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Update() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// TestUpdate_InvalidResourceType verifies that Update rejects resources that +// are not *specs.LinuxResources. +func TestUpdate_InvalidResourceType(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = StateRunning + + err := c.Update(t.Context(), "not-linux-resources") + if err == nil { + t.Fatal("expected error for invalid resource type") + } +} + +// --- releaseResources --- + +// TestReleaseResources_AllResourceTypes verifies that releaseResources unmaps +// layers, SCSI mounts, Plan9 shares, and VPCI devices in order. +func TestReleaseResources_AllResourceTypes(t *testing.T) { + t.Parallel() + c, scsiCtrl, plan9Ctrl, vpciCtrl, guestCtrl := newContainerTestController(t) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + scsiGUID, _ := guid.NewV4() + plan9GUID, _ := guid.NewV4() + deviceGUID, _ := guid.NewV4() + + c.layers = &scsiLayers{ + layersCombined: true, + rootfsPath: "/rootfs", + scratch: scsiReservation{id: scratchGUID, guestPath: "/dev/scratch"}, + roLayers: []scsiReservation{{id: roGUID, guestPath: "/dev/ro0"}}, + } + c.scsiResources = []guid.GUID{scsiGUID} + c.plan9Resources = []guid.GUID{plan9GUID} + c.devices = []guid.GUID{deviceGUID} + + // Expect combined layers removal. + guestCtrl.EXPECT(). + RemoveLCOWCombinedLayers(gomock.Any(), guestresource.LCOWCombinedLayers{ + ContainerID: c.gcsContainerID, + ContainerRootPath: "/rootfs", + Layers: []hcsschema.Layer{{Path: "/dev/ro0"}}, + ScratchPath: "/dev/scratch", + }). + Return(nil) + + // Expect scratch + RO layer unmaps. + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scratchGUID).Return(nil) + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), roGUID).Return(nil) + + // Expect additional SCSI resource unmap. + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scsiGUID).Return(nil) + + // Expect Plan9 share unmap. + plan9Ctrl.EXPECT().UnmapFromGuest(gomock.Any(), plan9GUID).Return(nil) + + // Expect VPCI device removal. + vpciCtrl.EXPECT().RemoveFromVM(gomock.Any(), deviceGUID).Return(nil) + + c.releaseResources(t.Context()) + + // All resource slices must be nil after release. + if c.layers != nil { + t.Error("layers should be nil after releaseResources") + } + if c.scsiResources != nil { + t.Error("scsiResources should be nil after releaseResources") + } + if c.plan9Resources != nil { + t.Error("plan9Resources should be nil after releaseResources") + } + if c.devices != nil { + t.Error("devices should be nil after releaseResources") + } +} + +// TestReleaseResources_NoLayers verifies that releaseResources handles the +// case where no layers were allocated (only additional resources). +func TestReleaseResources_NoLayers(t *testing.T) { + t.Parallel() + c, scsiCtrl, _, _, _ := newContainerTestController(t) + + scsiGUID, _ := guid.NewV4() + c.scsiResources = []guid.GUID{scsiGUID} + + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scsiGUID).Return(nil) + + c.releaseResources(t.Context()) + + if c.scsiResources != nil { + t.Error("scsiResources should be nil after releaseResources") + } +} + +// TestReleaseResources_LayersNotCombined verifies that when layers exist but +// were not combined, RemoveLCOWCombinedLayers is not called. +func TestReleaseResources_LayersNotCombined(t *testing.T) { + t.Parallel() + c, scsiCtrl, _, _, _ := newContainerTestController(t) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + c.layers = &scsiLayers{ + layersCombined: false, + scratch: scsiReservation{id: scratchGUID}, + roLayers: []scsiReservation{{id: roGUID}}, + } + + // Only unmaps expected; no RemoveLCOWCombinedLayers call. + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scratchGUID).Return(nil) + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), roGUID).Return(nil) + + c.releaseResources(t.Context()) + + if c.layers != nil { + t.Error("layers should be nil after releaseResources") + } +} + +// TestReleaseResources_Idempotent verifies that a second call to +// releaseResources is a no-op. +func TestReleaseResources_Idempotent(t *testing.T) { + t.Parallel() + c, scsiCtrl, _, _, _ := newContainerTestController(t) + + scsiGUID, _ := guid.NewV4() + c.scsiResources = []guid.GUID{scsiGUID} + + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scsiGUID).Return(nil).Times(1) + + c.releaseResources(t.Context()) + // Second call should be a no-op (no mock calls expected). + c.releaseResources(t.Context()) +} + +// TestReleaseResources_ErrorsContinue verifies that releaseResources continues +// releasing remaining resources even when individual unmaps fail. +func TestReleaseResources_ErrorsContinue(t *testing.T) { + t.Parallel() + c, scsiCtrl, plan9Ctrl, vpciCtrl, _ := newContainerTestController(t) + + scsiGUID, _ := guid.NewV4() + plan9GUID, _ := guid.NewV4() + deviceGUID, _ := guid.NewV4() + + c.scsiResources = []guid.GUID{scsiGUID} + c.plan9Resources = []guid.GUID{plan9GUID} + c.devices = []guid.GUID{deviceGUID} + + // Each unmap fails, but releaseResources should still attempt all. + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scsiGUID).Return(errUnmapSCSI) + plan9Ctrl.EXPECT().UnmapFromGuest(gomock.Any(), plan9GUID).Return(errUnmapPlan9) + vpciCtrl.EXPECT().RemoveFromVM(gomock.Any(), deviceGUID).Return(errRemoveVPCI) + + // Should not panic; errors are logged. + c.releaseResources(t.Context()) + + // Slices still cleared even on errors. + if c.scsiResources != nil { + t.Error("scsiResources should be nil after releaseResources") + } + if c.plan9Resources != nil { + t.Error("plan9Resources should be nil after releaseResources") + } + if c.devices != nil { + t.Error("devices should be nil after releaseResources") + } +} + +// TestReleaseResources_MultipleROLayers verifies that all read-only layers +// are individually unmapped. +func TestReleaseResources_MultipleROLayers(t *testing.T) { + t.Parallel() + c, scsiCtrl, _, _, guestCtrl := newContainerTestController(t) + + scratchGUID, _ := guid.NewV4() + roGUID0, _ := guid.NewV4() + roGUID1, _ := guid.NewV4() + roGUID2, _ := guid.NewV4() + + c.layers = &scsiLayers{ + layersCombined: true, + rootfsPath: "/rootfs", + scratch: scsiReservation{id: scratchGUID, guestPath: "/dev/scratch"}, + roLayers: []scsiReservation{ + {id: roGUID0, guestPath: "/dev/ro0"}, + {id: roGUID1, guestPath: "/dev/ro1"}, + {id: roGUID2, guestPath: "/dev/ro2"}, + }, + } + + guestCtrl.EXPECT(). + RemoveLCOWCombinedLayers(gomock.Any(), gomock.Any()). + Return(nil) + + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), scratchGUID).Return(nil) + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), roGUID0).Return(nil) + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), roGUID1).Return(nil) + scsiCtrl.EXPECT().UnmapFromGuest(gomock.Any(), roGUID2).Return(nil) + + c.releaseResources(t.Context()) + + if c.layers != nil { + t.Error("layers should be nil after releaseResources") + } +} + +// TestReleaseResources_NoResources verifies that releaseResources is a no-op +// when no resources were allocated. +func TestReleaseResources_NoResources(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + // No mock calls expected. + c.releaseResources(t.Context()) +} + +// --- closeContainer --- + +// TestCloseContainer_IdempotentViaSyncOnce verifies that closeContainer +// executes teardown exactly once even when called multiple times. +func TestCloseContainer_IdempotentViaSyncOnce(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + + // No container, no layers — closeContainer should just close terminatedCh. + c.closeContainer(t.Context()) + + // Verify terminatedCh is closed. + select { + case <-c.terminatedCh: + default: + t.Fatal("terminatedCh should be closed after closeContainer") + } + + // Second call should be a no-op (no panic from double-close). + c.closeContainer(t.Context()) +} + +// TestCloseContainer_NilContainer verifies that closeContainer succeeds +// without panicking when the container handle is nil. +func TestCloseContainer_NilContainer(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.container = nil + + c.closeContainer(t.Context()) + + select { + case <-c.terminatedCh: + default: + t.Fatal("terminatedCh should be closed after closeContainer") + } +} + +// --- Start --- + +// TestStart_WrongState verifies that Start rejects calls outside StateCreated. +func TestStart_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateRunning, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + _, err := c.Start(t.Context(), nil) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Start() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// --- Create --- + +// TestCreate_WrongState verifies that Create rejects calls outside StateNotCreated. +func TestCreate_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateCreated, StateRunning, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + err := c.Create(t.Context(), nil, nil, nil) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Create() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// --- Pids --- + +// TestPids_WrongState verifies that Pids rejects calls outside StateRunning. +func TestPids_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateCreated, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + _, err := c.Pids(t.Context()) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Pids() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// --- Stats --- + +// TestStats_WrongState verifies that Stats rejects calls outside StateRunning. +func TestStats_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateCreated, StateStopped, StateInvalid} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + c, _, _, _, _ := newContainerTestController(t) + c.state = state + + _, err := c.Stats(t.Context()) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Stats() error = %v, want ErrFailedPrecondition", err) + } + }) + } +} + +// --- KillProcess (additional state and flow tests) --- + +// TestKillProcess_AllowedInCreatedState verifies that KillProcess does not +// reject containers in StateCreated. The downstream error from the process +// controller (which is in StateNotCreated) is expected but should not be +// confused with a container-level state rejection. +func TestKillProcess_AllowedInCreatedState(t *testing.T) { + t.Parallel() + c, _, _, _, guestCtrl := newContainerTestController(t) + c.state = StateCreated + + // Add a process controller in its initial (NotCreated) state. + c.processes[""] = process.New(testContainerID, "", nil, 0) + + guestCtrl.EXPECT(). + Capabilities(). + Return(&gcs.LCOWGuestDefinedCapabilities{}) + + // SIGTERM (15) with no signal support returns nil options. + err := c.KillProcess(t.Context(), "", 15, false) + // An error from the process controller is expected (process not started), + // but the container-level state check should not fire. + if err != nil && strings.Contains(err.Error(), "cannot kill") { + t.Errorf("KillProcess should not reject StateCreated containers, got: %v", err) + } +} + +// TestKillProcess_AllowedInStoppedState verifies that KillProcess does not +// reject containers in StateStopped. +func TestKillProcess_AllowedInStoppedState(t *testing.T) { + t.Parallel() + c, _, _, _, guestCtrl := newContainerTestController(t) + c.state = StateStopped + + c.processes[""] = process.New(testContainerID, "", nil, 0) + + guestCtrl.EXPECT(). + Capabilities(). + Return(&gcs.LCOWGuestDefinedCapabilities{}) + + err := c.KillProcess(t.Context(), "", 15, false) + // Container state check should pass; any error should come from the process. + if err != nil && strings.Contains(err.Error(), "cannot kill") { + t.Errorf("KillProcess should not reject StateStopped containers, got: %v", err) + } +} diff --git a/internal/controller/linuxcontainer/devices.go b/internal/controller/linuxcontainer/devices.go new file mode 100644 index 0000000000..e3ee248860 --- /dev/null +++ b/internal/controller/linuxcontainer/devices.go @@ -0,0 +1,56 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "fmt" + + "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" +) + +// allocateDevices reserves and maps vPCI devices for the container. +func (c *Controller) allocateDevices(ctx context.Context, spec *specs.Spec) error { + for idx := range spec.Windows.Devices { + device := &spec.Windows.Devices[idx] + + if !vpci.IsValidDeviceType(device.IDType) { + return fmt.Errorf("reserve device %s: unsupported type %s", device.ID, device.IDType) + } + + // Parse the device path into a PCI ID and optional virtual function index. + pciID, virtualFunctionIndex := vpci.GetDeviceInfoFromPath(device.ID) + + // Reserve the device on the host. + vmBusGUID, err := c.vpci.Reserve(ctx, vpci.Device{ + DeviceInstanceID: pciID, + VirtualFunctionIndex: virtualFunctionIndex, + }) + if err != nil { + return fmt.Errorf("reserve device %s: %w", device.ID, err) + } + + // Map the device into the VM. + if err = c.vpci.AddToVM(ctx, vmBusGUID); err != nil { + return fmt.Errorf("add device %s to vm: %w", device.ID, err) + } + + log.G(ctx).WithFields(logrus.Fields{ + logfields.DeviceID: pciID, + logfields.VFIndex: virtualFunctionIndex, + logfields.VMBusGUID: vmBusGUID.String(), + }).Trace("reserved and mapped vPCI device") + + // Rewrite the spec entry so GCS references the VMBus GUID. + device.ID = vmBusGUID.String() + c.devices = append(c.devices, vmBusGUID) + } + + log.G(ctx).Debug("all vPCI devices allocated successfully") + return nil +} diff --git a/internal/controller/linuxcontainer/devices_test.go b/internal/controller/linuxcontainer/devices_test.go new file mode 100644 index 0000000000..1a90c1d76a --- /dev/null +++ b/internal/controller/linuxcontainer/devices_test.go @@ -0,0 +1,340 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "errors" + "testing" + + "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer/mocks" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/opencontainers/runtime-spec/specs-go" + "go.uber.org/mock/gomock" +) + +// newTestControllerAndSpec creates a Controller wired to a fresh vPCIController +// mock alongside a minimal OCI spec populated with the provided Windows devices. +func newTestControllerAndSpec(t *testing.T, devices ...specs.WindowsDevice) (*Controller, *specs.Spec, *mocks.MockvPCIController) { + t.Helper() + ctrl := gomock.NewController(t) + vpciCtrl := mocks.NewMockvPCIController(ctrl) + return &Controller{vpci: vpciCtrl}, &specs.Spec{ + Windows: &specs.Windows{ + Devices: devices, + }, + }, vpciCtrl +} + +var ( + errReserve = errors.New("reserve failed") + errAddToVM = errors.New("add to vm failed") +) + +// TestAllocateDevices_NoDevices verifies that allocateDevices succeeds without +// any vPCI calls when the spec contains no Windows devices. +func TestAllocateDevices_NoDevices(t *testing.T) { + t.Parallel() + c, spec, _ := newTestControllerAndSpec(t) + + if err := c.allocateDevices(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(c.devices) != 0 { + t.Errorf("expected 0 tracked devices, got %d", len(c.devices)) + } +} + +// TestAllocateDevices_InvalidDeviceType verifies that allocateDevices returns an +// error for unsupported device types, regardless of position in the device list. +func TestAllocateDevices_InvalidDeviceType(t *testing.T) { + t.Parallel() + tests := []struct { + name string + devices []specs.WindowsDevice + }{ + { + name: "single-invalid", + devices: []specs.WindowsDevice{ + {ID: "PCI\\VEN_1234&DEV_5678\\0", IDType: "unsupported-type"}, + }, + }, + { + name: "invalid-before-valid", + devices: []specs.WindowsDevice{ + {ID: "PCI\\VEN_AAAA&DEV_1111\\0", IDType: "bad-type"}, + {ID: "PCI\\VEN_BBBB&DEV_2222\\0", IDType: vpci.DeviceIDType}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, spec, _ := newTestControllerAndSpec(t, tt.devices...) + + // No Reserve or AddToVM calls expected. + + if err := c.allocateDevices(t.Context(), spec); err == nil { + t.Fatal("expected error for unsupported device type") + } + if len(c.devices) != 0 { + t.Errorf("expected 0 tracked devices, got %d", len(c.devices)) + } + }) + } +} + +// TestAllocateDevices_SingleDevice verifies the Reserve → AddToVM flow for each +// supported device type, including VF index parsing and spec ID rewrite. +func TestAllocateDevices_SingleDevice(t *testing.T) { + t.Parallel() + tests := []struct { + name string + deviceID string + idType string + expectPCI string + expectVF uint16 + }{ + { + name: "vpci-instance-id", + deviceID: "PCI\\VEN_1234&DEV_5678\\0", + idType: vpci.DeviceIDType, + expectPCI: "PCI\\VEN_1234&DEV_5678\\0", + expectVF: 0, + }, + { + name: "vpci-legacy-with-vf-index", + deviceID: "PCI\\VEN_1234&DEV_5678\\0/3", + idType: vpci.DeviceIDTypeLegacy, + expectPCI: "PCI\\VEN_1234&DEV_5678\\0", + expectVF: 3, + }, + { + name: "gpu", + deviceID: "PCI\\VEN_ABCD&DEV_9876\\0", + idType: vpci.GpuDeviceIDType, + expectPCI: "PCI\\VEN_ABCD&DEV_9876\\0", + expectVF: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, spec, vpciCtrl := newTestControllerAndSpec(t, specs.WindowsDevice{ + ID: tt.deviceID, + IDType: tt.idType, + }) + + testGUID, _ := guid.NewV4() + + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), vpci.Device{ + DeviceInstanceID: tt.expectPCI, + VirtualFunctionIndex: tt.expectVF, + }). + Return(testGUID, nil) + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), testGUID). + Return(nil) + + if err := c.allocateDevices(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the spec entry was rewritten to the VMBus GUID. + if got := spec.Windows.Devices[0].ID; got != testGUID.String() { + t.Errorf("spec device ID = %q, want %q", got, testGUID.String()) + } + + // Verify the GUID was tracked. + if len(c.devices) != 1 || c.devices[0] != testGUID { + t.Errorf("tracked devices = %v, want [%v]", c.devices, testGUID) + } + }) + } +} + +// TestAllocateDevices_SingleDeviceFailure verifies that Reserve and AddToVM +// failures are propagated and no device is tracked. +func TestAllocateDevices_SingleDeviceFailure(t *testing.T) { + t.Parallel() + tests := []struct { + name string + reserveErr error + addToVMErr error + wantWrapped error + }{ + { + name: "reserve-fails", + reserveErr: errReserve, + wantWrapped: errReserve, + }, + { + name: "add-to-vm-fails", + addToVMErr: errAddToVM, + wantWrapped: errAddToVM, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, spec, vpciCtrl := newTestControllerAndSpec(t, specs.WindowsDevice{ + ID: "PCI\\VEN_1234&DEV_5678\\0", + IDType: vpci.DeviceIDType, + }) + + testGUID, _ := guid.NewV4() + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any()). + Return(testGUID, tt.reserveErr) + + // AddToVM is only called when Reserve succeeds. + if tt.reserveErr == nil { + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), testGUID). + Return(tt.addToVMErr) + } + + err := c.allocateDevices(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, tt.wantWrapped) { + t.Errorf("error = %v, want wrapping %v", err, tt.wantWrapped) + } + if len(c.devices) != 0 { + t.Errorf("expected 0 tracked devices, got %d", len(c.devices)) + } + }) + } +} + +// TestAllocateDevices_MultipleDevices verifies that allocateDevices correctly +// handles multiple devices, reserving and adding each one independently. +func TestAllocateDevices_MultipleDevices(t *testing.T) { + t.Parallel() + c, spec, vpciCtrl := newTestControllerAndSpec(t, + specs.WindowsDevice{ID: "PCI\\VEN_AAAA&DEV_1111\\0", IDType: vpci.DeviceIDType}, + specs.WindowsDevice{ID: "PCI\\VEN_BBBB&DEV_2222\\0", IDType: vpci.GpuDeviceIDType}, + ) + + guidA, _ := guid.NewV4() + guidB, _ := guid.NewV4() + + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), vpci.Device{ + DeviceInstanceID: "PCI\\VEN_AAAA&DEV_1111\\0", + VirtualFunctionIndex: 0, + }). + Return(guidA, nil) + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), guidA). + Return(nil) + + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), vpci.Device{ + DeviceInstanceID: "PCI\\VEN_BBBB&DEV_2222\\0", + VirtualFunctionIndex: 0, + }). + Return(guidB, nil) + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), guidB). + Return(nil) + + if err := c.allocateDevices(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.devices) != 2 { + t.Fatalf("expected 2 tracked devices, got %d", len(c.devices)) + } + if c.devices[0] != guidA || c.devices[1] != guidB { + t.Errorf("tracked GUIDs = %v, %v; want %v, %v", c.devices[0], c.devices[1], guidA, guidB) + } + if spec.Windows.Devices[0].ID != guidA.String() { + t.Errorf("first device ID = %q, want %q", spec.Windows.Devices[0].ID, guidA.String()) + } + if spec.Windows.Devices[1].ID != guidB.String() { + t.Errorf("second device ID = %q, want %q", spec.Windows.Devices[1].ID, guidB.String()) + } +} + +// TestAllocateDevices_MultipleDevicesPartialFailure verifies that when the +// second device fails (at Reserve or AddToVM), the first device is tracked +// but the overall call returns the expected error. +func TestAllocateDevices_MultipleDevicesPartialFailure(t *testing.T) { + t.Parallel() + tests := []struct { + name string + reserveErr error + addToVMErr error + wantWrapped error + }{ + { + name: "second-reserve-fails", + reserveErr: errReserve, + wantWrapped: errReserve, + }, + { + name: "second-add-to-vm-fails", + addToVMErr: errAddToVM, + wantWrapped: errAddToVM, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, spec, vpciCtrl := newTestControllerAndSpec(t, + specs.WindowsDevice{ID: "PCI\\VEN_AAAA&DEV_1111\\0", IDType: vpci.DeviceIDType}, + specs.WindowsDevice{ID: "PCI\\VEN_BBBB&DEV_2222\\0", IDType: vpci.DeviceIDType}, + ) + + guidA, _ := guid.NewV4() + guidB, _ := guid.NewV4() + + // First device always succeeds. + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), vpci.Device{ + DeviceInstanceID: "PCI\\VEN_AAAA&DEV_1111\\0", + VirtualFunctionIndex: 0, + }). + Return(guidA, nil) + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), guidA). + Return(nil) + + // Second device fails at the configured step. + vpciCtrl.EXPECT(). + Reserve(gomock.Any(), vpci.Device{ + DeviceInstanceID: "PCI\\VEN_BBBB&DEV_2222\\0", + VirtualFunctionIndex: 0, + }). + Return(guidB, tt.reserveErr) + + // AddToVM for the second device is only called when its Reserve succeeds. + if tt.reserveErr == nil { + vpciCtrl.EXPECT(). + AddToVM(gomock.Any(), guidB). + Return(tt.addToVMErr) + } + + err := c.allocateDevices(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, tt.wantWrapped) { + t.Errorf("error = %v, want wrapping %v", err, tt.wantWrapped) + } + + // First device was already allocated before the second failed. + if len(c.devices) != 1 { + t.Errorf("expected 1 tracked device after partial failure, got %d", len(c.devices)) + } + }) + } +} diff --git a/internal/controller/linuxcontainer/doc.go b/internal/controller/linuxcontainer/doc.go new file mode 100644 index 0000000000..fc36d43939 --- /dev/null +++ b/internal/controller/linuxcontainer/doc.go @@ -0,0 +1,67 @@ +//go:build windows && lcow + +// Package linuxcontainer provides a controller for managing the full lifecycle of +// a single LCOW (Linux Containers on Windows) container running inside a Utility VM (UVM). +// +// It coordinates host-side resource allocation (SCSI layers, Plan9 shares, vPCI devices), +// guest-side container creation via the GCS (Guest Compute Service), and process management. +// +// # Lifecycle +// +// A container follows the state machine below. +// +// ┌──────────────────┐ +// │ StateNotCreated │ +// └───┬──────────┬───┘ +// Create ok │ │ Create fails +// ▼ ▼ +// ┌──────────────┐ ┌──────────────┐ +// │ StateCreated │ │ StateInvalid │ +// └───┬──────┬───┘ └──────────────┘ +// Start ok │ │ Start fails +// ▼ ▼ +// ┌──────────────┐ ┌──────────────┐ +// │ StateRunning │ │ StateInvalid │ +// └──────┬───────┘ └──────────────┘ +// │ init process exits +// ▼ +// ┌──────────────┐ +// │ StateStopped │ +// └──────────────┘ +// +// State descriptions: +// +// - [StateNotCreated]: initial state; no resources have been allocated. +// - [StateCreated]: after [Controller.Create] succeeds; host-side resources are +// allocated, the GCS container exists, and the init process is ready but not started. +// - [StateRunning]: after [Controller.Start] succeeds; the init process is executing. +// Exec processes may be added via [Controller.NewProcess]. +// - [StateStopped]: terminal state reached when the init process exits; +// all host-side resources have been released. +// - [StateInvalid]: entered when [Controller.Create] or [Controller.Start] fails +// mid-way; host-side resources are released. If the failure occurred after the +// GCS container was successfully created, guest-side state may still require +// cleanup via [Controller.DeleteProcess]. +// +// # Resource Allocation +// +// During [Controller.Create], three categories of host-side resources are allocated +// and mapped into the guest: +// +// - Layers: read-only image layers and the writable scratch layer are attached +// via SCSI and combined inside the guest to form the container rootfs. +// - Mounts: OCI spec mounts are dispatched by type — disk-backed mounts go through +// SCSI, host-directory bind mounts go through Plan9 shares, and guest-internal +// or unknown types pass through unmodified. +// - Devices: Windows vPCI devices listed in the OCI spec are reserved on the host, +// added to the VM, and their spec entries are rewritten to VMBus GUIDs. +// +// All allocated resources are released in reverse order during container teardown. +// +// # Process Management +// +// The init process (exec ID "") is created during [Controller.Create] and started +// during [Controller.Start]. Additional exec processes can be added to a running +// container via [Controller.NewProcess]. When the init process exits, the controller +// tears down all host-side resources and transitions to [StateStopped]. +package linuxcontainer diff --git a/internal/controller/linuxcontainer/document.go b/internal/controller/linuxcontainer/document.go new file mode 100644 index 0000000000..c5b0e893f7 --- /dev/null +++ b/internal/controller/linuxcontainer/document.go @@ -0,0 +1,152 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "encoding/json" + "fmt" + + containerdtypes "github.com/containerd/containerd/api/types" + "github.com/opencontainers/runtime-spec/specs-go" + + "github.com/Microsoft/hcsshim/internal/guestpath" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/oci" + "github.com/Microsoft/hcsshim/internal/ospath" + "github.com/Microsoft/hcsshim/internal/schemaversion" + "github.com/Microsoft/hcsshim/pkg/annotations" +) + +// vmHostedContainerSettingsV2 defines the portion of the container +// configuration sent via a V2 GCS call for LCOW containers. +type vmHostedContainerSettingsV2 struct { + SchemaVersion *hcsschema.Version + OCIBundlePath string `json:"OciBundlePath,omitempty"` + OCISpecification *specs.Spec `json:"OciSpecification,omitempty"` + // ScratchDirPath is the path inside the UVM where the container scratch + // directory is present. Usually the mount path of the scratch VHD, but + // with scratch sharing it becomes a sub-directory under the UVM scratch. + ScratchDirPath string +} + +// generateContainerDocument allocates all host-side resources (layers, mounts, devices) +// and returns the GCS container configuration document. +func (c *Controller) generateContainerDocument( + ctx context.Context, + spec *specs.Spec, + rootfs []*containerdtypes.Mount, + isScratchEncryptionEnabled bool, +) (*vmHostedContainerSettingsV2, error) { + if spec.Linux == nil { + return nil, fmt.Errorf("linux section must be present for lcow container") + } + + // If windows section is not present, add an empty section + // to avoid nil dereference in downstream code. + if spec.Windows == nil { + spec.Windows = &specs.Windows{} + } + + // Allocate host-side resources: layers, mounts, and devices. + + if err := c.allocateLayers(ctx, spec.Windows.LayerFolders, rootfs, isScratchEncryptionEnabled); err != nil { + return nil, fmt.Errorf("allocate layers: %w", err) + } + + if err := c.allocateMounts(ctx, spec); err != nil { + return nil, fmt.Errorf("allocate mounts: %w", err) + } + + if err := c.allocateDevices(ctx, spec); err != nil { + return nil, fmt.Errorf("allocate devices: %w", err) + } + + // Set the rootfs path for the container within guest. + if spec.Root == nil { + spec.Root = &specs.Root{} + } + spec.Root.Path = c.layers.rootfsPath + + // Build a sanitized deep copy of the spec for the guest. + linuxSpec, err := sanitizeSpec(ctx, spec) + if err != nil { + return nil, fmt.Errorf("sanitize spec: %w", err) + } + + return &vmHostedContainerSettingsV2{ + SchemaVersion: schemaversion.SchemaV21(), + OCIBundlePath: ospath.Join("linux", guestpath.LCOWV2RootPrefixInVM, c.gcsPodID, c.gcsContainerID), + OCISpecification: linuxSpec, + ScratchDirPath: c.layers.scratch.guestPath, + }, nil +} + +// sanitizeSpec deep-copies the OCI spec and strips fields unsupported by the GCS. +func sanitizeSpec(ctx context.Context, origSpec *specs.Spec) (*specs.Spec, error) { + // Deep copy via JSON round-trip so mutations do not affect the caller. + raw, err := json.Marshal(origSpec) + if err != nil { + return nil, fmt.Errorf("marshal spec: %w", err) + } + spec := &specs.Spec{} + if err = json.Unmarshal(raw, spec); err != nil { + return nil, fmt.Errorf("unmarshal spec: %w", err) + } + + // Preserve only the network namespace and assigned devices from the Windows section. + spec.Windows = extractWindowsFields(origSpec) + + // Hooks are executed on the host, not inside the guest. + spec.Hooks = nil + + // Apply safe CPU defaults when values are explicitly zeroed. + if spec.Linux.Resources != nil && spec.Linux.Resources.CPU != nil { + cpu := spec.Linux.Resources.CPU + if cpu.Period != nil && *cpu.Period == 0 { + *cpu.Period = 100000 + } + if cpu.Quota != nil && *cpu.Quota == 0 { + *cpu.Quota = -1 + } + } + + // Clear resource types the GCS manages on its own. + spec.Linux.CgroupsPath = "" + if spec.Linux.Resources != nil { + spec.Linux.Resources.Devices = nil + spec.Linux.Resources.Pids = nil + spec.Linux.Resources.BlockIO = nil + spec.Linux.Resources.HugepageLimits = nil + spec.Linux.Resources.Network = nil + } + + // Disable seccomp for privileged containers. + if oci.ParseAnnotationsBool(ctx, spec.Annotations, annotations.LCOWPrivileged, false) { + spec.Linux.Seccomp = nil + } + + return spec, nil +} + +// extractWindowsFields keeps only the network namespace and assigned devices. +func extractWindowsFields(origSpec *specs.Spec) *specs.Windows { + var win *specs.Windows + + if origSpec.Windows.Network != nil && origSpec.Windows.Network.NetworkNamespace != "" { + win = &specs.Windows{ + Network: &specs.WindowsNetwork{ + NetworkNamespace: origSpec.Windows.Network.NetworkNamespace, + }, + } + } + + if len(origSpec.Windows.Devices) > 0 { + if win == nil { + win = &specs.Windows{} + } + win.Devices = origSpec.Windows.Devices + } + + return win +} diff --git a/internal/controller/linuxcontainer/document_test.go b/internal/controller/linuxcontainer/document_test.go new file mode 100644 index 0000000000..5b1b1c628f --- /dev/null +++ b/internal/controller/linuxcontainer/document_test.go @@ -0,0 +1,302 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "testing" + + "github.com/opencontainers/runtime-spec/specs-go" + + "github.com/Microsoft/hcsshim/pkg/annotations" +) + +// TestSanitizeSpec_CPUDefaults verifies that explicitly zeroed CPU period and +// quota are replaced with safe defaults, while non-zero values pass through. +func TestSanitizeSpec_CPUDefaults(t *testing.T) { + t.Parallel() + + period := func(v uint64) *uint64 { return &v } + quota := func(v int64) *int64 { return &v } + + tests := []struct { + name string + period *uint64 + quota *int64 + wantPeriod uint64 + wantQuota int64 + }{ + { + name: "zero-period-and-quota", + period: period(0), + quota: quota(0), + wantPeriod: 100000, + wantQuota: -1, + }, + { + name: "non-zero-values-unchanged", + period: period(50000), + quota: quota(25000), + wantPeriod: 50000, + wantQuota: 25000, + }, + { + name: "nil-period-and-quota", + period: nil, + quota: nil, + wantPeriod: 0, // unused; checked via nil + wantQuota: 0, // unused; checked via nil + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + spec := &specs.Spec{Linux: &specs.Linux{}, Windows: &specs.Windows{}} + spec.Linux.Resources = &specs.LinuxResources{ + CPU: &specs.LinuxCPU{Period: tt.period, Quota: tt.quota}, + } + + got, err := sanitizeSpec(t.Context(), spec) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cpu := got.Linux.Resources.CPU + if tt.period == nil { + if cpu.Period != nil { + t.Errorf("Period = %v, want nil", *cpu.Period) + } + } else if *cpu.Period != tt.wantPeriod { + t.Errorf("Period = %d, want %d", *cpu.Period, tt.wantPeriod) + } + + if tt.quota == nil { + if cpu.Quota != nil { + t.Errorf("Quota = %v, want nil", *cpu.Quota) + } + } else if *cpu.Quota != tt.wantQuota { + t.Errorf("Quota = %d, want %d", *cpu.Quota, tt.wantQuota) + } + }) + } +} + +// TestSanitizeSpec_ClearsGCSManagedResources verifies that cgroups path and +// resource types managed by the GCS are removed from the sanitized spec. +func TestSanitizeSpec_ClearsGCSManagedResources(t *testing.T) { + t.Parallel() + classID := uint32(1) + cpuShares := uint64(512) + spec := &specs.Spec{Linux: &specs.Linux{}, Windows: &specs.Windows{}} + spec.Linux.CgroupsPath = "/sys/fs/cgroup/test" + spec.Linux.Resources = &specs.LinuxResources{ + Devices: []specs.LinuxDeviceCgroup{{Allow: true}}, + Pids: &specs.LinuxPids{Limit: 100}, + BlockIO: &specs.LinuxBlockIO{}, + HugepageLimits: []specs.LinuxHugepageLimit{{Pagesize: "2MB", Limit: 1024}}, + Network: &specs.LinuxNetwork{ClassID: &classID}, + // CPU should be preserved. + CPU: &specs.LinuxCPU{Shares: &cpuShares}, + } + + got, err := sanitizeSpec(t.Context(), spec) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got.Linux.CgroupsPath != "" { + t.Errorf("CgroupsPath = %q, want empty", got.Linux.CgroupsPath) + } + res := got.Linux.Resources + if res.Devices != nil { + t.Error("Devices should be nil") + } + if res.Pids != nil { + t.Error("Pids should be nil") + } + if res.BlockIO != nil { + t.Error("BlockIO should be nil") + } + if res.HugepageLimits != nil { + t.Error("HugepageLimits should be nil") + } + if res.Network != nil { + t.Error("Network should be nil") + } + // CPU must survive the clear. + if res.CPU == nil || res.CPU.Shares == nil || *res.CPU.Shares != 512 { + t.Error("CPU.Shares should be preserved as 512") + } +} + +// TestSanitizeSpec_NilsHooksAndSeccomp verifies that hooks are always removed +// and seccomp is removed only for privileged containers. +func TestSanitizeSpec_NilsHooksAndSeccomp(t *testing.T) { + t.Parallel() + tests := []struct { + name string + privileged bool + wantSeccomp bool + }{ + {name: "non-privileged", privileged: false, wantSeccomp: true}, + {name: "privileged", privileged: true, wantSeccomp: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + spec := &specs.Spec{Linux: &specs.Linux{}, Windows: &specs.Windows{}} + spec.Hooks = &specs.Hooks{ + CreateRuntime: []specs.Hook{{Path: "/bin/hook"}}, + } + spec.Linux.Seccomp = &specs.LinuxSeccomp{ + DefaultAction: specs.ActErrno, + } + if tt.privileged { + spec.Annotations = map[string]string{ + annotations.LCOWPrivileged: "true", + } + } + + got, err := sanitizeSpec(t.Context(), spec) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got.Hooks != nil { + t.Error("Hooks should be nil") + } + hasSeccomp := got.Linux.Seccomp != nil + if hasSeccomp != tt.wantSeccomp { + t.Errorf("Seccomp present = %v, want %v", hasSeccomp, tt.wantSeccomp) + } + }) + } +} + +// TestSanitizeSpec_DeepCopy confirms mutations to the sanitized spec do not +// affect the original. +func TestSanitizeSpec_DeepCopy(t *testing.T) { + t.Parallel() + orig := &specs.Spec{Linux: &specs.Linux{}, Windows: &specs.Windows{}} + orig.Hostname = "original" + + got, err := sanitizeSpec(t.Context(), orig) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got.Hostname = "mutated" + if orig.Hostname != "original" { + t.Error("mutation of sanitized spec leaked to the original") + } +} + +// TestSanitizeSpec_NilResources verifies sanitizeSpec succeeds when +// Linux.Resources is nil (no CPU/resource rewriting needed). +func TestSanitizeSpec_NilResources(t *testing.T) { + t.Parallel() + spec := &specs.Spec{Linux: &specs.Linux{}, Windows: &specs.Windows{}} + spec.Linux.Resources = nil + + got, err := sanitizeSpec(t.Context(), spec) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Linux.Resources != nil { + t.Error("Resources should remain nil") + } +} + +// TestExtractWindowsFields verifies that only the network namespace and assigned +// devices are preserved from the Windows section. +func TestExtractWindowsFields(t *testing.T) { + t.Parallel() + tests := []struct { + name string + windows *specs.Windows + wantNil bool + wantNetwork bool + wantDevices int + }{ + { + name: "empty-windows", + windows: &specs.Windows{}, + wantNil: true, + }, + { + name: "network-only", + windows: &specs.Windows{ + Network: &specs.WindowsNetwork{NetworkNamespace: "ns1"}, + }, + wantNetwork: true, + }, + { + name: "devices-only", + windows: &specs.Windows{ + Devices: []specs.WindowsDevice{{ID: "dev1"}}, + }, + wantDevices: 1, + }, + { + name: "network-and-devices", + windows: &specs.Windows{ + Network: &specs.WindowsNetwork{NetworkNamespace: "ns2"}, + Devices: []specs.WindowsDevice{{ID: "d1"}, {ID: "d2"}}, + }, + wantNetwork: true, + wantDevices: 2, + }, + { + name: "empty-network-namespace-ignored", + windows: &specs.Windows{ + Network: &specs.WindowsNetwork{NetworkNamespace: ""}, + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + spec := &specs.Spec{Windows: tt.windows} + got := extractWindowsFields(spec) + + if tt.wantNil { + if got != nil { + t.Fatalf("expected nil, got %+v", got) + } + return + } + + if got == nil { + t.Fatal("expected non-nil Windows") + } + + if tt.wantNetwork { + if got.Network == nil || got.Network.NetworkNamespace == "" { + t.Error("expected network namespace to be preserved") + } + } else if got.Network != nil { + t.Error("expected no network") + } + + if len(got.Devices) != tt.wantDevices { + t.Errorf("Devices count = %d, want %d", len(got.Devices), tt.wantDevices) + } + }) + } +} + +// TestGenerateContainerDocument_NilLinux verifies that generateContainerDocument +// returns an error when the Linux section is absent. +func TestGenerateContainerDocument_NilLinux(t *testing.T) { + t.Parallel() + c := &Controller{gcsPodID: "pod", gcsContainerID: "ctr"} + spec := &specs.Spec{} // no Linux section + + _, err := c.generateContainerDocument(t.Context(), spec, nil, false) + if err == nil { + t.Fatal("expected error for nil Linux section") + } +} diff --git a/internal/controller/linuxcontainer/layers.go b/internal/controller/linuxcontainer/layers.go new file mode 100644 index 0000000000..a8169334ae --- /dev/null +++ b/internal/controller/linuxcontainer/layers.go @@ -0,0 +1,157 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "fmt" + + "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + scsiMount "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + "github.com/Microsoft/hcsshim/internal/guestpath" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/layers" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/ospath" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/wclayer" + + "github.com/Microsoft/go-winio/pkg/fs" + "github.com/Microsoft/go-winio/pkg/guid" + containerdtypes "github.com/containerd/containerd/api/types" +) + +// scsiReservation pairs a SCSI reservation ID with its resolved guest path. +type scsiReservation struct { + id guid.GUID + guestPath string +} + +// scsiLayers holds the SCSI reservations for all container layers. +type scsiLayers struct { + roLayers []scsiReservation + scratch scsiReservation + layersCombined bool + rootfsPath string +} + +// grantVMAccess and resolvePath are converted to vars for unit testing. +var ( + grantVMAccess = wclayer.GrantVmAccess + resolvePath = fs.ResolvePath +) + +// allocateLayers parses, reserves, and maps all LCOW layers into the guest. +func (c *Controller) allocateLayers( + ctx context.Context, + layerFolders []string, + rootfs []*containerdtypes.Mount, + isScratchEncryptionEnabled bool, +) error { + log.G(ctx).Debug("allocating container layers") + + // Parse the rootfs mounts and layer folders into the canonical LCOW layer format. + lcowLayers, err := layers.ParseLCOWLayers(rootfs, layerFolders) + if err != nil { + return fmt.Errorf("parse lcow layers: %w", err) + } + + c.layers = &scsiLayers{} + + // Reserve and map each read-only layer. + for _, roLayer := range lcowLayers.Layers { + // Read-only layers come from the containerd snapshotter with broad read + // permissions (typically via GrantVmGroupAccess), so no per-VM access + // grant is needed here. + + // The layer path may be a symlink; resolve it to a real path before + // handing it to the SCSI reservation. + hostPath, err := resolvePath(roLayer.VHDPath) + if err != nil { + return fmt.Errorf("resolve symlinks for layer %s: %w", roLayer.VHDPath, err) + } + + reservationID, err := c.scsi.Reserve( + ctx, + disk.Config{HostPath: hostPath, ReadOnly: true, Type: disk.TypeVirtualDisk}, + scsiMount.Config{Partition: roLayer.Partition, ReadOnly: true, Options: []string{"ro"}}, + ) + if err != nil { + return fmt.Errorf("reserve scsi slot for layer %s: %w", roLayer.VHDPath, err) + } + + guestPath, err := c.scsi.MapToGuest(ctx, reservationID) + if err != nil { + return fmt.Errorf("map layer %s to guest: %w", roLayer.VHDPath, err) + } + + c.layers.roLayers = append(c.layers.roLayers, scsiReservation{id: reservationID, guestPath: guestPath}) + } + + // Reserve and map the writable scratch layer. + + // The scratch path may be a symlink to a shared sandbox.vhdx from another + // container (e.g. the sandbox container). Resolve it before granting access. + scratchHostPath, err := resolvePath(lcowLayers.ScratchVHDPath) + if err != nil { + return fmt.Errorf("resolve symlinks for scratch %s: %w", lcowLayers.ScratchVHDPath, err) + } + + // Unlike read-only layers, the scratch VHD requires explicit per-VM access. + if err = grantVMAccess(ctx, c.vmID, scratchHostPath); err != nil { + return fmt.Errorf("grant vm access to scratch %s: %w", scratchHostPath, err) + } + + // Encrypted scratch disks use xfs; all others default to ext4. + fileSystem := "ext4" + if isScratchEncryptionEnabled { + fileSystem = "xfs" + } + + scratchID, err := c.scsi.Reserve(ctx, + disk.Config{HostPath: scratchHostPath, ReadOnly: false, Type: disk.TypeVirtualDisk}, + scsiMount.Config{ + Encrypted: isScratchEncryptionEnabled, + EnsureFilesystem: true, + ReadOnly: false, + Filesystem: fileSystem, + }, + ) + if err != nil { + return fmt.Errorf("reserve scsi slot for scratch %s: %w", lcowLayers.ScratchVHDPath, err) + } + + scratchMountPath, err := c.scsi.MapToGuest(ctx, scratchID) + if err != nil { + return fmt.Errorf("map scratch to guest: %w", err) + } + + // When sharing a scratch disk across multiple containers, derive a unique + // sub-path per container to prevent upper/work directory collisions. + c.layers.scratch = scsiReservation{ + id: scratchID, + guestPath: ospath.Join("linux", scratchMountPath, "scratch", c.gcsPodID, c.gcsContainerID), + } + c.layers.rootfsPath = ospath.Join("linux", guestpath.LCOWV2RootPrefixInVM, c.gcsPodID, c.gcsContainerID, guestpath.RootfsPath) + + // Combine the mapped layers as the final step. + hcsLayers := make([]hcsschema.Layer, len(c.layers.roLayers)) + for i, roLayer := range c.layers.roLayers { + hcsLayers[i] = hcsschema.Layer{Path: roLayer.guestPath} + } + + if err = c.guest.AddLCOWCombinedLayers(ctx, guestresource.LCOWCombinedLayers{ + ContainerID: c.gcsContainerID, + ContainerRootPath: c.layers.rootfsPath, + Layers: hcsLayers, + ScratchPath: c.layers.scratch.guestPath, + }); err != nil { + return fmt.Errorf("add combined layers: %w", err) + } + + // Set the layersCombined flag so that we can uncombine them during teardown. + c.layers.layersCombined = true + + log.G(ctx).WithField("layers", log.Format(ctx, c.layers)).Trace("all LCOW layers reserved and mapped") + return nil +} diff --git a/internal/controller/linuxcontainer/layers_test.go b/internal/controller/linuxcontainer/layers_test.go new file mode 100644 index 0000000000..279d104048 --- /dev/null +++ b/internal/controller/linuxcontainer/layers_test.go @@ -0,0 +1,579 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "errors" + "testing" + + "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + scsiMount "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer/mocks" + "github.com/Microsoft/hcsshim/internal/guestpath" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + "github.com/Microsoft/hcsshim/internal/ospath" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + + "github.com/Microsoft/go-winio/pkg/guid" + containerdtypes "github.com/containerd/containerd/api/types" + "go.uber.org/mock/gomock" +) + +var ( + errResolvePath = errors.New("resolve path failed") + errGrantVMAccess = errors.New("grant vm access failed") + errScsiReserve = errors.New("scsi reserve failed") + errMapToGuest = errors.New("map to guest failed") + errCombineLayers = errors.New("combine layers failed") +) + +// newLayersTestController creates a Controller wired to mock SCSI and guest +// controllers, with stubbed resolvePath and grantVMAccess functions. +func newLayersTestController(t *testing.T) ( + *Controller, + *mocks.MockscsiController, + *mocks.Mockguest, +) { + t.Helper() + ctrl := gomock.NewController(t) + scsiCtrl := mocks.NewMockscsiController(ctrl) + guestCtrl := mocks.NewMockguest(ctrl) + c := &Controller{ + vmID: "test-vm", + gcsPodID: "test-pod", + gcsContainerID: "test-ctr", + scsi: scsiCtrl, + guest: guestCtrl, + } + return c, scsiCtrl, guestCtrl +} + +// stubResolvePath replaces the package-level resolvePath with an identity +// function and restores the original when the test completes. +func stubResolvePath(t *testing.T) { + t.Helper() + orig := resolvePath + resolvePath = func(path string) (string, error) { return path, nil } + t.Cleanup(func() { resolvePath = orig }) +} + +// stubGrantVMAccess replaces the package-level grantVMAccess with a no-op and +// restores the original when the test completes. +func stubGrantVMAccess(t *testing.T) { + t.Helper() + orig := grantVMAccess + grantVMAccess = func(_ context.Context, _, _ string) error { return nil } + t.Cleanup(func() { grantVMAccess = orig }) +} + +// legacyLayerFolders returns layer folders in the legacy containerd format: +// [parent0, parent1, ..., scratch]. +func legacyLayerFolders(parentPaths []string, scratchDir string) []string { + folders := make([]string, 0, len(parentPaths)+1) + folders = append(folders, parentPaths...) + folders = append(folders, scratchDir) + return folders +} + +// TestAllocateLayers_SingleReadOnlyLayer verifies the full Reserve → MapToGuest +// → CombineLayers flow for a container with one read-only layer and a scratch. +func TestAllocateLayers_SingleReadOnlyLayer(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, guestCtrl := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer: Reserve → MapToGuest. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{HostPath: `C:\layers\base\layer.vhd`, ReadOnly: true, Type: disk.TypeVirtualDisk}, scsiMount.Config{Partition: 0, ReadOnly: true, Options: []string{"ro"}}). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch layer: Reserve → MapToGuest. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{HostPath: `C:\scratch\sandbox.vhdx`, ReadOnly: false, Type: disk.TypeVirtualDisk}, scsiMount.Config{EnsureFilesystem: true, Filesystem: "ext4"}). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("/dev/sdb", nil) + + // Combine layers. + expectedScratchPath := ospath.Join("linux", "/dev/sdb", "scratch", c.gcsPodID, c.gcsContainerID) + expectedRootfsPath := ospath.Join("linux", guestpath.LCOWV2RootPrefixInVM, c.gcsPodID, c.gcsContainerID, guestpath.RootfsPath) + + guestCtrl.EXPECT(). + AddLCOWCombinedLayers(gomock.Any(), guestresource.LCOWCombinedLayers{ + ContainerID: c.gcsContainerID, + ContainerRootPath: expectedRootfsPath, + Layers: []hcsschema.Layer{{Path: "/dev/sda"}}, + ScratchPath: expectedScratchPath, + }). + Return(nil) + + if err := c.allocateLayers(t.Context(), layerFolders, nil, false); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.layers.roLayers) != 1 { + t.Errorf("expected 1 read-only layer, got %d", len(c.layers.roLayers)) + } + if c.layers.roLayers[0].id != roGUID { + t.Errorf("ro layer GUID = %v, want %v", c.layers.roLayers[0].id, roGUID) + } + if c.layers.scratch.id != scratchGUID { + t.Errorf("scratch GUID = %v, want %v", c.layers.scratch.id, scratchGUID) + } + if !c.layers.layersCombined { + t.Error("expected layersCombined to be true") + } + if c.layers.rootfsPath != expectedRootfsPath { + t.Errorf("rootfsPath = %q, want %q", c.layers.rootfsPath, expectedRootfsPath) + } +} + +// TestAllocateLayers_MultipleReadOnlyLayers verifies that multiple read-only +// layers are each reserved, mapped, and passed to CombineLayers in order. +func TestAllocateLayers_MultipleReadOnlyLayers(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, guestCtrl := newLayersTestController(t) + layerFolders := legacyLayerFolders( + []string{`C:\layers\layer0`, `C:\layers\layer1`}, + `C:\scratch`, + ) + + roGUID0, _ := guid.NewV4() + roGUID1, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer 0. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{HostPath: `C:\layers\layer0\layer.vhd`, ReadOnly: true, Type: disk.TypeVirtualDisk}, scsiMount.Config{Partition: 0, ReadOnly: true, Options: []string{"ro"}}). + Return(roGUID0, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID0). + Return("/dev/sda", nil) + + // Read-only layer 1. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{HostPath: `C:\layers\layer1\layer.vhd`, ReadOnly: true, Type: disk.TypeVirtualDisk}, scsiMount.Config{Partition: 0, ReadOnly: true, Options: []string{"ro"}}). + Return(roGUID1, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID1). + Return("/dev/sdb", nil) + + // Scratch layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("/dev/sdc", nil) + + // Combine layers with both read-only layers. + guestCtrl.EXPECT(). + AddLCOWCombinedLayers(gomock.Any(), gomock.Any()). + Return(nil) + + if err := c.allocateLayers(t.Context(), layerFolders, nil, false); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.layers.roLayers) != 2 { + t.Fatalf("expected 2 read-only layers, got %d", len(c.layers.roLayers)) + } + if c.layers.roLayers[0].id != roGUID0 { + t.Errorf("ro layer 0 GUID = %v, want %v", c.layers.roLayers[0].id, roGUID0) + } + if c.layers.roLayers[1].id != roGUID1 { + t.Errorf("ro layer 1 GUID = %v, want %v", c.layers.roLayers[1].id, roGUID1) + } +} + +// TestAllocateLayers_ScratchEncryption verifies that when scratch encryption is +// enabled, the scratch disk is reserved with xfs and the encrypted flag set. +func TestAllocateLayers_ScratchEncryption(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, guestCtrl := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch layer: must use xfs and the Encrypted flag. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), + disk.Config{HostPath: `C:\scratch\sandbox.vhdx`, ReadOnly: false, Type: disk.TypeVirtualDisk}, + scsiMount.Config{ + Encrypted: true, + EnsureFilesystem: true, + ReadOnly: false, + Filesystem: "xfs", + }). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("/dev/sdb", nil) + + guestCtrl.EXPECT(). + AddLCOWCombinedLayers(gomock.Any(), gomock.Any()). + Return(nil) + + if err := c.allocateLayers(t.Context(), layerFolders, nil, true); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !c.layers.layersCombined { + t.Error("expected layersCombined to be true") + } +} + +// TestAllocateLayers_ResolvePathFailure verifies that a resolvePath failure on +// a read-only layer propagates the error. +func TestAllocateLayers_ResolvePathFailure(t *testing.T) { + t.Parallel() + stubGrantVMAccess(t) + + orig := resolvePath + resolvePath = func(_ string) (string, error) { return "", errResolvePath } + t.Cleanup(func() { resolvePath = orig }) + + c, _, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errResolvePath) { + t.Errorf("error = %v, want wrapping %v", err, errResolvePath) + } +} + +// TestAllocateLayers_ROLayerReserveFailure verifies that a SCSI Reserve failure +// for a read-only layer propagates the error. +func TestAllocateLayers_ROLayerReserveFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(guid.GUID{}, errScsiReserve) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errScsiReserve) { + t.Errorf("error = %v, want wrapping %v", err, errScsiReserve) + } +} + +// TestAllocateLayers_ROLayerMapToGuestFailure verifies that a MapToGuest +// failure for a read-only layer propagates the error. +func TestAllocateLayers_ROLayerMapToGuestFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("", errMapToGuest) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errMapToGuest) { + t.Errorf("error = %v, want wrapping %v", err, errMapToGuest) + } +} + +// TestAllocateLayers_GrantVMAccessFailure verifies that a grantVMAccess failure +// on the scratch layer propagates the error. +func TestAllocateLayers_GrantVMAccessFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + + orig := grantVMAccess + grantVMAccess = func(_ context.Context, _, _ string) error { return errGrantVMAccess } + t.Cleanup(func() { grantVMAccess = orig }) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + + // Read-only layer succeeds. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errGrantVMAccess) { + t.Errorf("error = %v, want wrapping %v", err, errGrantVMAccess) + } +} + +// TestAllocateLayers_ScratchReserveFailure verifies that a SCSI Reserve failure +// on the scratch layer propagates the error. +func TestAllocateLayers_ScratchReserveFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + // Read-only layer succeeds. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch Reserve fails. We use DoAndReturn to distinguish the second + // Reserve call (scratch) from the first (read-only layer) above. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(guid.GUID{}, errScsiReserve) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errScsiReserve) { + t.Errorf("error = %v, want wrapping %v", err, errScsiReserve) + } +} + +// TestAllocateLayers_ScratchMapToGuestFailure verifies that a MapToGuest +// failure on the scratch layer propagates the error. +func TestAllocateLayers_ScratchMapToGuestFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer succeeds. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch Reserve succeeds but MapToGuest fails. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("", errMapToGuest) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errMapToGuest) { + t.Errorf("error = %v, want wrapping %v", err, errMapToGuest) + } +} + +// TestAllocateLayers_CombineLayersFailure verifies that an +// AddLCOWCombinedLayers failure propagates the error and layersCombined remains +// false. +func TestAllocateLayers_CombineLayersFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, guestCtrl := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("/dev/sdb", nil) + + // CombineLayers fails. + guestCtrl.EXPECT(). + AddLCOWCombinedLayers(gomock.Any(), gomock.Any()). + Return(errCombineLayers) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errCombineLayers) { + t.Errorf("error = %v, want wrapping %v", err, errCombineLayers) + } + if c.layers.layersCombined { + t.Error("expected layersCombined to be false after combine failure") + } +} + +// TestAllocateLayers_ScratchResolvePathFailure verifies that a resolvePath +// failure on the scratch VHD propagates the error. +func TestAllocateLayers_ScratchResolvePathFailure(t *testing.T) { + t.Parallel() + stubGrantVMAccess(t) + + callCount := 0 + orig := resolvePath + resolvePath = func(path string) (string, error) { + callCount++ + // Let read-only layer resolve succeed, fail on scratch. + if callCount == 1 { + return path, nil + } + return "", errResolvePath + } + t.Cleanup(func() { resolvePath = orig }) + + c, scsiCtrl, _ := newLayersTestController(t) + layerFolders := legacyLayerFolders([]string{`C:\layers\base`}, `C:\scratch`) + + roGUID, _ := guid.NewV4() + + // Read-only layer succeeds. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + err := c.allocateLayers(t.Context(), layerFolders, nil, false) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errResolvePath) { + t.Errorf("error = %v, want wrapping %v", err, errResolvePath) + } +} + +// TestAllocateLayers_RootfsMount verifies allocateLayers works with a rootfs +// mount instead of legacy layer folders. +func TestAllocateLayers_RootfsMount(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, guestCtrl := newLayersTestController(t) + + rootfs := []*containerdtypes.Mount{ + { + Type: "lcow-layer", + Source: `C:\scratch`, + Options: []string{ + `parentLayerPaths=["C:\\layers\\base"]`, + }, + }, + } + + roGUID, _ := guid.NewV4() + scratchGUID, _ := guid.NewV4() + + // Read-only layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(roGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), roGUID). + Return("/dev/sda", nil) + + // Scratch layer. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(scratchGUID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scratchGUID). + Return("/dev/sdb", nil) + + guestCtrl.EXPECT(). + AddLCOWCombinedLayers(gomock.Any(), gomock.Any()). + Return(nil) + + if err := c.allocateLayers(t.Context(), nil, rootfs, false); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.layers.roLayers) != 1 { + t.Errorf("expected 1 read-only layer, got %d", len(c.layers.roLayers)) + } + if !c.layers.layersCombined { + t.Error("expected layersCombined to be true") + } +} + +// TestAllocateLayers_InvalidLayerFolders verifies that allocateLayers returns +// an error when both rootfs and layerFolders are empty. +func TestAllocateLayers_InvalidLayerFolders(t *testing.T) { + t.Parallel() + c, _, _ := newLayersTestController(t) + + err := c.allocateLayers(t.Context(), nil, nil, false) + if err == nil { + t.Fatal("expected error for empty layers") + } +} diff --git a/internal/controller/linuxcontainer/mocks/mock_types.go b/internal/controller/linuxcontainer/mocks/mock_types.go new file mode 100644 index 0000000000..e47d2ee379 --- /dev/null +++ b/internal/controller/linuxcontainer/mocks/mock_types.go @@ -0,0 +1,325 @@ +//go:build windows && lcow + +// Code generated by MockGen. DO NOT EDIT. +// Source: types.go +// +// Generated by this command: +// +// mockgen -build_flags=-tags=lcow -build_constraint=windows && lcow -source types.go -package mocks -destination mocks/mock_types.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + guid "github.com/Microsoft/go-winio/pkg/guid" + mount "github.com/Microsoft/hcsshim/internal/controller/device/plan9/mount" + share "github.com/Microsoft/hcsshim/internal/controller/device/plan9/share" + disk "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + mount0 "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + vpci "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + gcs "github.com/Microsoft/hcsshim/internal/gcs" + guestresource "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + gomock "go.uber.org/mock/gomock" +) + +// Mockguest is a mock of guest interface. +type Mockguest struct { + ctrl *gomock.Controller + recorder *MockguestMockRecorder + isgomock struct{} +} + +// MockguestMockRecorder is the mock recorder for Mockguest. +type MockguestMockRecorder struct { + mock *Mockguest +} + +// NewMockguest creates a new mock instance. +func NewMockguest(ctrl *gomock.Controller) *Mockguest { + mock := &Mockguest{ctrl: ctrl} + mock.recorder = &MockguestMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mockguest) EXPECT() *MockguestMockRecorder { + return m.recorder +} + +// AddLCOWCombinedLayers mocks base method. +func (m *Mockguest) AddLCOWCombinedLayers(ctx context.Context, settings guestresource.LCOWCombinedLayers) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddLCOWCombinedLayers", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddLCOWCombinedLayers indicates an expected call of AddLCOWCombinedLayers. +func (mr *MockguestMockRecorder) AddLCOWCombinedLayers(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLCOWCombinedLayers", reflect.TypeOf((*Mockguest)(nil).AddLCOWCombinedLayers), ctx, settings) +} + +// Capabilities mocks base method. +func (m *Mockguest) Capabilities() gcs.GuestDefinedCapabilities { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Capabilities") + ret0, _ := ret[0].(gcs.GuestDefinedCapabilities) + return ret0 +} + +// Capabilities indicates an expected call of Capabilities. +func (mr *MockguestMockRecorder) Capabilities() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Capabilities", reflect.TypeOf((*Mockguest)(nil).Capabilities)) +} + +// CreateContainer mocks base method. +func (m *Mockguest) CreateContainer(ctx context.Context, cid string, config any) (*gcs.Container, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateContainer", ctx, cid, config) + ret0, _ := ret[0].(*gcs.Container) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateContainer indicates an expected call of CreateContainer. +func (mr *MockguestMockRecorder) CreateContainer(ctx, cid, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateContainer", reflect.TypeOf((*Mockguest)(nil).CreateContainer), ctx, cid, config) +} + +// DeleteContainerState mocks base method. +func (m *Mockguest) DeleteContainerState(ctx context.Context, cid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteContainerState", ctx, cid) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteContainerState indicates an expected call of DeleteContainerState. +func (mr *MockguestMockRecorder) DeleteContainerState(ctx, cid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteContainerState", reflect.TypeOf((*Mockguest)(nil).DeleteContainerState), ctx, cid) +} + +// RemoveLCOWCombinedLayers mocks base method. +func (m *Mockguest) RemoveLCOWCombinedLayers(ctx context.Context, settings guestresource.LCOWCombinedLayers) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveLCOWCombinedLayers", ctx, settings) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveLCOWCombinedLayers indicates an expected call of RemoveLCOWCombinedLayers. +func (mr *MockguestMockRecorder) RemoveLCOWCombinedLayers(ctx, settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveLCOWCombinedLayers", reflect.TypeOf((*Mockguest)(nil).RemoveLCOWCombinedLayers), ctx, settings) +} + +// MockscsiController is a mock of scsiController interface. +type MockscsiController struct { + ctrl *gomock.Controller + recorder *MockscsiControllerMockRecorder + isgomock struct{} +} + +// MockscsiControllerMockRecorder is the mock recorder for MockscsiController. +type MockscsiControllerMockRecorder struct { + mock *MockscsiController +} + +// NewMockscsiController creates a new mock instance. +func NewMockscsiController(ctrl *gomock.Controller) *MockscsiController { + mock := &MockscsiController{ctrl: ctrl} + mock.recorder = &MockscsiControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockscsiController) EXPECT() *MockscsiControllerMockRecorder { + return m.recorder +} + +// MapToGuest mocks base method. +func (m *MockscsiController) MapToGuest(ctx context.Context, id guid.GUID) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MapToGuest", ctx, id) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MapToGuest indicates an expected call of MapToGuest. +func (mr *MockscsiControllerMockRecorder) MapToGuest(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MapToGuest", reflect.TypeOf((*MockscsiController)(nil).MapToGuest), ctx, id) +} + +// Reserve mocks base method. +func (m *MockscsiController) Reserve(ctx context.Context, diskConfig disk.Config, mountConfig mount0.Config) (guid.GUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reserve", ctx, diskConfig, mountConfig) + ret0, _ := ret[0].(guid.GUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Reserve indicates an expected call of Reserve. +func (mr *MockscsiControllerMockRecorder) Reserve(ctx, diskConfig, mountConfig any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reserve", reflect.TypeOf((*MockscsiController)(nil).Reserve), ctx, diskConfig, mountConfig) +} + +// UnmapFromGuest mocks base method. +func (m *MockscsiController) UnmapFromGuest(ctx context.Context, reservation guid.GUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnmapFromGuest", ctx, reservation) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnmapFromGuest indicates an expected call of UnmapFromGuest. +func (mr *MockscsiControllerMockRecorder) UnmapFromGuest(ctx, reservation any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmapFromGuest", reflect.TypeOf((*MockscsiController)(nil).UnmapFromGuest), ctx, reservation) +} + +// Mockplan9Controller is a mock of plan9Controller interface. +type Mockplan9Controller struct { + ctrl *gomock.Controller + recorder *Mockplan9ControllerMockRecorder + isgomock struct{} +} + +// Mockplan9ControllerMockRecorder is the mock recorder for Mockplan9Controller. +type Mockplan9ControllerMockRecorder struct { + mock *Mockplan9Controller +} + +// NewMockplan9Controller creates a new mock instance. +func NewMockplan9Controller(ctrl *gomock.Controller) *Mockplan9Controller { + mock := &Mockplan9Controller{ctrl: ctrl} + mock.recorder = &Mockplan9ControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mockplan9Controller) EXPECT() *Mockplan9ControllerMockRecorder { + return m.recorder +} + +// MapToGuest mocks base method. +func (m *Mockplan9Controller) MapToGuest(ctx context.Context, id guid.GUID) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MapToGuest", ctx, id) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MapToGuest indicates an expected call of MapToGuest. +func (mr *Mockplan9ControllerMockRecorder) MapToGuest(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MapToGuest", reflect.TypeOf((*Mockplan9Controller)(nil).MapToGuest), ctx, id) +} + +// Reserve mocks base method. +func (m *Mockplan9Controller) Reserve(ctx context.Context, shareConfig share.Config, mountConfig mount.Config) (guid.GUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reserve", ctx, shareConfig, mountConfig) + ret0, _ := ret[0].(guid.GUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Reserve indicates an expected call of Reserve. +func (mr *Mockplan9ControllerMockRecorder) Reserve(ctx, shareConfig, mountConfig any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reserve", reflect.TypeOf((*Mockplan9Controller)(nil).Reserve), ctx, shareConfig, mountConfig) +} + +// UnmapFromGuest mocks base method. +func (m *Mockplan9Controller) UnmapFromGuest(ctx context.Context, reservation guid.GUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnmapFromGuest", ctx, reservation) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnmapFromGuest indicates an expected call of UnmapFromGuest. +func (mr *Mockplan9ControllerMockRecorder) UnmapFromGuest(ctx, reservation any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmapFromGuest", reflect.TypeOf((*Mockplan9Controller)(nil).UnmapFromGuest), ctx, reservation) +} + +// MockvPCIController is a mock of vPCIController interface. +type MockvPCIController struct { + ctrl *gomock.Controller + recorder *MockvPCIControllerMockRecorder + isgomock struct{} +} + +// MockvPCIControllerMockRecorder is the mock recorder for MockvPCIController. +type MockvPCIControllerMockRecorder struct { + mock *MockvPCIController +} + +// NewMockvPCIController creates a new mock instance. +func NewMockvPCIController(ctrl *gomock.Controller) *MockvPCIController { + mock := &MockvPCIController{ctrl: ctrl} + mock.recorder = &MockvPCIControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockvPCIController) EXPECT() *MockvPCIControllerMockRecorder { + return m.recorder +} + +// AddToVM mocks base method. +func (m *MockvPCIController) AddToVM(ctx context.Context, vmBusGUID guid.GUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddToVM", ctx, vmBusGUID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddToVM indicates an expected call of AddToVM. +func (mr *MockvPCIControllerMockRecorder) AddToVM(ctx, vmBusGUID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddToVM", reflect.TypeOf((*MockvPCIController)(nil).AddToVM), ctx, vmBusGUID) +} + +// RemoveFromVM mocks base method. +func (m *MockvPCIController) RemoveFromVM(ctx context.Context, vmBusGUID guid.GUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveFromVM", ctx, vmBusGUID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveFromVM indicates an expected call of RemoveFromVM. +func (mr *MockvPCIControllerMockRecorder) RemoveFromVM(ctx, vmBusGUID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveFromVM", reflect.TypeOf((*MockvPCIController)(nil).RemoveFromVM), ctx, vmBusGUID) +} + +// Reserve mocks base method. +func (m *MockvPCIController) Reserve(ctx context.Context, device vpci.Device) (guid.GUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reserve", ctx, device) + ret0, _ := ret[0].(guid.GUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Reserve indicates an expected call of Reserve. +func (mr *MockvPCIControllerMockRecorder) Reserve(ctx, device any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reserve", reflect.TypeOf((*MockvPCIController)(nil).Reserve), ctx, device) +} diff --git a/internal/controller/linuxcontainer/mounts.go b/internal/controller/linuxcontainer/mounts.go new file mode 100644 index 0000000000..560c45a310 --- /dev/null +++ b/internal/controller/linuxcontainer/mounts.go @@ -0,0 +1,260 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + plan9Mount "github.com/Microsoft/hcsshim/internal/controller/device/plan9/mount" + "github.com/Microsoft/hcsshim/internal/controller/device/plan9/share" + "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + scsiMount "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + "github.com/Microsoft/hcsshim/internal/guestpath" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/opencontainers/runtime-spec/specs-go" +) + +// Mount type constants. +const ( + // mountTypeBind is a regular host-directory bind mount served via a Plan9 share. + mountTypeBind = "bind" + + // mountTypePhysicalDisk hot-adds a physical pass-through disk via the SCSI controller. + mountTypePhysicalDisk = "physical-disk" + + // mountTypeVirtualDisk hot-adds a VHD or VHDX via the SCSI controller. + mountTypeVirtualDisk = "virtual-disk" + + // mountTypeExtensibleVirtualDisk hot-adds an extensible virtual disk via the SCSI controller. + mountTypeExtensibleVirtualDisk = "extensible-virtual-disk" + + // mountTypeNone signals that the mount is a disk-backed device mount whose + // filesystem will be resolved when the guest actually mounts the device. + mountTypeNone = "none" +) + +// allocateMounts reserves and maps host-side resources for each OCI mount, +// rewriting mount sources in the spec to their guest-visible paths. +func (c *Controller) allocateMounts(ctx context.Context, spec *specs.Spec) error { + for idx := range spec.Mounts { + mount := &spec.Mounts[idx] + + if mount.Destination == "" || mount.Source == "" { + return fmt.Errorf("invalid mount: both source and destination are required: %+v", mount) + } + + // Check if the mount is read-only. + isReadOnly := isReadOnlyMount(mount) + + // Dispatch to a mount-type-specific handler. + switch mount.Type { + case mountTypeVirtualDisk, mountTypePhysicalDisk, mountTypeExtensibleVirtualDisk: + if err := c.allocateSCSIMount(ctx, mount, isReadOnly); err != nil { + return err + } + case mountTypeBind: + // Hugepages mounts are backed by a pre-existing mount inside the UVM. + if strings.HasPrefix(mount.Source, guestpath.HugePagesMountPrefix) { + if err := validateHugePageMount(mount.Source); err != nil { + return err + } + continue + } + + // Guest-internal paths resolve entirely inside the UVM. + if isGuestInternalPath(mount.Source) { + continue + } + + // All remaining bind mounts are host directories served via Plan9. + // Allocate them. + if err := c.allocatePlan9Mount(ctx, mount, isReadOnly); err != nil { + return err + } + default: + // Unknown mount types (e.g. tmpfs, devpts, proc) are passed through + // to the guest without host-side resource reservation/allocation. + } + } + + log.G(ctx).Debug("all OCI mounts allocated successfully") + return nil +} + +// allocateSCSIMount resolves the host path, grants VM access, and reserves+maps +// a SCSI slot for any disk-backed mount type. +func (c *Controller) allocateSCSIMount(ctx context.Context, mount *specs.Mount, isReadOnly bool) error { + // Build disk config based on mount type. + var diskConfig disk.Config + switch mount.Type { + case mountTypeVirtualDisk, mountTypePhysicalDisk: + // Resolve any symlinks to get the real host path for the disk. + hostPath, err := resolvePath(mount.Source) + if err != nil { + return fmt.Errorf("resolve symlinks for mount source %s: %w", mount.Source, err) + } + + // The VM needs explicit access to the disk before it can be attached. + if err = grantVMAccess(ctx, c.vmID, hostPath); err != nil { + return fmt.Errorf("grant vm access to %s: %w", hostPath, err) + } + + // Physical disks use pass-through; everything else is a virtual disk. + diskType := disk.TypeVirtualDisk + if mount.Type == mountTypePhysicalDisk { + diskType = disk.TypePassThru + } + + // Create the final disk config. + diskConfig = disk.Config{HostPath: hostPath, ReadOnly: isReadOnly, Type: diskType} + + case mountTypeExtensibleVirtualDisk: + // EVD paths encode the provider type in the source URI. + evdType, sourcePath, err := parseExtensibleVirtualDiskPath(mount.Source) + if err != nil { + return fmt.Errorf("parse extensible virtual disk path: %w", err) + } + + // Resolve any symlinks to get the real host path for the disk. + hostPath, err := resolvePath(sourcePath) + if err != nil { + return fmt.Errorf("resolve symlinks for mount source %s: %w", sourcePath, err) + } + + // Create the final disk config. + diskConfig = disk.Config{HostPath: hostPath, ReadOnly: isReadOnly, Type: disk.TypeExtensibleVirtualDisk, EVDType: evdType} + + default: + return fmt.Errorf("unsupported scsi mount type %q", mount.Type) + } + + // Check if this is a block dev mount. + isBlockDev := strings.HasPrefix(mount.Destination, guestpath.BlockDevMountPrefix) + + // Reserve the mount. + reservationID, err := c.scsi.Reserve(ctx, diskConfig, scsiMount.Config{ + ReadOnly: isReadOnly, + Options: mount.Options, + BlockDev: isBlockDev, + }) + if err != nil { + return fmt.Errorf("reserve scsi mount for %s: %w", mount.Source, err) + } + + // Map the device into the guest. + guestPath, err := c.scsi.MapToGuest(ctx, reservationID) + if err != nil { + return fmt.Errorf("map scsi mount %s to guest: %w", mount.Source, err) + } + + c.scsiResources = append(c.scsiResources, reservationID) + + // Rewrite source to guest path; block-device mounts retain bind type. + mount.Source = guestPath + mount.Type = mountTypeNone + if isBlockDev { + mount.Type = mountTypeBind + } + + return nil +} + +// allocatePlan9Mount reserves and maps a Plan9 share for a host-backed bind mount. +func (c *Controller) allocatePlan9Mount(ctx context.Context, mount *specs.Mount, isReadOnly bool) error { + // Ensure that mount source exists. + fileInfo, err := os.Stat(mount.Source) + if err != nil { + return fmt.Errorf("stat bind mount source %s: %w", mount.Source, err) + } + + shareConfig := share.Config{ + HostPath: mount.Source, + ReadOnly: isReadOnly, + } + + // For single-file mounts, share the containing directory but restrict + // access to the specific file. + if !fileInfo.IsDir() { + hostDir, fileName := filepath.Split(mount.Source) + shareConfig.HostPath = hostDir + shareConfig.Restrict = true + shareConfig.AllowedNames = []string{fileName} + } + + // Reserve the plan9 share. + reservationID, err := c.plan9.Reserve(ctx, shareConfig, plan9Mount.Config{ReadOnly: isReadOnly}) + if err != nil { + return fmt.Errorf("reserve plan9 share for %s: %w", mount.Source, err) + } + + // Map the share into the guest. + guestPath, err := c.plan9.MapToGuest(ctx, reservationID) + if err != nil { + return fmt.Errorf("map plan9 share %s to guest: %w", mount.Source, err) + } + + c.plan9Resources = append(c.plan9Resources, reservationID) + mount.Source = guestPath + return nil +} + +// --- Helpers --- + +// parseExtensibleVirtualDiskPath extracts the EVD type and source path from +// a path with the format "evd:///". +func parseExtensibleVirtualDiskPath(hostPath string) (evdType, sourcePath string, err error) { + const evdPrefix = "evd://" + + if !strings.HasPrefix(hostPath, evdPrefix) { + return "", "", fmt.Errorf("invalid extensible virtual disk path %q: missing %q prefix", hostPath, evdPrefix) + } + + trimmed := strings.TrimPrefix(hostPath, evdPrefix) + idx := strings.Index(trimmed, "/") + if idx <= 0 { + return "", "", fmt.Errorf("invalid extensible virtual disk path %q: expected format %s/", hostPath, evdPrefix) + } + + return trimmed[:idx], trimmed[idx+1:], nil +} + +// validateHugePageMount checks that a hugepages mount source has the expected format. +func validateHugePageMount(source string) error { + // Expected format: "hugepages:///" + parts := strings.Split(strings.TrimPrefix(source, guestpath.HugePagesMountPrefix), "/") + if len(parts) < 2 { + return fmt.Errorf("invalid hugepages mount path %s: expected format %s/", source, guestpath.HugePagesMountPrefix) + } + // Only 2M (megabyte) hugepages are currently supported. + if parts[0] != "2M" { + return fmt.Errorf("unsupported hugepage size %s: only 2M is supported", parts[0]) + } + return nil +} + +// isReadOnlyMount returns true if the mount options contain the "ro" flag. +func isReadOnlyMount(mount *specs.Mount) bool { + for _, option := range mount.Options { + if strings.EqualFold(option, "ro") { + return true + } + } + return false +} + +// isGuestInternalPath reports whether the path uses a UVM-internal prefix +// that resolves inside the guest. +func isGuestInternalPath(path string) bool { + // Mounts that map to a path in UVM are specified with a 'sandbox://', 'sandbox-tmp://', or 'uvm://' prefix. + // examples: + // - sandbox:///a/dirInUvm destination:/b/dirInContainer + // - sandbox-tmp:///a/dirInUvm destination:/b/dirInContainer + // - uvm:///a/dirInUvm destination:/b/dirInContainer + return strings.HasPrefix(path, guestpath.SandboxMountPrefix) || + strings.HasPrefix(path, guestpath.SandboxTmpfsMountPrefix) || + strings.HasPrefix(path, guestpath.UVMMountPrefix) +} diff --git a/internal/controller/linuxcontainer/mounts_test.go b/internal/controller/linuxcontainer/mounts_test.go new file mode 100644 index 0000000000..e8029149d0 --- /dev/null +++ b/internal/controller/linuxcontainer/mounts_test.go @@ -0,0 +1,1050 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + + plan9Mount "github.com/Microsoft/hcsshim/internal/controller/device/plan9/mount" + "github.com/Microsoft/hcsshim/internal/controller/device/plan9/share" + "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + scsiMount "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer/mocks" + "github.com/Microsoft/hcsshim/internal/guestpath" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/opencontainers/runtime-spec/specs-go" + "go.uber.org/mock/gomock" +) + +var ( + errPlan9Reserve = errors.New("plan9 reserve failed") + errPlan9Map = errors.New("plan9 map to guest failed") +) + +// newMountsTestController creates a Controller wired to mock SCSI and Plan9 +// controllers alongside stubbed resolvePath and grantVMAccess functions. +func newMountsTestController(t *testing.T) ( + *Controller, + *mocks.MockscsiController, + *mocks.Mockplan9Controller, +) { + t.Helper() + ctrl := gomock.NewController(t) + scsiCtrl := mocks.NewMockscsiController(ctrl) + plan9Ctrl := mocks.NewMockplan9Controller(ctrl) + c := &Controller{ + vmID: "test-vm", + scsi: scsiCtrl, + plan9: plan9Ctrl, + } + return c, scsiCtrl, plan9Ctrl +} + +// --- allocateMounts: SCSI mount tests --- + +// TestAllocateMounts_VirtualDisk verifies the full Reserve → MapToGuest flow +// for a virtual-disk mount, including source rewrite and type change. +func TestAllocateMounts_VirtualDisk(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + Options: []string{"ro"}, + }, + }, + } + + id, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{ + HostPath: `C:\disks\data.vhdx`, + ReadOnly: true, + Type: disk.TypeVirtualDisk, + }, scsiMount.Config{ + ReadOnly: true, + Options: []string{"ro"}, + }). + Return(id, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/dev/sda", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spec.Mounts[0].Source != "/dev/sda" { + t.Errorf("mount source = %q, want %q", spec.Mounts[0].Source, "/dev/sda") + } + if spec.Mounts[0].Type != mountTypeNone { + t.Errorf("mount type = %q, want %q", spec.Mounts[0].Type, mountTypeNone) + } + if len(c.scsiResources) != 1 || c.scsiResources[0] != id { + t.Errorf("scsiResources = %v, want [%v]", c.scsiResources, id) + } +} + +// TestAllocateMounts_PhysicalDisk verifies that a physical-disk mount uses +// PassThru disk type and correctly rewrites the spec. +func TestAllocateMounts_PhysicalDisk(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `\\.\PhysicalDrive1`, + Destination: "/mnt/disk", + Type: mountTypePhysicalDisk, + }, + }, + } + + id, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{ + HostPath: `\\.\PhysicalDrive1`, + ReadOnly: false, + Type: disk.TypePassThru, + }, scsiMount.Config{}). + Return(id, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/dev/sdb", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spec.Mounts[0].Source != "/dev/sdb" { + t.Errorf("mount source = %q, want %q", spec.Mounts[0].Source, "/dev/sdb") + } + if spec.Mounts[0].Type != mountTypeNone { + t.Errorf("mount type = %q, want %q", spec.Mounts[0].Type, mountTypeNone) + } +} + +// TestAllocateMounts_ExtensibleVirtualDisk verifies EVD path parsing and SCSI +// reservation for an extensible-virtual-disk mount. +func TestAllocateMounts_ExtensibleVirtualDisk(t *testing.T) { + t.Parallel() + stubResolvePath(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `evd://provider-type/C:\disks\data.vhdx`, + Destination: "/mnt/evd", + Type: mountTypeExtensibleVirtualDisk, + }, + }, + } + + id, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), disk.Config{ + HostPath: `C:\disks\data.vhdx`, + ReadOnly: false, + Type: disk.TypeExtensibleVirtualDisk, + EVDType: "provider-type", + }, scsiMount.Config{}). + Return(id, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/dev/sdc", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spec.Mounts[0].Source != "/dev/sdc" { + t.Errorf("mount source = %q, want %q", spec.Mounts[0].Source, "/dev/sdc") + } +} + +// TestAllocateMounts_BlockDevMount verifies that a block-device mount (indicated +// by a "blockdev://" destination prefix) keeps the bind type. +func TestAllocateMounts_BlockDevMount(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\block.vhdx`, + Destination: guestpath.BlockDevMountPrefix + "/dev/sda", + Type: mountTypeVirtualDisk, + }, + }, + } + + id, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), scsiMount.Config{ + BlockDev: true, + }). + Return(id, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/dev/sda", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Block dev mounts retain bind type. + if spec.Mounts[0].Type != mountTypeBind { + t.Errorf("mount type = %q, want %q", spec.Mounts[0].Type, mountTypeBind) + } +} + +// TestAllocateMounts_SCSIReserveFailure verifies that a SCSI Reserve failure +// propagates the error. +func TestAllocateMounts_SCSIReserveFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + }, + }, + } + + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(guid.GUID{}, errScsiReserve) + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errScsiReserve) { + t.Errorf("error = %v, want wrapping %v", err, errScsiReserve) + } +} + +// TestAllocateMounts_SCSIMapToGuestFailure verifies that a SCSI MapToGuest +// failure propagates the error. +func TestAllocateMounts_SCSIMapToGuestFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + }, + }, + } + + id, _ := guid.NewV4() + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(id, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("", errMapToGuest) + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errMapToGuest) { + t.Errorf("error = %v, want wrapping %v", err, errMapToGuest) + } +} + +// TestAllocateMounts_VirtualDiskResolvePathFailure verifies that a resolvePath +// failure for a virtual-disk mount propagates the error. +func TestAllocateMounts_VirtualDiskResolvePathFailure(t *testing.T) { + t.Parallel() + stubGrantVMAccess(t) + + orig := resolvePath + resolvePath = func(_ string) (string, error) { return "", errResolvePath } + t.Cleanup(func() { resolvePath = orig }) + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errResolvePath) { + t.Errorf("error = %v, want wrapping %v", err, errResolvePath) + } +} + +// TestAllocateMounts_VirtualDiskGrantVMAccessFailure verifies that a +// grantVMAccess failure for a virtual-disk mount propagates the error. +func TestAllocateMounts_VirtualDiskGrantVMAccessFailure(t *testing.T) { + t.Parallel() + stubResolvePath(t) + + orig := grantVMAccess + grantVMAccess = func(_ context.Context, _, _ string) error { return errGrantVMAccess } + t.Cleanup(func() { grantVMAccess = orig }) + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errGrantVMAccess) { + t.Errorf("error = %v, want wrapping %v", err, errGrantVMAccess) + } +} + +// TestAllocateMounts_EVDResolvePathFailure verifies that a resolvePath failure +// for an extensible-virtual-disk mount propagates the error. +func TestAllocateMounts_EVDResolvePathFailure(t *testing.T) { + t.Parallel() + + orig := resolvePath + resolvePath = func(_ string) (string, error) { return "", errResolvePath } + t.Cleanup(func() { resolvePath = orig }) + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `evd://mytype/C:\disk.vhdx`, + Destination: "/mnt/evd", + Type: mountTypeExtensibleVirtualDisk, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errResolvePath) { + t.Errorf("error = %v, want wrapping %v", err, errResolvePath) + } +} + +// --- allocateMounts: Plan9 bind mount tests --- + +// TestAllocateMounts_Plan9BindDirectory verifies the Reserve → MapToGuest flow +// for a host directory bind mount served via Plan9. +func TestAllocateMounts_Plan9BindDirectory(t *testing.T) { + t.Parallel() + + c, _, plan9Ctrl := newMountsTestController(t) + hostDir := t.TempDir() + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: hostDir, + Destination: "/mnt/hostdir", + Type: mountTypeBind, + }, + }, + } + + id, _ := guid.NewV4() + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), share.Config{ + HostPath: hostDir, + ReadOnly: false, + }, plan9Mount.Config{ReadOnly: false}). + Return(id, nil) + plan9Ctrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/mnt/plan9/share0", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spec.Mounts[0].Source != "/mnt/plan9/share0" { + t.Errorf("mount source = %q, want %q", spec.Mounts[0].Source, "/mnt/plan9/share0") + } + if len(c.plan9Resources) != 1 || c.plan9Resources[0] != id { + t.Errorf("plan9Resources = %v, want [%v]", c.plan9Resources, id) + } +} + +// TestAllocateMounts_Plan9BindReadOnly verifies that the read-only flag +// is propagated to both share and mount configs for Plan9 mounts. +func TestAllocateMounts_Plan9BindReadOnly(t *testing.T) { + t.Parallel() + + c, _, plan9Ctrl := newMountsTestController(t) + hostDir := t.TempDir() + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: hostDir, + Destination: "/mnt/readonly", + Type: mountTypeBind, + Options: []string{"ro"}, + }, + }, + } + + id, _ := guid.NewV4() + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), share.Config{ + HostPath: hostDir, + ReadOnly: true, + }, plan9Mount.Config{ReadOnly: true}). + Return(id, nil) + plan9Ctrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/mnt/plan9/share0", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestAllocateMounts_Plan9BindSingleFile verifies that a single-file bind mount +// shares the parent directory with restrict mode and the file in AllowedNames. +func TestAllocateMounts_Plan9BindSingleFile(t *testing.T) { + t.Parallel() + + c, _, plan9Ctrl := newMountsTestController(t) + + // Create a real temp file so os.Stat succeeds and reports !IsDir(). + dir := t.TempDir() + filePath := filepath.Join(dir, "config.json") + if err := os.WriteFile(filePath, []byte("{}"), 0644); err != nil { + t.Fatalf("create temp file: %v", err) + } + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: filePath, + Destination: "/etc/config.json", + Type: mountTypeBind, + }, + }, + } + + id, _ := guid.NewV4() + + // For a single-file mount, the share's HostPath should be the parent + // directory with Restrict enabled and the filename in AllowedNames. + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), share.Config{ + HostPath: dir + string(filepath.Separator), + ReadOnly: false, + Restrict: true, + AllowedNames: []string{"config.json"}, + }, plan9Mount.Config{ReadOnly: false}). + Return(id, nil) + plan9Ctrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("/mnt/plan9/file0", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if spec.Mounts[0].Source != "/mnt/plan9/file0" { + t.Errorf("mount source = %q, want %q", spec.Mounts[0].Source, "/mnt/plan9/file0") + } +} + +// TestAllocateMounts_Plan9ReserveFailure verifies that a Plan9 Reserve failure +// propagates the error. +func TestAllocateMounts_Plan9ReserveFailure(t *testing.T) { + t.Parallel() + + c, _, plan9Ctrl := newMountsTestController(t) + hostDir := t.TempDir() + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: hostDir, + Destination: "/mnt/fail", + Type: mountTypeBind, + }, + }, + } + + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(guid.GUID{}, errPlan9Reserve) + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errPlan9Reserve) { + t.Errorf("error = %v, want wrapping %v", err, errPlan9Reserve) + } +} + +// TestAllocateMounts_Plan9MapToGuestFailure verifies that a Plan9 MapToGuest +// failure propagates the error. +func TestAllocateMounts_Plan9MapToGuestFailure(t *testing.T) { + t.Parallel() + + c, _, plan9Ctrl := newMountsTestController(t) + hostDir := t.TempDir() + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: hostDir, + Destination: "/mnt/fail", + Type: mountTypeBind, + }, + }, + } + + id, _ := guid.NewV4() + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(id, nil) + plan9Ctrl.EXPECT(). + MapToGuest(gomock.Any(), id). + Return("", errPlan9Map) + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, errPlan9Map) { + t.Errorf("error = %v, want wrapping %v", err, errPlan9Map) + } +} + +// TestAllocateMounts_Plan9StatFailure verifies that allocateMounts returns an +// error when os.Stat fails for a bind mount source. +func TestAllocateMounts_Plan9StatFailure(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\nonexistent\path\does\not\exist`, + Destination: "/mnt/missing", + Type: mountTypeBind, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error for stat of nonexistent path") + } +} + +// --- allocateMounts: skip / passthrough tests --- + +// TestAllocateMounts_HugePagesSkipped verifies that hugepages mounts are +// validated but skipped (no Plan9 or SCSI calls). +func TestAllocateMounts_HugePagesSkipped(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: guestpath.HugePagesMountPrefix + "2M/hugepages0", + Destination: "/dev/hugepages", + Type: mountTypeBind, + }, + }, + } + + // No SCSI or Plan9 calls expected. + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestAllocateMounts_HugePagesInvalidSize verifies that only 2M hugepages are +// accepted. +func TestAllocateMounts_HugePagesInvalidSize(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: guestpath.HugePagesMountPrefix + "1G/hugepages0", + Destination: "/dev/hugepages", + Type: mountTypeBind, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error for unsupported hugepage size") + } +} + +// TestAllocateMounts_HugePagesInvalidFormat verifies that an improperly +// formatted hugepages path is rejected. +func TestAllocateMounts_HugePagesInvalidFormat(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: guestpath.HugePagesMountPrefix + "bad", + Destination: "/dev/hugepages", + Type: mountTypeBind, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error for invalid hugepages format") + } +} + +// TestAllocateMounts_GuestInternalPathsSkipped verifies that sandbox://, sandbox-tmp://, +// and uvm:// prefixed mounts pass through without host-side allocation. +func TestAllocateMounts_GuestInternalPathsSkipped(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + }{ + {name: "sandbox", source: guestpath.SandboxMountPrefix + "/some/path"}, + {name: "sandbox-tmp", source: guestpath.SandboxTmpfsMountPrefix + "/tmp/path"}, + {name: "uvm", source: guestpath.UVMMountPrefix + "/uvm/path"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: tt.source, + Destination: "/mnt/guest", + Type: mountTypeBind, + }, + }, + } + + // No SCSI or Plan9 calls expected. + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Source should remain unchanged. + if spec.Mounts[0].Source != tt.source { + t.Errorf("source = %q, want %q", spec.Mounts[0].Source, tt.source) + } + }) + } +} + +// TestAllocateMounts_UnknownTypesPassThrough verifies that unknown mount types +// (e.g. tmpfs, proc) are passed through without error or host-side allocation. +func TestAllocateMounts_UnknownTypesPassThrough(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mountType string + }{ + {name: "tmpfs", mountType: "tmpfs"}, + {name: "proc", mountType: "proc"}, + {name: "devpts", mountType: "devpts"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: "none", + Destination: "/mnt/" + tt.mountType, + Type: tt.mountType, + }, + }, + } + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +// TestAllocateMounts_NoMounts verifies that allocateMounts succeeds when the +// spec contains no mounts. +func TestAllocateMounts_NoMounts(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{} + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.scsiResources) != 0 { + t.Errorf("expected 0 scsi resources, got %d", len(c.scsiResources)) + } + if len(c.plan9Resources) != 0 { + t.Errorf("expected 0 plan9 resources, got %d", len(c.plan9Resources)) + } +} + +// TestAllocateMounts_EmptySourceOrDestination verifies that a mount with an +// empty source or destination returns an error. +func TestAllocateMounts_EmptySourceOrDestination(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + src string + dst string + }{ + {name: "empty-source", src: "", dst: "/mnt/data"}, + {name: "empty-destination", src: `C:\disks\data.vhdx`, dst: ""}, + {name: "both-empty", src: "", dst: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: tt.src, + Destination: tt.dst, + Type: mountTypeVirtualDisk, + }, + }, + } + + if err := c.allocateMounts(t.Context(), spec); err == nil { + t.Fatal("expected error for empty source or destination") + } + }) + } +} + +// TestAllocateMounts_MultipleMixed verifies that allocateMounts correctly +// handles a mix of SCSI, Plan9, and passthrough mounts in a single spec. +func TestAllocateMounts_MultipleMixed(t *testing.T) { + t.Parallel() + stubResolvePath(t) + stubGrantVMAccess(t) + + c, scsiCtrl, plan9Ctrl := newMountsTestController(t) + hostDir := t.TempDir() + + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: `C:\disks\data.vhdx`, + Destination: "/mnt/data", + Type: mountTypeVirtualDisk, + }, + { + Source: hostDir, + Destination: "/mnt/hostdir", + Type: mountTypeBind, + }, + { + Source: "none", + Destination: "/proc", + Type: "proc", + }, + { + Source: guestpath.SandboxMountPrefix + "/sandbox/dir", + Destination: "/mnt/sandbox", + Type: mountTypeBind, + }, + }, + } + + scsiID, _ := guid.NewV4() + plan9ID, _ := guid.NewV4() + + // SCSI mount for virtual disk. + scsiCtrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(scsiID, nil) + scsiCtrl.EXPECT(). + MapToGuest(gomock.Any(), scsiID). + Return("/dev/sda", nil) + + // Plan9 mount for bind directory. + plan9Ctrl.EXPECT(). + Reserve(gomock.Any(), gomock.Any(), gomock.Any()). + Return(plan9ID, nil) + plan9Ctrl.EXPECT(). + MapToGuest(gomock.Any(), plan9ID). + Return("/mnt/plan9/share0", nil) + + if err := c.allocateMounts(t.Context(), spec); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(c.scsiResources) != 1 { + t.Errorf("expected 1 scsi resource, got %d", len(c.scsiResources)) + } + if len(c.plan9Resources) != 1 { + t.Errorf("expected 1 plan9 resource, got %d", len(c.plan9Resources)) + } + // SCSI mount rewritten. + if spec.Mounts[0].Source != "/dev/sda" { + t.Errorf("SCSI mount source = %q, want %q", spec.Mounts[0].Source, "/dev/sda") + } + // Plan9 mount rewritten. + if spec.Mounts[1].Source != "/mnt/plan9/share0" { + t.Errorf("Plan9 mount source = %q, want %q", spec.Mounts[1].Source, "/mnt/plan9/share0") + } + // Proc mount unchanged. + if spec.Mounts[2].Source != "none" { + t.Errorf("proc mount source = %q, want %q", spec.Mounts[2].Source, "none") + } + // Sandbox mount unchanged. + if spec.Mounts[3].Source != guestpath.SandboxMountPrefix+"/sandbox/dir" { + t.Errorf("sandbox mount source changed unexpectedly") + } +} + +// --- Helper function tests --- + +// TestParseExtensibleVirtualDiskPath verifies parsing of EVD paths. +func TestParseExtensibleVirtualDiskPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantType string + wantPath string + wantErr bool + }{ + { + name: "valid", + input: `evd://mytype/C:\disks\data.vhdx`, + wantType: "mytype", + wantPath: `C:\disks\data.vhdx`, + }, + { + name: "valid-nested", + input: "evd://provider/some/nested/path", + wantType: "provider", + wantPath: "some/nested/path", + }, + { + name: "missing-prefix", + input: "/no/evd/prefix", + wantErr: true, + }, + { + name: "no-type-separator", + input: "evd://", + wantErr: true, + }, + { + name: "empty-type", + input: "evd:///path", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + evdType, sourcePath, err := parseExtensibleVirtualDiskPath(tt.input) + + if tt.wantErr { + if err == nil { + t.Fatal("expected error") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evdType != tt.wantType { + t.Errorf("evdType = %q, want %q", evdType, tt.wantType) + } + if sourcePath != tt.wantPath { + t.Errorf("sourcePath = %q, want %q", sourcePath, tt.wantPath) + } + }) + } +} + +// TestValidateHugePageMount verifies the hugepages mount source validation. +func TestValidateHugePageMount(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + source string + wantErr bool + }{ + { + name: "valid-2M", + source: guestpath.HugePagesMountPrefix + "2M/hugepages0", + }, + { + name: "unsupported-1G", + source: guestpath.HugePagesMountPrefix + "1G/hugepages0", + wantErr: true, + }, + { + name: "missing-location", + source: guestpath.HugePagesMountPrefix + "2M", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateHugePageMount(tt.source) + if tt.wantErr && err == nil { + t.Fatal("expected error") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +// TestIsReadOnlyMount verifies read-only flag detection from mount options. +func TestIsReadOnlyMount(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + options []string + want bool + }{ + {name: "empty-options", options: nil, want: false}, + {name: "rw-only", options: []string{"rw"}, want: false}, + {name: "ro", options: []string{"ro"}, want: true}, + {name: "RO-case-insensitive", options: []string{"RO"}, want: true}, + {name: "ro-among-others", options: []string{"noatime", "ro", "nosuid"}, want: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + mount := &specs.Mount{Options: tt.options} + if got := isReadOnlyMount(mount); got != tt.want { + t.Errorf("isReadOnlyMount = %v, want %v", got, tt.want) + } + }) + } +} + +// TestIsGuestInternalPath verifies detection of guest-internal path prefixes. +func TestIsGuestInternalPath(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + want bool + }{ + {name: "sandbox", path: guestpath.SandboxMountPrefix + "/some/path", want: true}, + {name: "sandbox-tmp", path: guestpath.SandboxTmpfsMountPrefix + "/tmp/path", want: true}, + {name: "uvm", path: guestpath.UVMMountPrefix + "/uvm/path", want: true}, + {name: "host-path", path: `C:\some\host\path`, want: false}, + {name: "linux-path", path: "/some/linux/path", want: false}, + {name: "hugepages", path: guestpath.HugePagesMountPrefix + "2M/hp", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := isGuestInternalPath(tt.path); got != tt.want { + t.Errorf("isGuestInternalPath(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +// TestAllocateMounts_EVDInvalidPath verifies that an invalid EVD source path +// returns an error before any SCSI reservation. +func TestAllocateMounts_EVDInvalidPath(t *testing.T) { + t.Parallel() + + c, _, _ := newMountsTestController(t) + spec := &specs.Spec{ + Mounts: []specs.Mount{ + { + Source: "not-evd://missing", + Destination: "/mnt/bad", + Type: mountTypeExtensibleVirtualDisk, + }, + }, + } + + err := c.allocateMounts(t.Context(), spec) + if err == nil { + t.Fatal("expected error for invalid EVD path") + } +} diff --git a/internal/controller/linuxcontainer/state.go b/internal/controller/linuxcontainer/state.go new file mode 100644 index 0000000000..6f6ed963eb --- /dev/null +++ b/internal/controller/linuxcontainer/state.go @@ -0,0 +1,59 @@ +//go:build windows && lcow + +package linuxcontainer + +// State represents the current lifecycle state of the container. +// +// Normal progression: +// +// StateNotCreated → StateCreated → StateRunning → StateStopped +// +// Full state-transition table: +// +// Current State │ Trigger │ Next State +// ─────────────────┼──────────────────────────────────────────────────┼──────────────── +// StateNotCreated │ Create succeeds │ StateCreated +// StateNotCreated │ Create fails during resource allocation or later │ StateInvalid +// StateCreated │ Start succeeds │ StateRunning +// StateCreated │ Start fails │ StateInvalid +// StateRunning │ init process exits │ StateStopped +// StateStopped │ (terminal — no further transitions) │ — +// StateInvalid │ (terminal — no further transitions) │ — +type State int32 + +const ( + // StateNotCreated indicates the container has not been created yet. + StateNotCreated State = iota + + // StateCreated indicates the container has been created but not started. + StateCreated + + // StateRunning indicates the container has been started and is running. + StateRunning + + // StateStopped indicates the container's init process has exited and + // all host-side resources have been released. + StateStopped + + // StateInvalid indicates the container entered an unrecoverable failure + // during Create or Start. + StateInvalid +) + +// String returns a human-readable representation of the container State. +func (s State) String() string { + switch s { + case StateNotCreated: + return "NotCreated" + case StateCreated: + return "Created" + case StateRunning: + return "Running" + case StateStopped: + return "Stopped" + case StateInvalid: + return "Invalid" + default: + return "Unknown" + } +} diff --git a/internal/controller/linuxcontainer/types.go b/internal/controller/linuxcontainer/types.go new file mode 100644 index 0000000000..8b320efaa1 --- /dev/null +++ b/internal/controller/linuxcontainer/types.go @@ -0,0 +1,53 @@ +//go:build windows && lcow + +package linuxcontainer + +import ( + "context" + + plan9Mount "github.com/Microsoft/hcsshim/internal/controller/device/plan9/mount" + "github.com/Microsoft/hcsshim/internal/controller/device/plan9/share" + "github.com/Microsoft/hcsshim/internal/controller/device/scsi/disk" + scsiMount "github.com/Microsoft/hcsshim/internal/controller/device/scsi/mount" + "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + + "github.com/Microsoft/go-winio/pkg/guid" +) + +// CreateContainerOpts holds additional options for container creation. +type CreateContainerOpts struct { + IsScratchEncryptionEnabled bool +} + +// guest abstracts the UVM guest connection for container lifecycle operations. +type guest interface { + Capabilities() gcs.GuestDefinedCapabilities + CreateContainer(ctx context.Context, cid string, config interface{}) (*gcs.Container, error) + DeleteContainerState(ctx context.Context, cid string) error + + AddLCOWCombinedLayers(ctx context.Context, settings guestresource.LCOWCombinedLayers) error + RemoveLCOWCombinedLayers(ctx context.Context, settings guestresource.LCOWCombinedLayers) error +} + +// scsiController abstracts host-side SCSI disk reservation and guest mapping. +type scsiController interface { + Reserve(ctx context.Context, diskConfig disk.Config, mountConfig scsiMount.Config) (guid.GUID, error) + UnmapFromGuest(ctx context.Context, reservation guid.GUID) error + MapToGuest(ctx context.Context, id guid.GUID) (string, error) +} + +// plan9Controller abstracts host-side Plan9 share reservation and guest mapping. +type plan9Controller interface { + Reserve(ctx context.Context, shareConfig share.Config, mountConfig plan9Mount.Config) (guid.GUID, error) + UnmapFromGuest(ctx context.Context, reservation guid.GUID) error + MapToGuest(ctx context.Context, id guid.GUID) (string, error) +} + +// vPCIController abstracts host-side virtual PCI device reservation and VM assignment. +type vPCIController interface { + Reserve(ctx context.Context, device vpci.Device) (guid.GUID, error) + RemoveFromVM(ctx context.Context, vmBusGUID guid.GUID) error + AddToVM(ctx context.Context, vmBusGUID guid.GUID) error +} diff --git a/internal/controller/pod/doc.go b/internal/controller/pod/doc.go new file mode 100644 index 0000000000..770f2d6d25 --- /dev/null +++ b/internal/controller/pod/doc.go @@ -0,0 +1,13 @@ +//go:build windows && (lcow || wcow) + +// Package pod provides a controller for managing a single pod +// running inside a Utility VM (UVM). It owns the network controller and +// tracks all container controllers belonging to the pod. +// +// # Responsibilities +// +// - Setting up and tearing down the pod-level network namespace via +// the [network.Controller]. +// - Creating, retrieving, listing, and deleting container controllers +// within the pod. +package pod diff --git a/internal/controller/pod/mocks/mock_types.go b/internal/controller/pod/mocks/mock_types.go new file mode 100644 index 0000000000..870add42d4 --- /dev/null +++ b/internal/controller/pod/mocks/mock_types.go @@ -0,0 +1,184 @@ +//go:build windows && lcow + +// Code generated by MockGen. DO NOT EDIT. +// Source: types_lcow.go +// +// Generated by this command: +// +// mockgen -build_flags=-tags=lcow -build_constraint=windows && lcow -source types_lcow.go -package mocks -destination mocks/mock_types.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + plan9 "github.com/Microsoft/hcsshim/internal/controller/device/plan9" + scsi "github.com/Microsoft/hcsshim/internal/controller/device/scsi" + vpci "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + network "github.com/Microsoft/hcsshim/internal/controller/network" + guestmanager "github.com/Microsoft/hcsshim/internal/vm/guestmanager" + gomock "go.uber.org/mock/gomock" +) + +// MockvmController is a mock of vmController interface. +type MockvmController struct { + ctrl *gomock.Controller + recorder *MockvmControllerMockRecorder + isgomock struct{} +} + +// MockvmControllerMockRecorder is the mock recorder for MockvmController. +type MockvmControllerMockRecorder struct { + mock *MockvmController +} + +// NewMockvmController creates a new mock instance. +func NewMockvmController(ctrl *gomock.Controller) *MockvmController { + mock := &MockvmController{ctrl: ctrl} + mock.recorder = &MockvmControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockvmController) EXPECT() *MockvmControllerMockRecorder { + return m.recorder +} + +// Guest mocks base method. +func (m *MockvmController) Guest() *guestmanager.Guest { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Guest") + ret0, _ := ret[0].(*guestmanager.Guest) + return ret0 +} + +// Guest indicates an expected call of Guest. +func (mr *MockvmControllerMockRecorder) Guest() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Guest", reflect.TypeOf((*MockvmController)(nil).Guest)) +} + +// NetworkController mocks base method. +func (m *MockvmController) NetworkController() *network.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NetworkController") + ret0, _ := ret[0].(*network.Controller) + return ret0 +} + +// NetworkController indicates an expected call of NetworkController. +func (mr *MockvmControllerMockRecorder) NetworkController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NetworkController", reflect.TypeOf((*MockvmController)(nil).NetworkController)) +} + +// Plan9Controller mocks base method. +func (m *MockvmController) Plan9Controller() *plan9.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Plan9Controller") + ret0, _ := ret[0].(*plan9.Controller) + return ret0 +} + +// Plan9Controller indicates an expected call of Plan9Controller. +func (mr *MockvmControllerMockRecorder) Plan9Controller() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Plan9Controller", reflect.TypeOf((*MockvmController)(nil).Plan9Controller)) +} + +// RuntimeID mocks base method. +func (m *MockvmController) RuntimeID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RuntimeID") + ret0, _ := ret[0].(string) + return ret0 +} + +// RuntimeID indicates an expected call of RuntimeID. +func (mr *MockvmControllerMockRecorder) RuntimeID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RuntimeID", reflect.TypeOf((*MockvmController)(nil).RuntimeID)) +} + +// SCSIController mocks base method. +func (m *MockvmController) SCSIController() *scsi.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SCSIController") + ret0, _ := ret[0].(*scsi.Controller) + return ret0 +} + +// SCSIController indicates an expected call of SCSIController. +func (mr *MockvmControllerMockRecorder) SCSIController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SCSIController", reflect.TypeOf((*MockvmController)(nil).SCSIController)) +} + +// VPCIController mocks base method. +func (m *MockvmController) VPCIController() *vpci.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VPCIController") + ret0, _ := ret[0].(*vpci.Controller) + return ret0 +} + +// VPCIController indicates an expected call of VPCIController. +func (mr *MockvmControllerMockRecorder) VPCIController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VPCIController", reflect.TypeOf((*MockvmController)(nil).VPCIController)) +} + +// MocknetworkController is a mock of networkController interface. +type MocknetworkController struct { + ctrl *gomock.Controller + recorder *MocknetworkControllerMockRecorder + isgomock struct{} +} + +// MocknetworkControllerMockRecorder is the mock recorder for MocknetworkController. +type MocknetworkControllerMockRecorder struct { + mock *MocknetworkController +} + +// NewMocknetworkController creates a new mock instance. +func NewMocknetworkController(ctrl *gomock.Controller) *MocknetworkController { + mock := &MocknetworkController{ctrl: ctrl} + mock.recorder = &MocknetworkControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MocknetworkController) EXPECT() *MocknetworkControllerMockRecorder { + return m.recorder +} + +// Setup mocks base method. +func (m *MocknetworkController) Setup(ctx context.Context, opts *network.SetupOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Setup", ctx, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// Setup indicates an expected call of Setup. +func (mr *MocknetworkControllerMockRecorder) Setup(ctx, opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Setup", reflect.TypeOf((*MocknetworkController)(nil).Setup), ctx, opts) +} + +// Teardown mocks base method. +func (m *MocknetworkController) Teardown(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Teardown", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Teardown indicates an expected call of Teardown. +func (mr *MocknetworkControllerMockRecorder) Teardown(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Teardown", reflect.TypeOf((*MocknetworkController)(nil).Teardown), ctx) +} diff --git a/internal/controller/pod/pod_lcow.go b/internal/controller/pod/pod_lcow.go new file mode 100644 index 0000000000..cb02502890 --- /dev/null +++ b/internal/controller/pod/pod_lcow.go @@ -0,0 +1,130 @@ +//go:build windows && lcow + +package pod + +import ( + "context" + "fmt" + "sync" + + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer" + "github.com/Microsoft/hcsshim/internal/controller/network" +) + +// Controller manages the lifecycle of a single pod inside a Utility VM. +type Controller struct { + mu sync.RWMutex + + // podID is the containerd facing pod identifier. + podID string + + // gcsPodID is the identifier used when communicating with the GCS. + gcsPodID string + + // vm is the parent Utility VM that hosts this pod. + vm vmController + + // network manages the network namespace and endpoint lifecycle + // for this pod. + network networkController + + // containers maps containerID → [linuxcontainer.Controller] for every + // live container in this pod. Access must be guarded by mu. + containers map[string]*linuxcontainer.Controller +} + +// New creates a ready-to-use [Controller] for the given pod. +func New( + podID string, + vm vmController, +) *Controller { + return &Controller{ + podID: podID, + // Same ID is used as the pod. Post migration, we can always change + // the primary ID while GCS continues to use the original one. + gcsPodID: podID, + vm: vm, + network: vm.NetworkController(), + containers: make(map[string]*linuxcontainer.Controller), + } +} + +// SetupNetwork performs network setup for the pod. +func (c *Controller) SetupNetwork(ctx context.Context, opts *network.SetupOptions) error { + if err := c.network.Setup(ctx, opts); err != nil { + return fmt.Errorf("setup network for pod %s: %w", c.podID, err) + } + return nil +} + +// TeardownNetwork performs network teardown for the pod. +func (c *Controller) TeardownNetwork(ctx context.Context) error { + if err := c.network.Teardown(ctx); err != nil { + return fmt.Errorf("teardown network for pod %s: %w", c.podID, err) + } + return nil +} + +// GetContainer returns the container controller for the given containerID. +func (c *Controller) GetContainer(containerID string) (*linuxcontainer.Controller, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + containerCtrl, ok := c.containers[containerID] + if !ok { + return nil, fmt.Errorf("container %q not found in pod %q", containerID, c.podID) + } + + return containerCtrl, nil +} + +// NewContainer creates a new [linuxcontainer.Controller] and registers it +// in this pod. +func (c *Controller) NewContainer(ctx context.Context, containerID string) (*linuxcontainer.Controller, error) { + c.mu.Lock() + defer c.mu.Unlock() + + // Ensure we don't create a duplicate container controller. + if _, ok := c.containers[containerID]; ok { + return nil, fmt.Errorf("container %q already exists in pod %q", containerID, c.podID) + } + + containerCtrl := linuxcontainer.New( + c.vm.RuntimeID(), + c.gcsPodID, + containerID, + c.vm.Guest(), + c.vm.SCSIController(), + c.vm.Plan9Controller(), + c.vm.VPCIController(), + ) + c.containers[containerID] = containerCtrl + return containerCtrl, nil +} + +// ListContainers returns a snapshot of all live container controllers in +// this pod, keyed by container ID. +func (c *Controller) ListContainers() map[string]*linuxcontainer.Controller { + c.mu.RLock() + defer c.mu.RUnlock() + + result := make(map[string]*linuxcontainer.Controller, len(c.containers)) + for containerID, containerCtrl := range c.containers { + result[containerID] = containerCtrl + } + + return result +} + +// DeleteContainer removes a container from the pod's container map. +func (c *Controller) DeleteContainer(ctx context.Context, containerID string) error { + c.mu.Lock() + defer c.mu.Unlock() + + if _, ok := c.containers[containerID]; !ok { + return fmt.Errorf("container %q not found in pod %q", containerID, c.podID) + } + + delete(c.containers, containerID) + return nil +} diff --git a/internal/controller/pod/pod_lcow_test.go b/internal/controller/pod/pod_lcow_test.go new file mode 100644 index 0000000000..5e84715542 --- /dev/null +++ b/internal/controller/pod/pod_lcow_test.go @@ -0,0 +1,322 @@ +//go:build windows && lcow + +package pod + +import ( + "errors" + "testing" + + "go.uber.org/mock/gomock" + + "github.com/Microsoft/hcsshim/internal/controller/linuxcontainer" + "github.com/Microsoft/hcsshim/internal/controller/network" + "github.com/Microsoft/hcsshim/internal/controller/pod/mocks" +) + +const testPodID = "test-pod-1234" + +// errTest is a sentinel used in table-driven tests to verify error propagation. +var errTest = errors.New("test error") + +// newSetup creates a gomock controller, vm mock, network mock, and a pod +// Controller wired together. Mock expectations are verified automatically via +// t.Cleanup. +func newSetup(t *testing.T) (*mocks.MockvmController, *mocks.MocknetworkController, *Controller) { + t.Helper() + mc := gomock.NewController(t) + vm := mocks.NewMockvmController(mc) + net := mocks.NewMocknetworkController(mc) + return vm, net, &Controller{ + podID: testPodID, + gcsPodID: testPodID, + vm: vm, + network: net, + containers: make(map[string]*linuxcontainer.Controller), + } +} + +// expectVMCallsForNewContainer sets up the mock expectations on vmController +// that NewContainer triggers when constructing a linuxcontainer.Controller. +func expectVMCallsForNewContainer(vm *mocks.MockvmController) { + vm.EXPECT().RuntimeID().Return("vm-runtime-1") + vm.EXPECT().Guest().Return(nil) + vm.EXPECT().SCSIController().Return(nil) + vm.EXPECT().Plan9Controller().Return(nil) + vm.EXPECT().VPCIController().Return(nil) +} + +// TestNew_InitializesFields verifies that New sets up all fields correctly and +// that the containers map is non-nil and empty. +func TestNew_InitializesFields(t *testing.T) { + mc := gomock.NewController(t) + vm := mocks.NewMockvmController(mc) + vm.EXPECT().NetworkController().Return(nil) + + c := New(testPodID, vm) + + if c.podID != testPodID { + t.Errorf("expected podID=%q, got %q", testPodID, c.podID) + } + if c.gcsPodID != testPodID { + t.Errorf("expected gcsPodID=%q, got %q", testPodID, c.gcsPodID) + } + if c.containers == nil { + t.Fatal("expected non-nil containers map") + } + if len(c.containers) != 0 { + t.Errorf("expected empty containers map, got %d entries", len(c.containers)) + } +} + +// TestSetupNetwork verifies that SetupNetwork delegates to the network controller. +func TestSetupNetwork(t *testing.T) { + opts := &network.SetupOptions{NetworkNamespace: "ns-1234"} + + tests := []struct { + name string + retErr error + }{ + {"happy path", nil}, + {"fails", errTest}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, net, c := newSetup(t) + net.EXPECT().Setup(gomock.Any(), opts).Return(tt.retErr) + + err := c.SetupNetwork(t.Context(), opts) + if !errors.Is(err, tt.retErr) { + t.Errorf("SetupNetwork() error = %v, wantErr %v", err, tt.retErr) + } + }) + } +} + +// TestTeardownNetwork verifies that TeardownNetwork delegates to the network controller. +func TestTeardownNetwork(t *testing.T) { + tests := []struct { + name string + retErr error + }{ + {"happy path", nil}, + {"fails", errTest}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, net, c := newSetup(t) + net.EXPECT().Teardown(gomock.Any()).Return(tt.retErr) + + err := c.TeardownNetwork(t.Context()) + if !errors.Is(err, tt.retErr) { + t.Errorf("TeardownNetwork() error = %v, wantErr %v", err, tt.retErr) + } + }) + } +} + +// TestGetContainer verifies retrieval of registered and unknown containers. +func TestGetContainer(t *testing.T) { + t.Run("exists", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + created, err := c.NewContainer(t.Context(), "container-1") + if err != nil { + t.Fatalf("NewContainer: %v", err) + } + + got, err := c.GetContainer("container-1") + if err != nil { + t.Fatalf("GetContainer: %v", err) + } + if got != created { + t.Error("GetContainer returned a different controller than NewContainer") + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, c := newSetup(t) + if _, err := c.GetContainer("nonexistent"); err == nil { + t.Fatal("expected error for unknown container ID") + } + }) +} + +// TestNewContainer verifies creating containers with new, duplicate, and multiple IDs. +func TestNewContainer(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + containerCtrl, err := c.NewContainer(t.Context(), "container-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if containerCtrl == nil { + t.Fatal("expected non-nil container controller") + } + if _, ok := c.containers["container-1"]; !ok { + t.Error("container not in map after NewContainer") + } + }) + + t.Run("duplicate", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + if _, err := c.NewContainer(t.Context(), "container-dup"); err != nil { + t.Fatalf("first NewContainer: %v", err) + } + if _, err := c.NewContainer(t.Context(), "container-dup"); err == nil { + t.Fatal("expected error for duplicate container ID") + } + }) + + t.Run("multiple different", func(t *testing.T) { + vm, _, c := newSetup(t) + ids := []string{"container-a", "container-b", "container-c"} + for _, id := range ids { + expectVMCallsForNewContainer(vm) + if _, err := c.NewContainer(t.Context(), id); err != nil { + t.Fatalf("NewContainer(%q): %v", id, err) + } + } + + if len(c.containers) != len(ids) { + t.Errorf("expected %d containers, got %d", len(ids), len(c.containers)) + } + for _, id := range ids { + if _, ok := c.containers[id]; !ok { + t.Errorf("container %q missing from map", id) + } + } + }) +} + +// TestListContainers verifies snapshots of the live container map. +func TestListContainers(t *testing.T) { + t.Run("empty", func(t *testing.T) { + _, _, c := newSetup(t) + if result := c.ListContainers(); len(result) != 0 { + t.Errorf("expected empty map, got %d entries", len(result)) + } + }) + + t.Run("multiple containers", func(t *testing.T) { + vm, _, c := newSetup(t) + ids := []string{"container-x", "container-y"} + for _, id := range ids { + expectVMCallsForNewContainer(vm) + if _, err := c.NewContainer(t.Context(), id); err != nil { + t.Fatalf("NewContainer(%q): %v", id, err) + } + } + + result := c.ListContainers() + if len(result) != len(ids) { + t.Errorf("expected %d containers, got %d", len(ids), len(result)) + } + for _, id := range ids { + if _, ok := result[id]; !ok { + t.Errorf("container %q missing from ListContainers result", id) + } + } + + // Verify that the returned map is a snapshot: mutating it must not + // affect the internal state. + delete(result, ids[0]) + if _, ok := c.containers[ids[0]]; !ok { + t.Error("deleting from ListContainers result should not affect internal map") + } + }) +} + +// TestDeleteContainer verifies the full create → delete lifecycle and error cases. +func TestDeleteContainer(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + if _, err := c.NewContainer(t.Context(), "container-del"); err != nil { + t.Fatalf("NewContainer: %v", err) + } + if err := c.DeleteContainer(t.Context(), "container-del"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := c.containers["container-del"]; ok { + t.Error("container still in map after DeleteContainer") + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, c := newSetup(t) + if err := c.DeleteContainer(t.Context(), "nonexistent"); err == nil { + t.Fatal("expected error for unknown container ID") + } + }) + + t.Run("then get fails", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + if _, err := c.NewContainer(t.Context(), "container-gone"); err != nil { + t.Fatalf("NewContainer: %v", err) + } + if err := c.DeleteContainer(t.Context(), "container-gone"); err != nil { + t.Fatalf("DeleteContainer: %v", err) + } + if _, err := c.GetContainer("container-gone"); err == nil { + t.Fatal("expected error from GetContainer after deletion") + } + }) + + t.Run("double fails", func(t *testing.T) { + vm, _, c := newSetup(t) + expectVMCallsForNewContainer(vm) + + if _, err := c.NewContainer(t.Context(), "container-double-del"); err != nil { + t.Fatalf("NewContainer: %v", err) + } + if err := c.DeleteContainer(t.Context(), "container-double-del"); err != nil { + t.Fatalf("first DeleteContainer: %v", err) + } + if err := c.DeleteContainer(t.Context(), "container-double-del"); err == nil { + t.Fatal("expected error on second DeleteContainer") + } + }) +} + +// TestNewContainer_AfterDelete verifies that a container can be re-created with +// the same ID after the original has been deleted. +func TestNewContainer_AfterDelete(t *testing.T) { + vm, _, c := newSetup(t) + ctx := t.Context() + + // First lifecycle. + expectVMCallsForNewContainer(vm) + first, err := c.NewContainer(ctx, "container-recreate") + if err != nil { + t.Fatalf("first NewContainer: %v", err) + } + if err := c.DeleteContainer(ctx, "container-recreate"); err != nil { + t.Fatalf("DeleteContainer: %v", err) + } + + // Re-create with the same ID. + expectVMCallsForNewContainer(vm) + second, err := c.NewContainer(ctx, "container-recreate") + if err != nil { + t.Fatalf("second NewContainer: %v", err) + } + if first == second { + t.Error("expected a new controller instance after re-creation") + } + + got, err := c.GetContainer("container-recreate") + if err != nil { + t.Fatalf("GetContainer after re-creation: %v", err) + } + if got != second { + t.Error("GetContainer returned the old controller after re-creation") + } +} diff --git a/internal/controller/pod/types_lcow.go b/internal/controller/pod/types_lcow.go new file mode 100644 index 0000000000..5c63bd1e47 --- /dev/null +++ b/internal/controller/pod/types_lcow.go @@ -0,0 +1,46 @@ +//go:build windows && lcow + +package pod + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/controller/device/plan9" + "github.com/Microsoft/hcsshim/internal/controller/device/scsi" + "github.com/Microsoft/hcsshim/internal/controller/device/vpci" + "github.com/Microsoft/hcsshim/internal/controller/network" + "github.com/Microsoft/hcsshim/internal/vm/guestmanager" +) + +// vmController exposes the subset of the VM manager that the pod controller +// needs: identity, guest access, device controllers, and the network controller. +// Implemented by the VM controller (e.g., vmmanager.UtilityVM). +type vmController interface { + // RuntimeID returns the unique runtime identifier for the VM. + RuntimeID() string + + // Guest returns the guest manager used for guest-side operations. + Guest() *guestmanager.Guest + + // SCSIController returns the SCSI device controller for the VM. + SCSIController() *scsi.Controller + + // VPCIController returns the vPCI device controller for the VM. + VPCIController() *vpci.Controller + + // Plan9Controller returns the Plan9 share controller for the VM. + Plan9Controller() *plan9.Controller + + // NetworkController returns the network controller for the VM. + NetworkController() *network.Controller +} + +// networkController is the narrow interface used by the pod to set up and +// tear down the network namespace. Implemented by [network.Controller]. +type networkController interface { + // Setup performs network setup for the pod. + Setup(ctx context.Context, opts *network.SetupOptions) error + + // Teardown performs network teardown for the pod. + Teardown(ctx context.Context) error +} diff --git a/internal/controller/process/doc.go b/internal/controller/process/doc.go new file mode 100644 index 0000000000..6ac9c8816e --- /dev/null +++ b/internal/controller/process/doc.go @@ -0,0 +1,53 @@ +//go:build windows && (lcow || wcow) + +// Package process provides a controller for managing individual process +// (exec) instances within a container. It handles the full lifecycle from +// creation through exit, including IO plumbing, signal delivery, and exit +// status reporting. +// +// # Lifecycle +// +// [Controller] drives a single process through a linear state machine: +// +// ┌───────────────────┐ +// │ StateNotCreated │ +// └────────┬──────────┘ +// │ Create +// ▼ +// ┌───────────────────┐ +// │ StateCreated │── Start fails / Kill / Delete──┐ +// └────────┬──────────┘ │ +// │ Start ok │ +// ▼ │ +// ┌───────────────────┐ │ +// │ StateRunning │──── process exits / Kill ──────┤ +// └───────────────────┘ │ +// ▼ +// ┌───────────────────┐ +// │ StateTerminated │ +// └───────────────────┘ +// +// - [Controller.Create] sets up upstream IO connections and stores the +// process spec. The controller transitions from StateNotCreated to +// StateCreated. +// - [Controller.Start] launches the process inside the hosting system +// and spawns a background goroutine to monitor exit. The controller +// transitions from StateCreated to StateRunning. +// - [Controller.Kill] delivers a signal to a running process or +// terminates a created-but-not-started process. +// - [Controller.Delete] prepares the process for removal from the +// container's process table. For a created-but-never-started process, +// it transitions to StateTerminated and releases its IO resources. +// - [Controller.Wait] blocks until the process exits or the context +// is cancelled. +// - [Controller.Status] returns the current containerd-compatible state +// of the process. +// +// # Exit Handling +// +// When a process is started, a background goroutine waits for the process +// to exit, records the exit code and timestamp, drains all IO copies, and +// publishes a TaskExit event via the caller-supplied channel. The +// exitedCh channel is closed once all cleanup is complete, unblocking any +// [Controller.Wait] callers. +package process diff --git a/internal/controller/process/mocks/mock_cow.go b/internal/controller/process/mocks/mock_cow.go new file mode 100644 index 0000000000..be3c0c63bd --- /dev/null +++ b/internal/controller/process/mocks/mock_cow.go @@ -0,0 +1,271 @@ +//go:build windows && (lcow || wcow) + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Microsoft/hcsshim/internal/cow (interfaces: Process,ProcessHost) +// +// Generated by this command: +// +// mockgen -build_constraint=windows && (lcow || wcow) -package mocks -destination mocks/mock_cow.go github.com/Microsoft/hcsshim/internal/cow Process,ProcessHost +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + io "io" + reflect "reflect" + + cow "github.com/Microsoft/hcsshim/internal/cow" + gomock "go.uber.org/mock/gomock" +) + +// MockProcess is a mock of Process interface. +type MockProcess struct { + ctrl *gomock.Controller + recorder *MockProcessMockRecorder + isgomock struct{} +} + +// MockProcessMockRecorder is the mock recorder for MockProcess. +type MockProcessMockRecorder struct { + mock *MockProcess +} + +// NewMockProcess creates a new mock instance. +func NewMockProcess(ctrl *gomock.Controller) *MockProcess { + mock := &MockProcess{ctrl: ctrl} + mock.recorder = &MockProcessMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProcess) EXPECT() *MockProcessMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockProcess) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockProcessMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockProcess)(nil).Close)) +} + +// CloseStderr mocks base method. +func (m *MockProcess) CloseStderr(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStderr", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStderr indicates an expected call of CloseStderr. +func (mr *MockProcessMockRecorder) CloseStderr(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStderr", reflect.TypeOf((*MockProcess)(nil).CloseStderr), ctx) +} + +// CloseStdin mocks base method. +func (m *MockProcess) CloseStdin(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStdin", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStdin indicates an expected call of CloseStdin. +func (mr *MockProcessMockRecorder) CloseStdin(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStdin", reflect.TypeOf((*MockProcess)(nil).CloseStdin), ctx) +} + +// CloseStdout mocks base method. +func (m *MockProcess) CloseStdout(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStdout", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStdout indicates an expected call of CloseStdout. +func (mr *MockProcessMockRecorder) CloseStdout(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStdout", reflect.TypeOf((*MockProcess)(nil).CloseStdout), ctx) +} + +// ExitCode mocks base method. +func (m *MockProcess) ExitCode() (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExitCode") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExitCode indicates an expected call of ExitCode. +func (mr *MockProcessMockRecorder) ExitCode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExitCode", reflect.TypeOf((*MockProcess)(nil).ExitCode)) +} + +// Kill mocks base method. +func (m *MockProcess) Kill(ctx context.Context) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Kill", ctx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Kill indicates an expected call of Kill. +func (mr *MockProcessMockRecorder) Kill(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kill", reflect.TypeOf((*MockProcess)(nil).Kill), ctx) +} + +// Pid mocks base method. +func (m *MockProcess) Pid() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Pid") + ret0, _ := ret[0].(int) + return ret0 +} + +// Pid indicates an expected call of Pid. +func (mr *MockProcessMockRecorder) Pid() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pid", reflect.TypeOf((*MockProcess)(nil).Pid)) +} + +// ResizeConsole mocks base method. +func (m *MockProcess) ResizeConsole(ctx context.Context, width, height uint16) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResizeConsole", ctx, width, height) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResizeConsole indicates an expected call of ResizeConsole. +func (mr *MockProcessMockRecorder) ResizeConsole(ctx, width, height any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResizeConsole", reflect.TypeOf((*MockProcess)(nil).ResizeConsole), ctx, width, height) +} + +// Signal mocks base method. +func (m *MockProcess) Signal(ctx context.Context, options any) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Signal", ctx, options) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Signal indicates an expected call of Signal. +func (mr *MockProcessMockRecorder) Signal(ctx, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signal", reflect.TypeOf((*MockProcess)(nil).Signal), ctx, options) +} + +// Stdio mocks base method. +func (m *MockProcess) Stdio() (io.Writer, io.Reader, io.Reader) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stdio") + ret0, _ := ret[0].(io.Writer) + ret1, _ := ret[1].(io.Reader) + ret2, _ := ret[2].(io.Reader) + return ret0, ret1, ret2 +} + +// Stdio indicates an expected call of Stdio. +func (mr *MockProcessMockRecorder) Stdio() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stdio", reflect.TypeOf((*MockProcess)(nil).Stdio)) +} + +// Wait mocks base method. +func (m *MockProcess) Wait() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait") + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockProcessMockRecorder) Wait() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockProcess)(nil).Wait)) +} + +// MockProcessHost is a mock of ProcessHost interface. +type MockProcessHost struct { + ctrl *gomock.Controller + recorder *MockProcessHostMockRecorder + isgomock struct{} +} + +// MockProcessHostMockRecorder is the mock recorder for MockProcessHost. +type MockProcessHostMockRecorder struct { + mock *MockProcessHost +} + +// NewMockProcessHost creates a new mock instance. +func NewMockProcessHost(ctrl *gomock.Controller) *MockProcessHost { + mock := &MockProcessHost{ctrl: ctrl} + mock.recorder = &MockProcessHostMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProcessHost) EXPECT() *MockProcessHostMockRecorder { + return m.recorder +} + +// CreateProcess mocks base method. +func (m *MockProcessHost) CreateProcess(ctx context.Context, config any) (cow.Process, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateProcess", ctx, config) + ret0, _ := ret[0].(cow.Process) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateProcess indicates an expected call of CreateProcess. +func (mr *MockProcessHostMockRecorder) CreateProcess(ctx, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProcess", reflect.TypeOf((*MockProcessHost)(nil).CreateProcess), ctx, config) +} + +// IsOCI mocks base method. +func (m *MockProcessHost) IsOCI() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsOCI") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsOCI indicates an expected call of IsOCI. +func (mr *MockProcessHostMockRecorder) IsOCI() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsOCI", reflect.TypeOf((*MockProcessHost)(nil).IsOCI)) +} + +// OS mocks base method. +func (m *MockProcessHost) OS() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OS") + ret0, _ := ret[0].(string) + return ret0 +} + +// OS indicates an expected call of OS. +func (mr *MockProcessHostMockRecorder) OS() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OS", reflect.TypeOf((*MockProcessHost)(nil).OS)) +} diff --git a/internal/controller/process/mocks/mock_upstream_io.go b/internal/controller/process/mocks/mock_upstream_io.go new file mode 100644 index 0000000000..2c011bfe96 --- /dev/null +++ b/internal/controller/process/mocks/mock_upstream_io.go @@ -0,0 +1,166 @@ +//go:build windows && (lcow || wcow) + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Microsoft/hcsshim/internal/cmd (interfaces: UpstreamIO) +// +// Generated by this command: +// +// mockgen -build_constraint=windows && (lcow || wcow) -package mocks -destination mocks/mock_upstream_io.go github.com/Microsoft/hcsshim/internal/cmd UpstreamIO +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + io "io" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockUpstreamIO is a mock of UpstreamIO interface. +type MockUpstreamIO struct { + ctrl *gomock.Controller + recorder *MockUpstreamIOMockRecorder + isgomock struct{} +} + +// MockUpstreamIOMockRecorder is the mock recorder for MockUpstreamIO. +type MockUpstreamIOMockRecorder struct { + mock *MockUpstreamIO +} + +// NewMockUpstreamIO creates a new mock instance. +func NewMockUpstreamIO(ctrl *gomock.Controller) *MockUpstreamIO { + mock := &MockUpstreamIO{ctrl: ctrl} + mock.recorder = &MockUpstreamIOMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUpstreamIO) EXPECT() *MockUpstreamIOMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockUpstreamIO) Close(ctx context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close", ctx) +} + +// Close indicates an expected call of Close. +func (mr *MockUpstreamIOMockRecorder) Close(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockUpstreamIO)(nil).Close), ctx) +} + +// CloseStdin mocks base method. +func (m *MockUpstreamIO) CloseStdin(ctx context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CloseStdin", ctx) +} + +// CloseStdin indicates an expected call of CloseStdin. +func (mr *MockUpstreamIOMockRecorder) CloseStdin(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStdin", reflect.TypeOf((*MockUpstreamIO)(nil).CloseStdin), ctx) +} + +// Stderr mocks base method. +func (m *MockUpstreamIO) Stderr() io.Writer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stderr") + ret0, _ := ret[0].(io.Writer) + return ret0 +} + +// Stderr indicates an expected call of Stderr. +func (mr *MockUpstreamIOMockRecorder) Stderr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stderr", reflect.TypeOf((*MockUpstreamIO)(nil).Stderr)) +} + +// StderrPath mocks base method. +func (m *MockUpstreamIO) StderrPath() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StderrPath") + ret0, _ := ret[0].(string) + return ret0 +} + +// StderrPath indicates an expected call of StderrPath. +func (mr *MockUpstreamIOMockRecorder) StderrPath() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StderrPath", reflect.TypeOf((*MockUpstreamIO)(nil).StderrPath)) +} + +// Stdin mocks base method. +func (m *MockUpstreamIO) Stdin() io.Reader { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stdin") + ret0, _ := ret[0].(io.Reader) + return ret0 +} + +// Stdin indicates an expected call of Stdin. +func (mr *MockUpstreamIOMockRecorder) Stdin() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stdin", reflect.TypeOf((*MockUpstreamIO)(nil).Stdin)) +} + +// StdinPath mocks base method. +func (m *MockUpstreamIO) StdinPath() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StdinPath") + ret0, _ := ret[0].(string) + return ret0 +} + +// StdinPath indicates an expected call of StdinPath. +func (mr *MockUpstreamIOMockRecorder) StdinPath() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StdinPath", reflect.TypeOf((*MockUpstreamIO)(nil).StdinPath)) +} + +// Stdout mocks base method. +func (m *MockUpstreamIO) Stdout() io.Writer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stdout") + ret0, _ := ret[0].(io.Writer) + return ret0 +} + +// Stdout indicates an expected call of Stdout. +func (mr *MockUpstreamIOMockRecorder) Stdout() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stdout", reflect.TypeOf((*MockUpstreamIO)(nil).Stdout)) +} + +// StdoutPath mocks base method. +func (m *MockUpstreamIO) StdoutPath() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StdoutPath") + ret0, _ := ret[0].(string) + return ret0 +} + +// StdoutPath indicates an expected call of StdoutPath. +func (mr *MockUpstreamIOMockRecorder) StdoutPath() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StdoutPath", reflect.TypeOf((*MockUpstreamIO)(nil).StdoutPath)) +} + +// Terminal mocks base method. +func (m *MockUpstreamIO) Terminal() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Terminal") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Terminal indicates an expected call of Terminal. +func (mr *MockUpstreamIOMockRecorder) Terminal() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Terminal", reflect.TypeOf((*MockUpstreamIO)(nil).Terminal)) +} diff --git a/internal/controller/process/process.go b/internal/controller/process/process.go new file mode 100644 index 0000000000..c6d7657301 --- /dev/null +++ b/internal/controller/process/process.go @@ -0,0 +1,374 @@ +//go:build windows && (lcow || wcow) + +package process + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/Microsoft/hcsshim/internal/cmd" + "github.com/Microsoft/hcsshim/internal/cow" + "github.com/Microsoft/hcsshim/internal/hcs" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/logfields" + eventstypes "github.com/containerd/containerd/api/events" + + "github.com/containerd/containerd/api/runtime/task/v2" + "github.com/containerd/errdefs" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Controller manages the lifecycle of a single process (init or exec) +// within a container. +type Controller struct { + mu sync.RWMutex + + // containerID is the ID of the owning container. + // This is the client facing containerID. + containerID string + + // execID is the unique identifier for this exec instance. + // The init process uses an empty string. + execID string + + // hostingSystem is the UVM connection that hosts this process. + hostingSystem cow.ProcessHost + + // ioRetryTimeout is the duration to retry IO connection setup. + ioRetryTimeout time.Duration + + // state is the current lifecycle state. + state State + + // process is the underlying OS process handle, returned by cmd. + process cow.Process + + // processID is the OS-level PID, set after Start. + processID int + + // upstreamIO holds the upstream IO connections (stdin/stdout/stderr pipes). + // Set during Create, consumed during Start. + upstreamIO cmd.UpstreamIO + + // bundle is the path to the OCI bundle directory. + bundle string + + // processSpec is the OCI process spec for exec processes. + processSpec *specs.Process + + // exitedAt is the timestamp when the process exited. + exitedAt time.Time + + // exitCode is the process exit code. + // Defaults to 255 for processes that have not exited. + exitCode uint32 + + // exitedCh is closed when the process has exited and all cleanup is done. + exitedCh chan struct{} +} + +// New creates a [Controller] for a process in the given container. +func New(containerID string, execID string, hostingSystem cow.ProcessHost, ioRetryTimeout time.Duration) *Controller { + return &Controller{ + containerID: containerID, + execID: execID, + hostingSystem: hostingSystem, + ioRetryTimeout: ioRetryTimeout, + state: StateNotCreated, + exitCode: 255, // By design for non-exited process status. + exitedCh: make(chan struct{}), + } +} + +// Pid returns the OS-level process ID. +func (c *Controller) Pid() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.processID +} + +// Create sets up upstream IO connections and stores the process spec, +// transitioning the controller from StateNotCreated to StateCreated. +func (c *Controller) Create(ctx context.Context, opts *CreateOptions) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateNotCreated { + return fmt.Errorf("process %q in container %s is in state %s; cannot create: %w", c.execID, c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + if opts.Terminal && opts.Stderr != "" { + return fmt.Errorf("process %q in container %s has terminal enabled but stderr is not empty: %w", c.execID, c.containerID, errdefs.ErrFailedPrecondition) + } + + // Establish upstream IO connections for stdin, stdout, and stderr. + upstreamIO, err := cmd.NewUpstreamIO(ctx, c.containerID, opts.Stdout, opts.Stderr, opts.Stdin, opts.Terminal, c.ioRetryTimeout) + if err != nil { + return fmt.Errorf("create upstream io for process %q in container %s: %w", c.execID, c.containerID, err) + } + + c.upstreamIO = upstreamIO + c.bundle = opts.Bundle + c.processSpec = opts.Spec + c.state = StateCreated + + return nil +} + +// Start launches the process inside the hosting system and returns the PID. +func (c *Controller) Start(ctx context.Context, events chan interface{}) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateCreated { + return -1, fmt.Errorf("process %q in container %s is in state %s; cannot start: %w", c.execID, c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + // Build the command to run inside the container. + // An init exec passes the process as part of the config. We only pass + // the spec if this is a true exec. + execCmd := &cmd.Cmd{ + Host: c.hostingSystem, + Stdin: c.upstreamIO.Stdin(), + Stdout: c.upstreamIO.Stdout(), + Stderr: c.upstreamIO.Stderr(), + Log: log.G(ctx).WithFields(logrus.Fields{ + logfields.ContainerID: c.containerID, + logfields.ExecID: c.execID, + }), + CopyAfterExitTimeout: time.Second, + Spec: c.processSpec, + } + + // Start the process and abort on failure. + if err := execCmd.Start(); err != nil { + c.abortInternal(ctx, 1) + return -1, err + } + + // Track the running process and launch background exit handler. + c.process = execCmd.Process + c.processID = c.process.Pid() + c.state = StateRunning + + go c.handleProcessExit(ctx, execCmd, events) + + return c.processID, nil +} + +// handleProcessExit blocks until the process exits, cleans up IO, and +// publishes the exit event via events channel. +func (c *Controller) handleProcessExit(ctx context.Context, execCmd *cmd.Cmd, events chan interface{}) { + // Detach from the caller's context so upstream cancellation does + // not abort the background teardown. + ctx = context.WithoutCancel(ctx) + + // Wait for the process to exit, drain all IO copies, and close the + // underlying process handle. + if err := execCmd.Wait(); err != nil { + log.G(ctx).WithError(err).Warn("process exit wait failed") + } + + exitCode := execCmd.ExitState.ExitCode() + + // Record the exit status under the lock. + c.mu.Lock() + if c.state == StateTerminated { + log.G(ctx).Warnf("process %s is already in terminated state", c.execID) + c.mu.Unlock() + return + } + c.state = StateTerminated + c.exitCode = uint32(exitCode) + c.exitedAt = time.Now() + c.mu.Unlock() + + // Release upstream IO connections. + c.upstreamIO.Close(ctx) + + // Unblock any waiters. + close(c.exitedCh) + + // Publish the exit event after all cleanup is done. + // We do not publish exit event for init process wherein + // the event is sent after container teardown is complete. + if events != nil { + status := c.Status(true) + events <- &eventstypes.TaskExit{ + ContainerID: c.containerID, + ID: status.ExecID, + Pid: status.Pid, + ExitStatus: status.ExitStatus, + ExitedAt: status.ExitedAt, + } + } +} + +// Status returns the current containerd-compatible state of the process. +// When isDetailed is true, the response includes bundle, IO paths, +// terminal flag, exit code, and exit timestamp. +func (c *Controller) Status(isDetailed bool) *task.StateResponse { + c.mu.RLock() + defer c.mu.RUnlock() + + resp := &task.StateResponse{ + ID: c.containerID, + ExecID: c.execID, + Pid: uint32(c.processID), + Status: c.state.ContainerdStatus(), + } + + if isDetailed && c.state != StateNotCreated { + resp.Bundle = c.bundle + resp.Stdin = c.upstreamIO.StdinPath() + resp.Stdout = c.upstreamIO.StdoutPath() + resp.Stderr = c.upstreamIO.StderrPath() + resp.Terminal = c.upstreamIO.Terminal() + resp.ExitStatus = c.exitCode + resp.ExitedAt = timestamppb.New(c.exitedAt) + } + + return resp +} + +// State returns the current lifecycle state. +func (c *Controller) State() State { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.state +} + +// ResizeConsole resizes the pseudo-TTY of a running process. +func (c *Controller) ResizeConsole(ctx context.Context, width, height uint32) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.state != StateRunning { + return fmt.Errorf("process %q in container %s is in state %s; cannot resize console: %w", c.execID, c.containerID, c.state, errdefs.ErrFailedPrecondition) + } + + if !c.upstreamIO.Terminal() { + return fmt.Errorf("process %q in container %s is not a tty: %w", c.execID, c.containerID, errdefs.ErrFailedPrecondition) + } + + return c.process.ResizeConsole(ctx, uint16(width), uint16(height)) +} + +// CloseIO closes the upstream stdin connection, unblocking any pending +// IO copy. This is safe to call multiple times. +func (c *Controller) CloseIO(ctx context.Context) { + // Taking read lock is sufficient as CloseStdin itself is thread safe. + c.mu.RLock() + defer c.mu.RUnlock() + + if c.state == StateNotCreated { + return + } + + // If we have any upstream IO we close the upstream connection. This will + // unblock the `io.Copy` in the `cmd.Start()` call which will signal + // `cmd.CloseStdin()`. This is safe to call multiple times. + c.upstreamIO.CloseStdin(ctx) +} + +// Wait blocks until the process has exited or the context is cancelled. +func (c *Controller) Wait(ctx context.Context) { + select { + case <-c.exitedCh: + case <-ctx.Done(): + } +} + +// Kill delivers a signal to the process or terminates it. +// +// signalOptions contains the platform-specific signal options (e.g., +// SignalProcessOptionsWCOW or SignalProcessOptionsLCOW). The caller is +// responsible for validating the signal and producing the correct options +// for the platform. When signalOptions is nil the terminate path is used. +func (c *Controller) Kill(ctx context.Context, signalOptions interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + switch c.state { + case StateCreated: + // The process was never started. Transition directly to terminated state. + c.abortInternal(ctx, 1) + return nil + + case StateRunning: + var isDelivered bool + var err error + + if signalOptions != nil { + isDelivered, err = c.process.Signal(ctx, signalOptions) + } else { + // Legacy path: signals are not supported, issue a direct terminate. + isDelivered, err = c.process.Kill(ctx) + } + + if err != nil { + if hcs.IsAlreadyStopped(err) { + // Desired state matches actual state — not an error. + return nil + } + return fmt.Errorf("failed to kill process: %w", err) + } + + if !isDelivered { + return fmt.Errorf("process %q in container %s was not found: %w", c.execID, c.containerID, errdefs.ErrNotFound) + } + return nil + + case StateTerminated: + // The process already exited — desired state matches actual state. + return nil + + default: + return fmt.Errorf("process %q in container %s is in an unexpected state %s: %w", c.execID, c.containerID, c.state, errdefs.ErrFailedPrecondition) + } +} + +// Delete prepares the process for removal from the container's process table. +func (c *Controller) Delete(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + switch c.state { + case StateTerminated: + // Expected state — the process has exited and is ready for cleanup. + return nil + + case StateCreated: + // The process was created but never started. Abort it to release IO + // resources and unblock any waiters. + c.abortInternal(ctx, 0) + return nil + + case StateRunning: + // A running process must be explicitly killed before it can be deleted. + return fmt.Errorf("process %q in container %s is still running; cannot delete: %w", c.execID, c.containerID, errdefs.ErrFailedPrecondition) + + default: + return fmt.Errorf("process %q in container %s is in unexpected state %s for delete: %w", c.execID, c.containerID, c.state, errdefs.ErrFailedPrecondition) + } +} + +// abortInternal performs the abort teardown while the caller already holds c.mu. +func (c *Controller) abortInternal(ctx context.Context, exitCode uint32) { + // No OS-level process exists — transition directly to terminated state. + c.state = StateTerminated + c.exitCode = exitCode + c.exitedAt = time.Now() + + // Release upstream IO connections that were never used. + c.upstreamIO.Close(ctx) + + // Unblock any waiters. + close(c.exitedCh) +} diff --git a/internal/controller/process/process_test.go b/internal/controller/process/process_test.go new file mode 100644 index 0000000000..ba91887baa --- /dev/null +++ b/internal/controller/process/process_test.go @@ -0,0 +1,774 @@ +//go:build windows && (lcow || wcow) + +package process + +import ( + "context" + "errors" + "testing" + "time" + + containerdtypes "github.com/containerd/containerd/api/types/task" + "github.com/containerd/errdefs" + "github.com/opencontainers/runtime-spec/specs-go" + "go.uber.org/mock/gomock" + + "github.com/Microsoft/hcsshim/internal/controller/process/mocks" + "github.com/Microsoft/hcsshim/internal/hcs" +) + +const ( + testContainerID = "test-container-1234" + testExecID = "test-exec-5678" + testPID = 42 +) + +// errTest is a sentinel used in table-driven tests to verify error propagation. +var errTest = errors.New("test error") + +func newSetup(t *testing.T) (*gomock.Controller, *mocks.MockProcessHost, *mocks.MockUpstreamIO, *Controller) { + t.Helper() + mockCtrl := gomock.NewController(t) + mockHost := mocks.NewMockProcessHost(mockCtrl) + mockIO := mocks.NewMockUpstreamIO(mockCtrl) + return mockCtrl, mockHost, mockIO, New(testContainerID, testExecID, mockHost, time.Second) +} + +// TestNew_InitializesFields verifies that New sets all fields correctly and +// that the initial state is StateNotCreated with exit code 255. +func TestNew_InitializesFields(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + + if controller.containerID != testContainerID { + t.Errorf("containerID = %q; want %q", controller.containerID, testContainerID) + } + if controller.execID != testExecID { + t.Errorf("execID = %q; want %q", controller.execID, testExecID) + } + if controller.state != StateNotCreated { + t.Errorf("initial state = %s; want StateNotCreated", controller.state) + } + if controller.exitCode != 255 { + t.Errorf("initial exitCode = %d; want 255", controller.exitCode) + } + if controller.exitedCh == nil { + t.Fatal("exitedCh must not be nil after New") + } + if controller.Pid() != 0 { + t.Errorf("initial Pid() = %d; want 0", controller.Pid()) + } +} + +// TestCreate_WrongState verifies that Create rejects calls made outside StateNotCreated. +func TestCreate_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateCreated, StateRunning, StateTerminated} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = state + + err := controller.Create(t.Context(), &CreateOptions{}) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Create() = %v; want ErrFailedPrecondition", err) + } + }) + } +} + +// TestCreate_TerminalWithStderr verifies that Create rejects the combination of +// terminal=true and a non-empty stderr path. +func TestCreate_TerminalWithStderr(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + + err := controller.Create(t.Context(), &CreateOptions{ + Terminal: true, + Stderr: `\\.\pipe\some-stderr`, + }) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Create(terminal+stderr) = %v; want ErrFailedPrecondition", err) + } +} + +// TestCreate_Succeeds verifies that Create transitions to StateCreated and stores +// the bundle and process spec. Empty IO paths are used so no real named-pipe +// connections are attempted. +func TestCreate_Succeeds(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + + spec := &specs.Process{Args: []string{"/bin/sh"}} + opts := &CreateOptions{ + Bundle: "/test/bundle", + Spec: spec, + } + + if err := controller.Create(t.Context(), opts); err != nil { + t.Fatalf("Create() = %v; want nil", err) + } + if controller.state != StateCreated { + t.Errorf("state = %s; want StateCreated", controller.state) + } + if controller.bundle != opts.Bundle { + t.Errorf("bundle = %q; want %q", controller.bundle, opts.Bundle) + } + if controller.processSpec != spec { + t.Error("processSpec not stored correctly after Create") + } + if controller.upstreamIO == nil { + t.Error("upstreamIO must be non-nil after Create") + } +} + +// TestStart_WrongState verifies that Start rejects calls made outside StateCreated. +func TestStart_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateRunning, StateTerminated} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = state + + pid, err := controller.Start(t.Context(), nil) + if pid != -1 { + t.Errorf("Start() pid = %d; want -1", pid) + } + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Start() = %v; want ErrFailedPrecondition", err) + } + }) + } +} + +// TestStart_Succeeds verifies the happy path: Start returns the correct PID, +// transitions to StateRunning, publishes a TaskExit event, and the background +// goroutine transitions to StateTerminated after the mock process exits. +func TestStart_Succeeds(t *testing.T) { + t.Parallel() + mockCtrl, mockHost, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateCreated + mockProc := mocks.NewMockProcess(mockCtrl) + + // cmd.Cmd struct is populated with IO readers/writers before Start is called. + mockIO.EXPECT().Stdin().Return(nil) + mockIO.EXPECT().Stdout().Return(nil) + mockIO.EXPECT().Stderr().Return(nil) + + // cmd.Cmd.Start calls IsOCI and then CreateProcess. + mockHost.EXPECT().IsOCI().Return(true) + mockHost.EXPECT().CreateProcess(gomock.Any(), gomock.Any()).Return(mockProc, nil) + + // Pid is called once inside cmd.Cmd.Start for log enrichment, and once in + // process.Controller.Start to record the OS-level PID. + mockProc.EXPECT().Pid().Return(testPID).Times(2) + + // cmd.Cmd.Start always calls Stdio() to retrieve the process IO streams + // before deciding whether to start relay goroutines. + mockProc.EXPECT().Stdio().Return(nil, nil, nil) + + // cmd.Cmd.Wait internally calls Process.Wait, Process.ExitCode, and + // Process.Close once. handleProcessExit delegates entirely to + // execCmd.Wait rather than calling these directly. + mockProc.EXPECT().Wait().Return(nil) + mockProc.EXPECT().ExitCode().Return(0, nil) + mockProc.EXPECT().Close().Return(nil) + + // upstreamIO cleanup at the end of handleProcessExit. + mockIO.EXPECT().Close(gomock.Any()) + + // Status(true) is called inside handleProcessExit to populate the exit event. + mockIO.EXPECT().StdinPath().Return("") + mockIO.EXPECT().StdoutPath().Return("") + mockIO.EXPECT().StderrPath().Return("") + mockIO.EXPECT().Terminal().Return(false) + + // Use a buffered channel so the goroutine never blocks on send. + events := make(chan interface{}, 1) + + pid, err := controller.Start(t.Context(), events) + if err != nil { + t.Fatalf("Start() = %v; want nil", err) + } + if pid != testPID { + t.Errorf("Start() pid = %d; want %d", pid, testPID) + } + if controller.State() != StateRunning { + t.Errorf("state after Start = %s; want StateRunning", controller.State()) + } + + // Block until handleProcessExit goroutine finishes so all mock expectations + // are satisfied before the gomock controller runs Finish. + controller.Wait(t.Context()) + + if controller.State() != StateTerminated { + t.Errorf("state after exit = %s; want StateTerminated", controller.State()) + } + + // Verify that a TaskExit event was published. + select { + case event := <-events: + if event == nil { + t.Error("received nil event; want TaskExit") + } + default: + t.Error("expected a TaskExit event in events channel; got none") + } +} + +// TestStart_HostCreateProcessFails verifies that a CreateProcess error causes +// Start to abort the controller and transition it to StateTerminated. +func TestStart_HostCreateProcessFails(t *testing.T) { + t.Parallel() + _, mockHost, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateCreated + + // IO readers/writers are read before cmd.Cmd.Start is invoked. + mockIO.EXPECT().Stdin().Return(nil) + mockIO.EXPECT().Stdout().Return(nil) + mockIO.EXPECT().Stderr().Return(nil) + + // CreateProcess fails; cmd.Cmd.Start returns the error. + mockHost.EXPECT().IsOCI().Return(true) + mockHost.EXPECT().CreateProcess(gomock.Any(), gomock.Any()).Return(nil, errTest) + + // abortInternal is called, which releases upstreamIO. + mockIO.EXPECT().Close(gomock.Any()) + + pid, err := controller.Start(t.Context(), nil) + if pid != -1 { + t.Errorf("Start() pid = %d; want -1", pid) + } + if !errors.Is(err, errTest) { + t.Errorf("Start() = %v; want errTest", err) + } + if controller.State() != StateTerminated { + t.Errorf("state = %s; want StateTerminated", controller.State()) + } +} + +// TestKill_WrongState verifies that Kill returns ErrFailedPrecondition +// for states that are not part of the valid kill state machine (e.g., StateNotCreated). +func TestKill_WrongState(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + // StateNotCreated is not a valid state for Kill. + err := controller.Kill(t.Context(), nil) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Kill(NotCreated) = %v; want ErrFailedPrecondition", err) + } +} + +// TestKill_CreatedState verifies that Kill on a created-but-not-started process +// triggers abortInternal, transitioning the controller to StateTerminated. +func TestKill_CreatedState(t *testing.T) { + t.Parallel() + _, _, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateCreated + mockIO.EXPECT().Close(gomock.Any()) + + if err := controller.Kill(t.Context(), nil); err != nil { + t.Fatalf("Kill(Created) = %v; want nil", err) + } + if controller.state != StateTerminated { + t.Errorf("state = %s; want StateTerminated", controller.state) + } + // exitedCh must be closed so any waiters are unblocked. + select { + case <-controller.exitedCh: + default: + t.Error("exitedCh should be closed after Kill(Created)") + } +} + +// TestKill_TerminatedState verifies that Kill on an already terminated process +// is a no-op and returns nil. +func TestKill_TerminatedState(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = StateTerminated + + if err := controller.Kill(t.Context(), nil); err != nil { + t.Errorf("Kill(Terminated) = %v; want nil", err) + } +} + +// TestKill_RunningState_Signal verifies all signal-delivery outcomes when Kill +// is called with non-nil signal options on a running process. +func TestKill_RunningState_Signal(t *testing.T) { + t.Parallel() + signalOpts := struct{ sig int }{sig: 15} + + tests := []struct { + name string + isDelivered bool + signalErr error + wantErr bool + wantErrIs error + }{ + { + name: "signal delivered", + isDelivered: true, + signalErr: nil, + wantErr: false, + }, + { + name: "signal not delivered", + isDelivered: false, + signalErr: nil, + wantErr: true, + wantErrIs: errdefs.ErrNotFound, + }, + { + name: "signal error propagated", + isDelivered: false, + signalErr: errTest, + wantErr: true, + wantErrIs: errTest, + }, + { + name: "already stopped is treated as success", + isDelivered: false, + signalErr: hcs.ErrProcessAlreadyStopped, + wantErr: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + mockCtrl, _, _, controller := newSetup(t) + mockProc := mocks.NewMockProcess(mockCtrl) + + controller.state = StateRunning + controller.process = mockProc + + mockProc.EXPECT().Signal(gomock.Any(), signalOpts).Return(testCase.isDelivered, testCase.signalErr) + + err := controller.Kill(t.Context(), signalOpts) + if (err != nil) != testCase.wantErr { + t.Errorf("Kill() error = %v; wantErr = %v", err, testCase.wantErr) + } + if testCase.wantErrIs != nil && !errors.Is(err, testCase.wantErrIs) { + t.Errorf("Kill() = %v; want errors.Is(%v)", err, testCase.wantErrIs) + } + }) + } +} + +// TestKill_RunningState_Terminate verifies all terminate-delivery outcomes when +// Kill is called with nil signal options on a running process. +func TestKill_RunningState_Terminate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + isDelivered bool + killErr error + wantErr bool + wantErrIs error + }{ + { + name: "terminate delivered", + isDelivered: true, + killErr: nil, + wantErr: false, + }, + { + name: "terminate not delivered", + isDelivered: false, + killErr: nil, + wantErr: true, + wantErrIs: errdefs.ErrNotFound, + }, + { + name: "terminate error propagated", + isDelivered: false, + killErr: errTest, + wantErr: true, + wantErrIs: errTest, + }, + { + name: "already stopped is treated as success", + isDelivered: false, + killErr: hcs.ErrVmcomputeAlreadyStopped, + wantErr: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + mockCtrl, _, _, controller := newSetup(t) + mockProc := mocks.NewMockProcess(mockCtrl) + + controller.state = StateRunning + controller.process = mockProc + + // nil signalOptions triggers the legacy Kill (terminate) path. + mockProc.EXPECT().Kill(gomock.Any()).Return(testCase.isDelivered, testCase.killErr) + + err := controller.Kill(t.Context(), nil) + if (err != nil) != testCase.wantErr { + t.Errorf("Kill() error = %v; wantErr = %v", err, testCase.wantErr) + } + if testCase.wantErrIs != nil && !errors.Is(err, testCase.wantErrIs) { + t.Errorf("Kill() = %v; want errors.Is(%v)", err, testCase.wantErrIs) + } + }) + } +} + +// TestResizeConsole_WrongState verifies that ResizeConsole fails with +// ErrFailedPrecondition when the process is not in StateRunning. +func TestResizeConsole_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateCreated, StateTerminated} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = state + + err := controller.ResizeConsole(t.Context(), 80, 24) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("ResizeConsole() = %v; want ErrFailedPrecondition", err) + } + }) + } +} + +// TestResizeConsole_NonTerminal verifies that ResizeConsole fails when the +// process was not started with a pseudo-TTY. +func TestResizeConsole_NonTerminal(t *testing.T) { + t.Parallel() + _, _, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateRunning + + mockIO.EXPECT().Terminal().Return(false) + + err := controller.ResizeConsole(t.Context(), 80, 24) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("ResizeConsole(non-TTY) = %v; want ErrFailedPrecondition", err) + } +} + +// TestResizeConsole_Result verifies the happy path and error propagation for +// ResizeConsole when the process is running with a pseudo-TTY. +func TestResizeConsole_Result(t *testing.T) { + t.Parallel() + tests := []struct { + name string + width uint32 + height uint32 + resizeErr error + wantErr bool + }{ + { + name: "succeeds", + width: 80, + height: 24, + resizeErr: nil, + wantErr: false, + }, + { + name: "error propagated", + width: 100, + height: 50, + resizeErr: errTest, + wantErr: true, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + mockCtrl, _, mockIO, controller := newSetup(t) + mockProc := mocks.NewMockProcess(mockCtrl) + controller.upstreamIO = mockIO + controller.state = StateRunning + controller.process = mockProc + + mockIO.EXPECT().Terminal().Return(true) + mockProc.EXPECT().ResizeConsole(gomock.Any(), uint16(testCase.width), uint16(testCase.height)).Return(testCase.resizeErr) + + err := controller.ResizeConsole(t.Context(), testCase.width, testCase.height) + if (err != nil) != testCase.wantErr { + t.Errorf("ResizeConsole() error = %v; wantErr = %v", err, testCase.wantErr) + } + if testCase.resizeErr != nil && !errors.Is(err, testCase.resizeErr) { + t.Errorf("ResizeConsole() = %v; want errors.Is(%v)", err, testCase.resizeErr) + } + }) + } +} + +// TestCloseIO verifies that CloseIO is a no-op when the process has not been +// created (upstreamIO is nil) and forwards to upstreamIO.CloseStdin for all +// other states. +func TestCloseIO(t *testing.T) { + t.Parallel() + tests := []struct { + name string + state State + hasUpstreamIO bool + }{ + {"NotCreated", StateNotCreated, false}, + {"Created", StateCreated, true}, + {"Running", StateRunning, true}, + {"Terminated", StateTerminated, true}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + _, _, mockIO, controller := newSetup(t) + controller.state = testCase.state + + if testCase.hasUpstreamIO { + controller.upstreamIO = mockIO + mockIO.EXPECT().CloseStdin(gomock.Any()) + } + + controller.CloseIO(t.Context()) + }) + } +} + +// TestWait_ProcessExit verifies that Wait returns promptly once exitedCh is closed. +func TestWait_ProcessExit(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + // Simulate a process that has already exited. + close(controller.exitedCh) + + finished := make(chan struct{}) + go func() { + controller.Wait(t.Context()) + close(finished) + }() + + select { + case <-finished: + // success + case <-time.After(time.Second): + t.Fatal("Wait() did not return after process exited") + } +} + +// TestWait_ContextCancelled verifies that Wait returns when the context is +// cancelled even if exitedCh is never closed. +func TestWait_ContextCancelled(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + // Do not close exitedCh; Wait must return via context cancellation. + + cancelCtx, cancel := context.WithCancel(t.Context()) + + finished := make(chan struct{}) + go func() { + controller.Wait(cancelCtx) + close(finished) + }() + + cancel() + + select { + case <-finished: + // success + case <-time.After(time.Second): + t.Fatal("Wait() did not return after context was cancelled") + } +} + +// TestDelete_WrongState verifies that Delete rejects calls made outside of valid states. +func TestDelete_WrongState(t *testing.T) { + t.Parallel() + invalidStates := []State{StateNotCreated, StateRunning} + + for _, state := range invalidStates { + t.Run(state.String(), func(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = state + + err := controller.Delete(t.Context()) + if !errors.Is(err, errdefs.ErrFailedPrecondition) { + t.Errorf("Delete() = %v; want ErrFailedPrecondition", err) + } + }) + } +} + +// TestDelete_CreatedState_Succeeds verifies that Delete on a created-but-never-started +// process transitions to StateTerminated with exit code 0, releases upstreamIO, +// and closes exitedCh. +func TestDelete_CreatedState_Succeeds(t *testing.T) { + t.Parallel() + _, _, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateCreated + mockIO.EXPECT().Close(gomock.Any()) + + if err := controller.Delete(t.Context()); err != nil { + t.Fatalf("Delete() = %v; want nil", err) + } + if controller.state != StateTerminated { + t.Errorf("state = %s; want StateTerminated", controller.state) + } + if controller.exitCode != 0 { + t.Errorf("exitCode = %d; want 0", controller.exitCode) + } + // exitedCh must be closed so any concurrent Wait calls unblock. + select { + case <-controller.exitedCh: + // success + default: + t.Error("exitedCh should be closed after Delete") + } +} + +// TestDelete_TerminatedState_Succeeds verifies that Delete on an already-terminated +// process is a no-op and returns nil. +func TestDelete_TerminatedState_Succeeds(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = StateTerminated + + if err := controller.Delete(t.Context()); err != nil { + t.Fatalf("Delete() = %v; want nil", err) + } +} + +// TestStatus_NotCreated verifies that Status returns UNKNOWN status when the +// controller has not been created yet. +func TestStatus_NotCreated(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + + status := controller.Status(false) + + if status.ID != testContainerID { + t.Errorf("ID = %q; want %q", status.ID, testContainerID) + } + if status.ExecID != testExecID { + t.Errorf("ExecID = %q; want %q", status.ExecID, testExecID) + } + if status.Status != containerdtypes.Status_UNKNOWN { + t.Errorf("Status = %v; want UNKNOWN", status.Status) + } +} + +// TestStatus_Created_Detailed verifies that Status in StateCreated with +// isDetailed=true populates all IO path and bundle fields. +func TestStatus_Created_Detailed(t *testing.T) { + t.Parallel() + _, _, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateCreated + controller.bundle = "/test/bundle" + + mockIO.EXPECT().StdinPath().Return("/pipe/stdin") + mockIO.EXPECT().StdoutPath().Return("/pipe/stdout") + mockIO.EXPECT().StderrPath().Return("/pipe/stderr") + mockIO.EXPECT().Terminal().Return(false) + + status := controller.Status(true) + + if status.Status != containerdtypes.Status_CREATED { + t.Errorf("Status = %v; want CREATED", status.Status) + } + if status.Bundle != "/test/bundle" { + t.Errorf("Bundle = %q; want /test/bundle", status.Bundle) + } + if status.Stdin != "/pipe/stdin" { + t.Errorf("Stdin = %q; want /pipe/stdin", status.Stdin) + } + if status.Stdout != "/pipe/stdout" { + t.Errorf("Stdout = %q; want /pipe/stdout", status.Stdout) + } + if status.Stderr != "/pipe/stderr" { + t.Errorf("Stderr = %q; want /pipe/stderr", status.Stderr) + } + if status.Terminal { + t.Error("Terminal = true; want false") + } +} + +// TestStatus_Running verifies that Status reflects RUNNING state and the stored PID. +// Status(false) does not access upstreamIO so no IO mock expectations are needed. +func TestStatus_Running(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.state = StateRunning + controller.processID = testPID + + status := controller.Status(false) + + if status.Status != containerdtypes.Status_RUNNING { + t.Errorf("Status = %v; want RUNNING", status.Status) + } + if int(status.Pid) != testPID { + t.Errorf("Pid = %d; want %d", status.Pid, testPID) + } +} + +// TestStatus_Terminated_Detailed verifies that Status in StateTerminated with +// isDetailed=true includes the exit code and IO paths. +func TestStatus_Terminated_Detailed(t *testing.T) { + t.Parallel() + const wantExitCode = uint32(1) + + _, _, mockIO, controller := newSetup(t) + controller.upstreamIO = mockIO + controller.state = StateTerminated + controller.exitCode = wantExitCode + + mockIO.EXPECT().StdinPath().Return("") + mockIO.EXPECT().StdoutPath().Return("") + mockIO.EXPECT().StderrPath().Return("") + mockIO.EXPECT().Terminal().Return(false) + + status := controller.Status(true) + + if status.Status != containerdtypes.Status_STOPPED { + t.Errorf("Status = %v; want STOPPED", status.Status) + } + if status.ExitStatus != wantExitCode { + t.Errorf("ExitStatus = %d; want %d", status.ExitStatus, wantExitCode) + } +} + +// TestState_Returns verifies that the State accessor reflects every lifecycle state. +func TestState_Returns(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + + for _, wantState := range []State{StateNotCreated, StateCreated, StateRunning, StateTerminated} { + controller.state = wantState + if got := controller.State(); got != wantState { + t.Errorf("State() = %v; want %v", got, wantState) + } + } +} + +// TestPid_Returns verifies that Pid returns the OS-level process ID stored +// during Start. +func TestPid_Returns(t *testing.T) { + t.Parallel() + _, _, _, controller := newSetup(t) + controller.processID = testPID + + if got := controller.Pid(); got != testPID { + t.Errorf("Pid() = %d; want %d", got, testPID) + } +} diff --git a/internal/controller/process/state.go b/internal/controller/process/state.go new file mode 100644 index 0000000000..3e991b716c --- /dev/null +++ b/internal/controller/process/state.go @@ -0,0 +1,73 @@ +//go:build windows && (lcow || wcow) + +package process + +import ( + containerdtypes "github.com/containerd/containerd/api/types/task" +) + +// State represents the current state of the process lifecycle. +// +// The normal progression is: +// +// StateNotCreated → StateCreated → StateRunning → StateTerminated +// +// Full state-transition table: +// +// Current State │ Trigger │ Next State +// ─────────────────┼──────────────────────────────────────┼──────────────── +// StateNotCreated │ Create succeeds │ StateCreated +// StateCreated │ Start succeeds │ StateRunning +// StateCreated │ Start fails / Kill / Delete │ StateTerminated +// StateRunning │ process exits │ StateTerminated +// StateRunning │ Kill succeeds (signal or terminate) │ StateTerminated +// StateTerminated │ (terminal — no further transitions) │ — +type State int32 + +const ( + // StateNotCreated indicates the process has not been created yet. + // This is the initial state set by [New]. + StateNotCreated State = iota + + // StateCreated indicates the process has been created but not started. + // IO connections are established and the process spec is stored. + StateCreated + + // StateRunning indicates the process has been started and is executing. + StateRunning + + // StateTerminated indicates the process has exited and all cleanup is done. + // This is a terminal state — no further transitions are possible. + StateTerminated +) + +// String returns a human-readable representation of the State. +func (s State) String() string { + switch s { + case StateNotCreated: + return "NotCreated" + case StateCreated: + return "Created" + case StateRunning: + return "Running" + case StateTerminated: + return "Terminated" + default: + return "Unknown" + } +} + +// ContainerdStatus converts the State into the equivalent containerd task Status. +func (s State) ContainerdStatus() containerdtypes.Status { + switch s { + case StateCreated: + return containerdtypes.Status_CREATED + case StateRunning: + return containerdtypes.Status_RUNNING + case StateTerminated: + return containerdtypes.Status_STOPPED + default: + // StateNotCreated has no direct containerd equivalent. + return containerdtypes.Status_UNKNOWN + } +} diff --git a/internal/controller/process/types.go b/internal/controller/process/types.go new file mode 100644 index 0000000000..c95da266c8 --- /dev/null +++ b/internal/controller/process/types.go @@ -0,0 +1,31 @@ +//go:build windows && (lcow || wcow) + +package process + +import ( + "github.com/opencontainers/runtime-spec/specs-go" +) + +// CreateOptions holds the parameters for creating a new process. +type CreateOptions struct { + // Bundle is the path to the OCI bundle directory. + Bundle string + + // Spec is the OCI process spec for exec processes. For init processes + // the spec is passed as part of the container config and this is nil. + Spec *specs.Process + + // Terminal indicates whether the process should allocate a pseudo-TTY. + // When true, Stderr must be empty because the PTY multiplexes both + // stdout and stderr onto a single stream. + Terminal bool + + // Stdin is the named-pipe path for the process's standard input. + Stdin string + + // Stdout is the named-pipe path for the process's standard output. + Stdout string + + // Stderr is the named-pipe path for the process's standard error. + Stderr string +} diff --git a/internal/guestpath/paths.go b/internal/guestpath/paths.go index 1852bc2454..d616efc646 100644 --- a/internal/guestpath/paths.go +++ b/internal/guestpath/paths.go @@ -4,6 +4,10 @@ const ( // LCOWRootPrefixInUVM is the path inside UVM where LCOW container's root // file system will be mounted LCOWRootPrefixInUVM = "/run/gcs/c" + // LCOWV2RootPrefixInVM is the path inside the UVM where LCOW container's root + // file system will be mounted. + // For V2 shims, this will be of format "/run/gcs/pods//". + LCOWV2RootPrefixInVM = "/run/gcs/pods" // WCOWRootPrefixInUVM is the path inside UVM where WCOW container's root // file system will be mounted WCOWRootPrefixInUVM = `C:\c` diff --git a/internal/logfields/fields.go b/internal/logfields/fields.go index d792c055bc..f36417d7fd 100644 --- a/internal/logfields/fields.go +++ b/internal/logfields/fields.go @@ -9,6 +9,7 @@ const ( ID = "id" ContainerID = "cid" + GCSContainerID = "gcs_container_id" ExecID = "eid" NamespaceID = "namespace-id" PodID = "pod-id"