diff --git a/src/runtime/virtcontainers/acrn.go b/src/runtime/virtcontainers/acrn.go index c1b2532552..3f99350438 100644 --- a/src/runtime/virtcontainers/acrn.go +++ b/src/runtime/virtcontainers/acrn.go @@ -511,36 +511,9 @@ func (a *Acrn) stopSandbox(ctx context.Context) (err error) { pid := a.state.PID - // Send signal to the VM process to try to stop it properly - if err = syscall.Kill(pid, syscall.SIGINT); err != nil { - if err == syscall.ESRCH { - return nil - } - a.Logger().Info("Sending signal to stop acrn VM failed") - return err - } - - // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - a.Logger().Info("acrn VM stopped after sending signal") - return nil - } - - if time.Since(tInit).Seconds() >= acrnStopSandboxTimeoutSecs { - a.Logger().Warnf("VM still running after waiting %ds", acrnStopSandboxTimeoutSecs) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - return syscall.Kill(pid, syscall.SIGKILL) + shutdownSignal := syscall.SIGINT + return utils.WaitLocalProcess(pid, acrnStopSandboxTimeoutSecs, shutdownSignal, a.Logger()) } func (a *Acrn) updateBlockDevice(drive *config.BlockDrive) error { diff --git a/src/runtime/virtcontainers/clh.go b/src/runtime/virtcontainers/clh.go index fcf346808f..2fc74c0760 100644 --- a/src/runtime/virtcontainers/clh.go +++ b/src/runtime/virtcontainers/clh.go @@ -786,31 +786,8 @@ func (clh *cloudHypervisor) terminate(ctx context.Context) (err error) { } } - // At this point the VMM was stop nicely, but need to check if PID is still running - // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - pidRunning = false - break - } - - if time.Since(tInit).Seconds() >= clhStopSandboxTimeout { - pidRunning = true - clh.Logger().Warnf("VM still running after waiting %ds", clhStopSandboxTimeout) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - if pidRunning { - if err = syscall.Kill(pid, syscall.SIGKILL); err != nil { - return fmt.Errorf("Fatal, failed to kill hypervisor process, error: %s", err) - } + if err = utils.WaitLocalProcess(pid, clhStopSandboxTimeout, syscall.Signal(0), clh.Logger()); err != nil { + return err } if clh.virtiofsd == nil { diff --git a/src/runtime/virtcontainers/fc.go b/src/runtime/virtcontainers/fc.go index a223d124e0..c0644ed718 100644 --- a/src/runtime/virtcontainers/fc.go +++ b/src/runtime/virtcontainers/fc.go @@ -425,33 +425,10 @@ func (fc *firecracker) fcEnd(ctx context.Context) (err error) { pid := fc.info.PID - // Send a SIGTERM to the VM process to try to stop it properly - if err = syscall.Kill(pid, syscall.SIGTERM); err != nil { - if err == syscall.ESRCH { - return nil - } - return err - } + shutdownSignal := syscall.SIGTERM // Wait for the VM process to terminate - tInit := time.Now() - for { - if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { - return nil - } - - if time.Since(tInit).Seconds() >= fcStopSandboxTimeout { - fc.Logger().Warnf("VM still running after waiting %ds", fcStopSandboxTimeout) - break - } - - // Let's avoid to run a too busy loop - time.Sleep(time.Duration(50) * time.Millisecond) - } - - // Let's try with a hammer now, a SIGKILL should get rid of the - // VM process. - return syscall.Kill(pid, syscall.SIGKILL) + return utils.WaitLocalProcess(pid, fcStopSandboxTimeout, shutdownSignal, fc.Logger()) } func (fc *firecracker) client(ctx context.Context) *client.Firecracker { diff --git a/src/runtime/virtcontainers/utils/utils.go b/src/runtime/virtcontainers/utils/utils.go index d8bfc1fddd..0936c03ad8 100644 --- a/src/runtime/virtcontainers/utils/utils.go +++ b/src/runtime/virtcontainers/utils/utils.go @@ -12,7 +12,10 @@ import ( "os" "os/exec" "path/filepath" + "syscall" + "time" + "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" pbTypes "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/pkg/agent/protocols" @@ -312,3 +315,63 @@ func ConvertNetlinkFamily(netlinkFamily int32) pbTypes.IPFamily { return pbTypes.IPFamily_v4 } } + +// WaitLocalProcess waits for the specified process for up to timeoutSecs seconds. +// +// Notes: +// +// - If the initial signal is zero, the specified process is assumed to be +// attempting to stop itself. +// - If the initial signal is not zero, it will be sent to the process before +// checking if it is running. +// - If the process has not ended after the timeout value, it will be forcibly killed. +func WaitLocalProcess(pid int, timeoutSecs uint, initialSignal syscall.Signal, logger *logrus.Entry) error { + var err error + + // Don't support process groups + if pid <= 0 { + return errors.New("can only wait for a single process") + } + + if initialSignal != syscall.Signal(0) { + if err = syscall.Kill(pid, initialSignal); err != nil { + if err == syscall.ESRCH { + return nil + } + + return fmt.Errorf("Failed to send initial signal %v to process %v: %v", initialSignal, pid, err) + } + } + + pidRunning := true + + // Wait for the VM process to terminate + startTime := time.Now() + + for { + if err = syscall.Kill(pid, syscall.Signal(0)); err != nil { + pidRunning = false + break + } + + if time.Since(startTime).Seconds() >= float64(timeoutSecs) { + pidRunning = true + + logger.Warnf("process %v still running after waiting %ds", pid, timeoutSecs) + + break + } + + // Brief pause to avoid a busy loop + time.Sleep(time.Duration(50) * time.Millisecond) + } + + if pidRunning { + // Force process to die + if err = syscall.Kill(pid, syscall.SIGKILL); err != nil { + return fmt.Errorf("Failed to stop process %v: %s", pid, err) + } + } + + return nil +} diff --git a/src/runtime/virtcontainers/utils/utils_test.go b/src/runtime/virtcontainers/utils/utils_test.go index bc296943ad..da95d73f94 100644 --- a/src/runtime/virtcontainers/utils/utils_test.go +++ b/src/runtime/virtcontainers/utils/utils_test.go @@ -9,14 +9,19 @@ import ( "fmt" "io/ioutil" "os" + "os/exec" "path/filepath" "reflect" "strings" + "syscall" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) +const waitLocalProcessTimeoutSecs = 3 + func TestFileCopySuccessful(t *testing.T) { assert := assert.New(t) fileContent := "testContent" @@ -375,3 +380,99 @@ func TestToBytes(t *testing.T) { expected := uint64(1073741824) assert.Equal(expected, result) } + +func TestWaitLocalProcessInvalidSignal(t *testing.T) { + assert := assert.New(t) + + const invalidSignal = syscall.Signal(999) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, invalidSignal, logger) + assert.Error(err) + + err = syscall.Kill(pid, syscall.SIGTERM) + assert.NoError(err) + + err = cmd.Wait() + + // This will error because we killed the process without the knowledge + // of exec.Command. + assert.Error(err) +} + +func TestWaitLocalProcessInvalidPid(t *testing.T) { + assert := assert.New(t) + + invalidPids := []int{-999, -173, -3, -2, -1, 0} + + logger := logrus.WithField("foo", "bar") + + for i, pid := range invalidPids { + msg := fmt.Sprintf("test[%d]: %v", i, pid) + + err := WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.Signal(0), logger) + assert.Error(err, msg) + } +} + +func TestWaitLocalProcessBrief(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("true") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.SIGKILL, logger) + assert.NoError(err) + + _ = cmd.Wait() +} + +func TestWaitLocalProcessLongRunningPreKill(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + err = WaitLocalProcess(pid, waitLocalProcessTimeoutSecs, syscall.SIGKILL, logger) + assert.NoError(err) + + _ = cmd.Wait() +} + +func TestWaitLocalProcessLongRunning(t *testing.T) { + assert := assert.New(t) + + cmd := exec.Command("sleep", "999") + err := cmd.Start() + assert.NoError(err) + + pid := cmd.Process.Pid + + logger := logrus.WithField("foo", "bar") + + // Don't wait for long as the process isn't actually trying to stop, + // so it will have to timeout and then be killed. + const timeoutSecs = 1 + + err = WaitLocalProcess(pid, timeoutSecs, syscall.Signal(0), logger) + assert.NoError(err) + + _ = cmd.Wait() +}