diff --git a/src/runtime/virtcontainers/acrn.go b/src/runtime/virtcontainers/acrn.go index c1b2532552..93359e16ad 100644 --- a/src/runtime/virtcontainers/acrn.go +++ b/src/runtime/virtcontainers/acrn.go @@ -479,7 +479,7 @@ func (a *Acrn) waitSandbox(ctx context.Context, timeoutSecs int) error { } // stopSandbox will stop the Sandbox's VM. -func (a *Acrn) stopSandbox(ctx context.Context) (err error) { +func (a *Acrn) stopSandbox(ctx context.Context, waitOnly bool) (err error) { span, _ := a.trace(ctx, "stopSandbox") defer span.End() @@ -511,36 +511,14 @@ func (a *Acrn) stopSandbox(ctx context.Context) (err error) { pid := a.state.PID - // Send signal to the VM process to try to stop it properly - if err = syscall.Kill(pid, syscall.SIGINT); err != nil { - if err == syscall.ESRCH { - return nil - } - a.Logger().Info("Sending signal to stop acrn VM failed") - return err + shutdownSignal := syscall.SIGINT + + if waitOnly { + // NOP + shutdownSignal = syscall.Signal(0) } - // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - a.Logger().Info("acrn VM stopped after sending signal") - return nil - } - - if time.Since(tInit).Seconds() >= acrnStopSandboxTimeoutSecs { - a.Logger().Warnf("VM still running after waiting %ds", acrnStopSandboxTimeoutSecs) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - return syscall.Kill(pid, syscall.SIGKILL) - + return utils.WaitLocalProcess(pid, acrnStopSandboxTimeoutSecs, shutdownSignal, a.Logger()) } func (a *Acrn) updateBlockDevice(drive *config.BlockDrive) error { diff --git a/src/runtime/virtcontainers/clh.go b/src/runtime/virtcontainers/clh.go index 482f3c70c2..d0b9b97dd0 100644 --- a/src/runtime/virtcontainers/clh.go +++ b/src/runtime/virtcontainers/clh.go @@ -673,11 +673,11 @@ func (clh *cloudHypervisor) resumeSandbox(ctx context.Context) error { } // stopSandbox will stop the Sandbox's VM. -func (clh *cloudHypervisor) stopSandbox(ctx context.Context) (err error) { +func (clh *cloudHypervisor) stopSandbox(ctx context.Context, waitOnly bool) (err error) { span, _ := clh.trace(ctx, "stopSandbox") defer span.End() clh.Logger().WithField("function", "stopSandbox").Info("Stop Sandbox") - return clh.terminate(ctx) + return clh.terminate(ctx, waitOnly) } func (clh *cloudHypervisor) fromGrpc(ctx context.Context, hypervisorConfig *HypervisorConfig, j []byte) error { @@ -777,7 +777,7 @@ func (clh *cloudHypervisor) trace(parent context.Context, name string) (otelTrac return span, ctx } -func (clh *cloudHypervisor) terminate(ctx context.Context) (err error) { +func (clh *cloudHypervisor) terminate(ctx context.Context, waitOnly bool) (err error) { span, _ := clh.trace(ctx, "terminate") defer span.End() @@ -796,7 +796,7 @@ func (clh *cloudHypervisor) terminate(ctx context.Context) (err error) { clh.Logger().Debug("Stopping Cloud Hypervisor") - if pidRunning { + if pidRunning && !waitOnly { clhRunning, _ := clh.isClhRunning(clhStopSandboxTimeout) if clhRunning { ctx, cancel := context.WithTimeout(context.Background(), clhStopSandboxTimeout*time.Second) @@ -807,31 +807,8 @@ func (clh *cloudHypervisor) terminate(ctx context.Context) (err error) { } } - // At this point the VMM was stop nicely, but need to check if PID is still running - // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - pidRunning = false - break - } - - if time.Since(tInit).Seconds() >= clhStopSandboxTimeout { - pidRunning = true - clh.Logger().Warnf("VM still running after waiting %ds", clhStopSandboxTimeout) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - if pidRunning { - if err = syscall.Kill(pid, syscall.SIGKILL); err != nil { - return fmt.Errorf("Fatal, failed to kill hypervisor process, error: %s", err) - } + if err = utils.WaitLocalProcess(pid, clhStopSandboxTimeout, syscall.Signal(0), clh.Logger()); err != nil { + return err } if clh.virtiofsd == nil { diff --git a/src/runtime/virtcontainers/fc.go b/src/runtime/virtcontainers/fc.go index a223d124e0..7dac0e78a0 100644 --- a/src/runtime/virtcontainers/fc.go +++ b/src/runtime/virtcontainers/fc.go @@ -409,7 +409,7 @@ func (fc *firecracker) fcInit(ctx context.Context, timeout int) error { return nil } -func (fc *firecracker) fcEnd(ctx context.Context) (err error) { +func (fc *firecracker) fcEnd(ctx context.Context, waitOnly bool) (err error) { span, _ := fc.trace(ctx, "fcEnd") defer span.End() @@ -425,33 +425,15 @@ func (fc *firecracker) fcEnd(ctx context.Context) (err error) { pid := fc.info.PID - // Send a SIGTERM to the VM process to try to stop it properly - if err = syscall.Kill(pid, syscall.SIGTERM); err != nil { - if err == syscall.ESRCH { - return nil - } - return err + shutdownSignal := syscall.SIGTERM + + if waitOnly { + // NOP + shutdownSignal = syscall.Signal(0) } // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - return nil - } - - if time.Since(tInit).Seconds() >= fcStopSandboxTimeout { - fc.Logger().Warnf("VM still running after waiting %ds", fcStopSandboxTimeout) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - return syscall.Kill(pid, syscall.SIGKILL) + return utils.WaitLocalProcess(pid, fcStopSandboxTimeout, shutdownSignal, fc.Logger()) } func (fc *firecracker) client(ctx context.Context) *client.Firecracker { @@ -783,7 +765,7 @@ func (fc *firecracker) startSandbox(ctx context.Context, timeout int) error { var err error defer func() { if err != nil { - fc.fcEnd(ctx) + fc.fcEnd(ctx, false) } }() @@ -876,11 +858,11 @@ func (fc *firecracker) cleanupJail(ctx context.Context) { } // stopSandbox will stop the Sandbox's VM. -func (fc *firecracker) stopSandbox(ctx context.Context) (err error) { +func (fc *firecracker) stopSandbox(ctx context.Context, waitOnly bool) (err error) { span, _ := fc.trace(ctx, "stopSandbox") defer span.End() - return fc.fcEnd(ctx) + return fc.fcEnd(ctx, waitOnly) } func (fc *firecracker) pauseSandbox(ctx context.Context) error { diff --git a/src/runtime/virtcontainers/hypervisor.go b/src/runtime/virtcontainers/hypervisor.go index a7fb093cd9..dc6fedf063 100644 --- a/src/runtime/virtcontainers/hypervisor.go +++ b/src/runtime/virtcontainers/hypervisor.go @@ -787,7 +787,9 @@ func generateVMSocket(id string, vmStogarePath string) (interface{}, error) { type hypervisor interface { createSandbox(ctx context.Context, id string, networkNS NetworkNamespace, hypervisorConfig *HypervisorConfig) error startSandbox(ctx context.Context, timeout int) error - stopSandbox(ctx context.Context) error + // If wait is set, don't actively stop the sandbox: + // just perform cleanup. + stopSandbox(ctx context.Context, waitOnly bool) error pauseSandbox(ctx context.Context) error saveSandbox() error resumeSandbox(ctx context.Context) error diff --git a/src/runtime/virtcontainers/mock_hypervisor.go b/src/runtime/virtcontainers/mock_hypervisor.go index fd2a6161b7..af19130f40 100644 --- a/src/runtime/virtcontainers/mock_hypervisor.go +++ b/src/runtime/virtcontainers/mock_hypervisor.go @@ -41,7 +41,7 @@ func (m *mockHypervisor) startSandbox(ctx context.Context, timeout int) error { return nil } -func (m *mockHypervisor) stopSandbox(ctx context.Context) error { +func (m *mockHypervisor) stopSandbox(ctx context.Context, waitOnly bool) error { return nil } diff --git a/src/runtime/virtcontainers/mock_hypervisor_test.go b/src/runtime/virtcontainers/mock_hypervisor_test.go index dece251652..c557f86580 100644 --- a/src/runtime/virtcontainers/mock_hypervisor_test.go +++ b/src/runtime/virtcontainers/mock_hypervisor_test.go @@ -53,7 +53,7 @@ func TestMockHypervisorStartSandbox(t *testing.T) { func TestMockHypervisorStopSandbox(t *testing.T) { var m *mockHypervisor - assert.NoError(t, m.stopSandbox(context.Background())) + assert.NoError(t, m.stopSandbox(context.Background(), false)) } func TestMockHypervisorAddDevice(t *testing.T) { diff --git a/src/runtime/virtcontainers/qemu.go b/src/runtime/virtcontainers/qemu.go index 46a6bf9b48..3c8ff0249f 100644 --- a/src/runtime/virtcontainers/qemu.go +++ b/src/runtime/virtcontainers/qemu.go @@ -122,6 +122,8 @@ const ( scsiControllerID = "scsi0" rngID = "rng0" fallbackFileBackedMemDir = "/dev/shm" + + qemuStopSandboxTimeoutSecs = 15 ) // agnostic list of kernel parameters @@ -703,7 +705,7 @@ func (q *qemu) setupVirtiofsd(ctx context.Context) (err error) { q.Logger().Info("virtiofsd quits") // Wait to release resources of virtiofsd process cmd.Process.Wait() - q.stopSandbox(ctx) + q.stopSandbox(ctx, false) }() return err } @@ -933,7 +935,7 @@ func (q *qemu) waitSandbox(ctx context.Context, timeout int) error { } // stopSandbox will stop the Sandbox's VM. -func (q *qemu) stopSandbox(ctx context.Context) error { +func (q *qemu) stopSandbox(ctx context.Context, waitOnly bool) error { span, _ := q.trace(ctx, "stopSandbox") defer span.End() @@ -965,10 +967,24 @@ func (q *qemu) stopSandbox(ctx context.Context) error { return err } - err := q.qmpMonitorCh.qmp.ExecuteQuit(q.qmpMonitorCh.ctx) - if err != nil { - q.Logger().WithError(err).Error("Fail to execute qmp QUIT") - return err + if waitOnly { + pids := q.getPids() + if len(pids) == 0 { + return errors.New("cannot determine QEMU PID") + } + + pid := pids[0] + + err := utils.WaitLocalProcess(pid, qemuStopSandboxTimeoutSecs, syscall.Signal(0), q.Logger()) + if err != nil { + return err + } + } else { + err := q.qmpMonitorCh.qmp.ExecuteQuit(q.qmpMonitorCh.ctx) + if err != nil { + q.Logger().WithError(err).Error("Fail to execute qmp QUIT") + return err + } } return nil diff --git a/src/runtime/virtcontainers/sandbox.go b/src/runtime/virtcontainers/sandbox.go index e758071fdb..a9370e0d58 100644 --- a/src/runtime/virtcontainers/sandbox.go +++ b/src/runtime/virtcontainers/sandbox.go @@ -1027,7 +1027,7 @@ func (s *Sandbox) startVM(ctx context.Context) (err error) { defer func() { if err != nil { - s.hypervisor.stopSandbox(ctx) + s.hypervisor.stopSandbox(ctx, false) } }() @@ -1081,14 +1081,9 @@ func (s *Sandbox) stopVM(ctx context.Context) error { s.Logger().WithError(err).WithField("sandboxid", s.id).Warning("Agent did not stop sandbox") } - if s.disableVMShutdown { - // Do not kill the VM - allow the agent to shut it down - // (only used to support static agent tracing). - return nil - } - s.Logger().Info("Stopping VM") - return s.hypervisor.stopSandbox(ctx) + + return s.hypervisor.stopSandbox(ctx, s.disableVMShutdown) } func (s *Sandbox) addContainer(c *Container) error { diff --git a/src/runtime/virtcontainers/utils/utils.go b/src/runtime/virtcontainers/utils/utils.go index d8bfc1fddd..e49f55a101 100644 --- a/src/runtime/virtcontainers/utils/utils.go +++ b/src/runtime/virtcontainers/utils/utils.go @@ -12,7 +12,10 @@ import ( "os" "os/exec" "path/filepath" + "syscall" + "time" + "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" pbTypes "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/pkg/agent/protocols" @@ -312,3 +315,78 @@ func ConvertNetlinkFamily(netlinkFamily int32) pbTypes.IPFamily { return pbTypes.IPFamily_v4 } } + +// WaitLocalProcess waits for the specified process for up to timeoutSecs seconds. +// +// Notes: +// +// - If the initial signal is zero, the specified process is assumed to be +// attempting to stop itself. +// - If the initial signal is not zero, it will be sent to the process before +// checking if it is running. +// - If the process has not ended after the timeout value, it will be forcibly killed. +func WaitLocalProcess(pid int, timeoutSecs uint, initialSignal syscall.Signal, logger *logrus.Entry) error { + var err error + + // Don't support process groups + if pid <= 0 { + return errors.New("can only wait for a single process") + } + + if initialSignal != syscall.Signal(0) { + if err = syscall.Kill(pid, initialSignal); err != nil { + if err == syscall.ESRCH { + return nil + } + + return fmt.Errorf("Failed to send initial signal %v to process %v: %v", initialSignal, pid, err) + } + } + + pidRunning := true + + secs := time.Duration(timeoutSecs) + timeout := time.After(secs * time.Second) + + // Wait for the VM process to terminate +outer: + for { + select { + case <-time.After(50 * time.Millisecond): + // Check if the process is running periodically to avoid a busy loop + + var _status syscall.WaitStatus + var _rusage syscall.Rusage + var waitedPid int + + // "A watched pot never boils" and an unwaited-for process never appears to die! + waitedPid, err = syscall.Wait4(pid, &_status, syscall.WNOHANG, &_rusage) + + if waitedPid == pid && err == nil { + pidRunning = false + break outer + } + + if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { + pidRunning = false + break outer + } + + break + + case <-timeout: + logger.Warnf("process %v still running after waiting %ds", pid, timeoutSecs) + + break outer + } + } + + if pidRunning { + // Force process to die + if err = syscall.Kill(pid, syscall.SIGKILL); err != nil { + return fmt.Errorf("Failed to stop process %v: %s", pid, err) + } + } + + return nil +} diff --git a/src/runtime/virtcontainers/utils/utils_test.go b/src/runtime/virtcontainers/utils/utils_test.go index bc296943ad..da95d73f94 100644 --- a/src/runtime/virtcontainers/utils/utils_test.go +++ b/src/runtime/virtcontainers/utils/utils_test.go @@ -9,14 +9,19 @@ import ( "fmt" "io/ioutil" "os" + "os/exec" "path/filepath" "reflect" "strings" + "syscall" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) +const waitLocalProcessTimeoutSecs = 3 + func TestFileCopySuccessful(t *testing.T) { assert := assert.New(t) fileContent := "testContent" @@ -375,3 +380,99 @@ func TestToBytes(t *testing.T) { expected := uint64(1073741824) assert.Equal(expected, result) } + +func TestWaitLocalProcessInvalidSignal(t *testing.T) { + assert := assert.New(t) + + const invalidSignal = syscall.Signal(999) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, invalidSignal, logger) + assert.Error(err) + + err = syscall.Kill(pid, syscall.SIGTERM) + assert.NoError(err) + + err = cmd.Wait() + + // This will error because we killed the process without the knowledge + // of exec.Command. + assert.Error(err) +} + +func TestWaitLocalProcessInvalidPid(t *testing.T) { + assert := assert.New(t) + + invalidPids := []int{-999, -173, -3, -2, -1, 0} + + logger := logrus.WithField("foo", "bar") + + for i, pid := range invalidPids { + msg := fmt.Sprintf("test[%d]: %v", i, pid) + + err := WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.Signal(0), logger) + assert.Error(err, msg) + } +} + +func TestWaitLocalProcessBrief(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("true") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.SIGKILL, logger) + assert.NoError(err) + + _ = cmd.Wait() +} + +func TestWaitLocalProcessLongRunningPreKill(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.SIGKILL, logger) + assert.NoError(err) + + _ = cmd.Wait() +} + +func TestWaitLocalProcessLongRunning(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + // Don't wait for long as the process isn't actually trying to stop, + // so it will have to timeout and then be killed. + const timeoutSecs = 1 + + err = WaitLocalProcess(pid, timeoutSecs, syscall.Signal(0), logger) + assert.NoError(err) + + _ = cmd.Wait() +} diff --git a/src/runtime/virtcontainers/vm.go b/src/runtime/virtcontainers/vm.go index c94a26fc5b..e6f02b6e07 100644 --- a/src/runtime/virtcontainers/vm.go +++ b/src/runtime/virtcontainers/vm.go @@ -137,7 +137,7 @@ func NewVM(ctx context.Context, config VMConfig) (*VM, error) { defer func() { if err != nil { virtLog.WithField("vm", id).WithError(err).Info("clean up vm") - hypervisor.stopSandbox(ctx) + hypervisor.stopSandbox(ctx, false) } }() @@ -251,7 +251,7 @@ func (v *VM) Disconnect(ctx context.Context) error { func (v *VM) Stop(ctx context.Context) error { v.logger().Info("stop vm") - if err := v.hypervisor.stopSandbox(ctx); err != nil { + if err := v.hypervisor.stopSandbox(ctx, false); err != nil { return err }