diff --git a/virtcontainers/container.go b/virtcontainers/container.go index dca9df5f0f..d136be673e 100644 --- a/virtcontainers/container.go +++ b/virtcontainers/container.go @@ -24,6 +24,7 @@ import ( "github.com/kata-containers/runtime/virtcontainers/utils" specs "github.com/opencontainers/runtime-spec/specs-go" opentracing "github.com/opentracing/opentracing-go" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" @@ -612,6 +613,20 @@ func (c *Container) unmountHostMounts() error { return err } + if m.Type == "bind" { + s, err := os.Stat(m.HostPath) + if err != nil { + return errors.Wrapf(err, "Could not stat host-path %v", m.HostPath) + } + // Remove the empty file or directory + if s.Mode().IsRegular() && s.Size() == 0 { + os.Remove(m.HostPath) + } + if s.Mode().IsDir() { + syscall.Rmdir(m.HostPath) + } + } + span.Finish() } } diff --git a/virtcontainers/container_test.go b/virtcontainers/container_test.go index ff54e009c7..995fccee67 100644 --- a/virtcontainers/container_test.go +++ b/virtcontainers/container_test.go @@ -133,6 +133,120 @@ func TestContainerRemoveDrive(t *testing.T) { assert.Nil(t, err, "remove drive should succeed") } +func TestUnmountHostMountsRemoveBindHostPath(t *testing.T) { + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(testDisabledAsNonRoot) + } + + createFakeMountDir := func(t *testing.T, dir, prefix string) string { + name, err := ioutil.TempDir(dir, "test-mnt-"+prefix+"-") + if err != nil { + t.Fatal(err) + } + return name + } + + createFakeMountFile := func(t *testing.T, dir, prefix string) string { + f, err := ioutil.TempFile(dir, "test-mnt-"+prefix+"-") + if err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() + } + + doUnmountCheck := func(src, dest, hostPath, nonEmptyHostpath, devPath string) { + mounts := []Mount{ + { + Source: src, + Destination: dest, + HostPath: hostPath, + Type: "bind", + }, + { + Source: src, + Destination: dest, + HostPath: nonEmptyHostpath, + Type: "bind", + }, + { + Source: src, + Destination: dest, + HostPath: devPath, + Type: "dev", + }, + } + + c := Container{ + mounts: mounts, + ctx: context.Background(), + } + + if err := bindMount(c.ctx, src, hostPath, false); err != nil { + t.Fatal(err) + } + defer syscall.Unmount(hostPath, 0) + if err := bindMount(c.ctx, src, nonEmptyHostpath, false); err != nil { + t.Fatal(err) + } + defer syscall.Unmount(nonEmptyHostpath, 0) + if err := bindMount(c.ctx, src, devPath, false); err != nil { + t.Fatal(err) + } + defer syscall.Unmount(devPath, 0) + + err := c.unmountHostMounts() + if err != nil { + t.Fatal(err) + } + + for _, path := range [3]string{src, dest, devPath} { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + t.Fatalf("path %s should not be removed", path) + } else { + t.Fatal(err) + } + } + } + + if _, err := os.Stat(hostPath); err == nil { + t.Fatal("empty host-path should be removed") + } else if !os.IsNotExist(err) { + t.Fatal(err) + } + + if _, err := os.Stat(nonEmptyHostpath); err != nil { + if os.IsNotExist(err) { + t.Fatal("non-empty host-path should not be removed") + } else { + t.Fatal(err) + } + } + } + + src := createFakeMountDir(t, testDir, "src") + dest := createFakeMountDir(t, testDir, "dest") + hostPath := createFakeMountDir(t, testDir, "host-path") + nonEmptyHostpath := createFakeMountDir(t, testDir, "non-empty-host-path") + devPath := createFakeMountDir(t, testDir, "dev-hostpath") + createFakeMountDir(t, nonEmptyHostpath, "nop") + doUnmountCheck(src, dest, hostPath, nonEmptyHostpath, devPath) + + src = createFakeMountFile(t, testDir, "src") + dest = createFakeMountFile(t, testDir, "dest") + hostPath = createFakeMountFile(t, testDir, "host-path") + nonEmptyHostpath = createFakeMountFile(t, testDir, "non-empty-host-path") + devPath = createFakeMountFile(t, testDir, "dev-host-path") + f, err := os.OpenFile(nonEmptyHostpath, os.O_WRONLY, os.FileMode(0640)) + if err != nil { + t.Fatal(err) + } + f.WriteString("nop\n") + f.Close() + doUnmountCheck(src, dest, hostPath, nonEmptyHostpath, devPath) +} + func testSetupFakeRootfs(t *testing.T) (testRawFile, loopDev, mntDir string, err error) { assert := assert.New(t) if tc.NotValid(ktu.NeedRoot()) {