diff --git a/src/runtime/virtcontainers/kata_agent.go b/src/runtime/virtcontainers/kata_agent.go index 9c92747138..12756f0b11 100644 --- a/src/runtime/virtcontainers/kata_agent.go +++ b/src/runtime/virtcontainers/kata_agent.go @@ -383,7 +383,7 @@ func (k *kataAgent) internalConfigure(h hypervisor, id string, config KataAgentC return nil } -func setupSandboxBindMounts(sandbox *Sandbox) error { +func (k *kataAgent) setupSandboxBindMounts(sandbox *Sandbox) (err error) { if len(sandbox.config.SandboxBindMounts) == 0 { return nil } @@ -394,6 +394,20 @@ func setupSandboxBindMounts(sandbox *Sandbox) error { if err := os.MkdirAll(sandboxMountDir, DirMode); err != nil { return fmt.Errorf("Creating sandbox shared mount directory: %v: %w", sandboxMountDir, err) } + var mountedList []string + defer func() { + if err != nil { + for _, mnt := range mountedList { + if derr := syscall.Unmount(mnt, syscall.MNT_DETACH|UmountNoFollow); derr != nil { + k.Logger().WithError(derr).Errorf("cleanup: couldn't unmount %s", mnt) + } + } + if derr := os.RemoveAll(sandboxMountDir); derr != nil { + k.Logger().WithError(derr).Errorf("cleanup: failed to remove %s", sandboxMountDir) + } + + } + }() for _, m := range sandbox.config.SandboxBindMounts { mountDest := filepath.Join(sandboxMountDir, filepath.Base(m)) @@ -401,6 +415,7 @@ func setupSandboxBindMounts(sandbox *Sandbox) error { if err := bindMount(context.Background(), m, mountDest, true, "private"); err != nil { return fmt.Errorf("Mounting sandbox directory: %v to %v: %w", m, mountDest, err) } + mountedList = append(mountedList, mountDest) mountDest = filepath.Join(sandboxShareDir, filepath.Base(m)) if err := remountRo(context.Background(), mountDest); err != nil { @@ -412,19 +427,30 @@ func setupSandboxBindMounts(sandbox *Sandbox) error { return nil } -func cleanupSandboxBindMounts(sandbox *Sandbox) error { +func (k *kataAgent) cleanupSandboxBindMounts(sandbox *Sandbox) error { if sandbox.config == nil || len(sandbox.config.SandboxBindMounts) == 0 { return nil } + var retErr error + bindmountShareDir := filepath.Join(getMountPath(sandbox.id), sandboxMountsDir) for _, m := range sandbox.config.SandboxBindMounts { - mountPath := filepath.Join(getMountPath(sandbox.id), sandboxMountsDir, filepath.Base(m)) + mountPath := filepath.Join(bindmountShareDir, filepath.Base(m)) if err := syscall.Unmount(mountPath, syscall.MNT_DETACH|UmountNoFollow); err != nil { - return fmt.Errorf("Unmounting observe directory: %v: %w", mountPath, err) + if retErr == nil { + retErr = err + } + k.Logger().WithError(err).Errorf("Failed to unmount sandbox bindmount: %v", mountPath) } } + if err := os.RemoveAll(bindmountShareDir); err != nil { + if retErr == nil { + retErr = err + } + k.Logger().WithError(err).Errorf("Failed to remove sandbox bindmount directory: %s", bindmountShareDir) + } - return nil + return retErr } func (k *kataAgent) configure(ctx context.Context, h hypervisor, id, sharePath string, config KataAgentConfig) error { @@ -473,7 +499,7 @@ func (k *kataAgent) configureFromGrpc(h hypervisor, id string, config KataAgentC return k.internalConfigure(h, id, config) } -func (k *kataAgent) setupSharedPath(ctx context.Context, sandbox *Sandbox) error { +func (k *kataAgent) setupSharedPath(ctx context.Context, sandbox *Sandbox) (err error) { // create shared path structure sharePath := getSharePath(sandbox.id) mountPath := getMountPath(sandbox.id) @@ -488,9 +514,16 @@ func (k *kataAgent) setupSharedPath(ctx context.Context, sandbox *Sandbox) error if err := bindMount(ctx, mountPath, sharePath, true, "slave"); err != nil { return err } + defer func() { + if err != nil { + if umountErr := syscall.Unmount(sharePath, syscall.MNT_DETACH|UmountNoFollow); umountErr != nil { + k.Logger().WithError(umountErr).Errorf("failed to unmount vm share path %s", sharePath) + } + } + }() // Setup sandbox bindmounts, if specified: - if err := setupSandboxBindMounts(sandbox); err != nil { + if err = k.setupSandboxBindMounts(sandbox); err != nil { return err } @@ -2149,8 +2182,8 @@ func (k *kataAgent) markDead(ctx context.Context) { } func (k *kataAgent) cleanup(ctx context.Context, s *Sandbox) { - if err := cleanupSandboxBindMounts(s); err != nil { - k.Logger().WithError(err).Errorf("failed to cleanup observability logs bindmount") + if err := k.cleanupSandboxBindMounts(s); err != nil { + k.Logger().WithError(err).Errorf("failed to cleanup sandbox bindmounts") } // Unmount shared path diff --git a/src/runtime/virtcontainers/kata_agent_test.go b/src/runtime/virtcontainers/kata_agent_test.go index ef52299423..50235f99cd 100644 --- a/src/runtime/virtcontainers/kata_agent_test.go +++ b/src/runtime/virtcontainers/kata_agent_test.go @@ -1216,3 +1216,116 @@ func TestKataAgentDirs(t *testing.T) { assert.Equal(ephemeralPath(), defaultEphemeralPath) } } + +func TestSandboxBindMount(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("Test disabled as requires root user") + } + + assert := assert.New(t) + // create temporary files to mount: + testMountPath, err := ioutil.TempDir("", "sandbox-test") + assert.NoError(err) + defer os.RemoveAll(testMountPath) + + // create a new shared directory for our test: + kataHostSharedDirSaved := kataHostSharedDir + testHostDir, err := ioutil.TempDir("", "kata-cleanup") + assert.NoError(err) + kataHostSharedDir = func() string { + return testHostDir + } + defer func() { + kataHostSharedDir = kataHostSharedDirSaved + }() + + m1Path := filepath.Join(testMountPath, "foo.txt") + f1, err := os.Create(m1Path) + assert.NoError(err) + defer f1.Close() + + m2Path := filepath.Join(testMountPath, "bar.txt") + f2, err := os.Create(m2Path) + assert.NoError(err) + defer f2.Close() + + // create sandbox for mounting into + sandbox := &Sandbox{ + ctx: context.Background(), + id: "foobar", + config: &SandboxConfig{ + SandboxBindMounts: []string{m1Path, m2Path}, + }, + } + k := &kataAgent{ctx: context.Background()} + + // make the shared directory for our test: + dir := kataHostSharedDir() + err = os.MkdirAll(path.Join(dir, sandbox.id), 0777) + assert.Nil(err) + defer os.RemoveAll(dir) + + sharePath := getSharePath(sandbox.id) + mountPath := getMountPath(sandbox.id) + + err = os.MkdirAll(sharePath, DirMode) + assert.Nil(err) + err = os.MkdirAll(mountPath, DirMode) + assert.Nil(err) + + // setup the expeted slave mount: + err = bindMount(sandbox.ctx, mountPath, sharePath, true, "slave") + assert.Nil(err) + defer syscall.Unmount(sharePath, syscall.MNT_DETACH|UmountNoFollow) + + // Test the function. We expect it to succeed and for the mount to exist + err = k.setupSandboxBindMounts(sandbox) + assert.NoError(err) + + // Test the cleanup function. We expect it to succeed for the mount to be removed. + err = k.cleanupSandboxBindMounts(sandbox) + assert.NoError(err) + + // After successful cleanup, verify there are not any mounts left behind. + stat := syscall.Stat_t{} + mount1CheckPath := filepath.Join(getMountPath(sandbox.id), sandboxMountsDir, filepath.Base(m1Path)) + err = syscall.Stat(mount1CheckPath, &stat) + assert.Error(err) + assert.True(os.IsNotExist(err)) + + mount2CheckPath := filepath.Join(getMountPath(sandbox.id), sandboxMountsDir, filepath.Base(m2Path)) + err = syscall.Stat(mount2CheckPath, &stat) + assert.Error(err) + assert.True(os.IsNotExist(err)) + + // Now, let's setup the cleanup to fail. Setup the sandbox bind mount twice, which will result in + // extra mounts being present that the sandbox description doesn't account for (ie, duplicate mounts). + // We expect cleanup to fail on the first time, since it cannot remove the sandbox-bindmount directory because + // there are leftover mounts. If we run it a second time, however, it should succeed since it'll remove the + // second set of mounts: + err = k.setupSandboxBindMounts(sandbox) + assert.NoError(err) + err = k.setupSandboxBindMounts(sandbox) + assert.NoError(err) + // Test the cleanup function. We expect it to succeed for the mount to be removed. + err = k.cleanupSandboxBindMounts(sandbox) + assert.Error(err) + err = k.cleanupSandboxBindMounts(sandbox) + assert.NoError(err) + + // + // Now, let's setup the sandbox bindmount to fail, and verify that no mounts are left behind + // + sandbox.config.SandboxBindMounts = append(sandbox.config.SandboxBindMounts, "oh-nos") + err = k.setupSandboxBindMounts(sandbox) + assert.Error(err) + // Verify there aren't any mounts left behind + stat = syscall.Stat_t{} + err = syscall.Stat(mount1CheckPath, &stat) + assert.Error(err) + assert.True(os.IsNotExist(err)) + err = syscall.Stat(mount2CheckPath, &stat) + assert.Error(err) + assert.True(os.IsNotExist(err)) + +}