diff --git a/virtcontainers/container_test.go b/virtcontainers/container_test.go index b11e83aff3..5648b9a150 100644 --- a/virtcontainers/container_test.go +++ b/virtcontainers/container_test.go @@ -174,15 +174,15 @@ func TestUnmountHostMountsRemoveBindHostPath(t *testing.T) { ctx: context.Background(), } - if err := bindMount(c.ctx, src, hostPath, false); err != nil { + if err := bindMount(c.ctx, src, hostPath, false, "private"); err != nil { t.Fatal(err) } defer syscall.Unmount(hostPath, 0) - if err := bindMount(c.ctx, src, nonEmptyHostpath, false); err != nil { + if err := bindMount(c.ctx, src, nonEmptyHostpath, false, "private"); err != nil { t.Fatal(err) } defer syscall.Unmount(nonEmptyHostpath, 0) - if err := bindMount(c.ctx, src, devPath, false); err != nil { + if err := bindMount(c.ctx, src, devPath, false, "private"); err != nil { t.Fatal(err) } defer syscall.Unmount(devPath, 0) diff --git a/virtcontainers/mount_test.go b/virtcontainers/mount_test.go index 7fc36a9137..f4d94f15db 100644 --- a/virtcontainers/mount_test.go +++ b/virtcontainers/mount_test.go @@ -189,7 +189,7 @@ func TestGetDeviceForPathBindMount(t *testing.T) { defer os.Remove(dest) - err = bindMount(context.Background(), source, dest, false) + err = bindMount(context.Background(), source, dest, false, "private") assert.NoError(err) defer syscall.Unmount(dest, syscall.MNT_DETACH) @@ -283,6 +283,107 @@ func TestIsEphemeralStorage(t *testing.T) { assert.False(isHostEmptyDir) } +func TestBindMountInvalidSourceSymlink(t *testing.T) { + source := filepath.Join(testDir, "fooFile") + os.Remove(source) + + err := bindMount(context.Background(), source, "", false, "private") + assert.Error(t, err) +} + +func TestBindMountFailingMount(t *testing.T) { + source := filepath.Join(testDir, "fooLink") + fakeSource := filepath.Join(testDir, "fooFile") + os.Remove(source) + os.Remove(fakeSource) + assert := assert.New(t) + + _, err := os.OpenFile(fakeSource, os.O_CREATE, mountPerm) + assert.NoError(err) + + err = os.Symlink(fakeSource, source) + assert.NoError(err) + + err = bindMount(context.Background(), source, "", false, "private") + assert.Error(err) +} + +func TestBindMountSuccessful(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(testDisabledAsNonRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + syscall.Unmount(dest, 0) + os.Remove(source) + os.Remove(dest) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, false, "private") + assert.NoError(err) + + syscall.Unmount(dest, 0) +} + +func TestBindMountReadonlySuccessful(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(testDisabledAsNonRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + syscall.Unmount(dest, 0) + os.Remove(source) + os.Remove(dest) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, true, "private") + assert.NoError(err) + + defer syscall.Unmount(dest, 0) + + // should not be able to create file in read-only mount + destFile := filepath.Join(dest, "foo") + _, err = os.OpenFile(destFile, os.O_CREATE, mountPerm) + assert.Error(err) +} + +func TestBindMountInvalidPgtypes(t *testing.T) { + assert := assert.New(t) + if tc.NotValid(ktu.NeedRoot()) { + t.Skip(testDisabledAsNonRoot) + } + + source := filepath.Join(testDir, "fooDirSrc") + dest := filepath.Join(testDir, "fooDirDest") + syscall.Unmount(dest, 0) + os.Remove(source) + os.Remove(dest) + + err := os.MkdirAll(source, mountPerm) + assert.NoError(err) + + err = os.MkdirAll(dest, mountPerm) + assert.NoError(err) + + err = bindMount(context.Background(), source, dest, false, "foo") + expectedErr := fmt.Sprintf("Wrong propagation type %s", "foo") + assert.EqualError(err, expectedErr) +} + // TestBindUnmountContainerRootfsENOENTNotError tests that if a file // or directory attempting to be unmounted doesn't exist, then it // is not considered an error diff --git a/virtcontainers/syscall_test.go b/virtcontainers/syscall_test.go index 568e40a2b4..ed321a4de3 100644 --- a/virtcontainers/syscall_test.go +++ b/virtcontainers/syscall_test.go @@ -7,94 +7,13 @@ package virtcontainers import ( - "context" "os" "path/filepath" - "syscall" "testing" - ktu "github.com/kata-containers/runtime/pkg/katatestutils" "github.com/stretchr/testify/assert" ) -func TestBindMountInvalidSourceSymlink(t *testing.T) { - source := filepath.Join(testDir, "fooFile") - os.Remove(source) - - err := bindMount(context.Background(), source, "", false) - assert.Error(t, err) -} - -func TestBindMountFailingMount(t *testing.T) { - source := filepath.Join(testDir, "fooLink") - fakeSource := filepath.Join(testDir, "fooFile") - os.Remove(source) - os.Remove(fakeSource) - assert := assert.New(t) - - _, err := os.OpenFile(fakeSource, os.O_CREATE, mountPerm) - assert.NoError(err) - - err = os.Symlink(fakeSource, source) - assert.NoError(err) - - err = bindMount(context.Background(), source, "", false) - assert.Error(err) -} - -func TestBindMountSuccessful(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(testDisabledAsNonRoot) - } - - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") - syscall.Unmount(dest, 0) - os.Remove(source) - os.Remove(dest) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - err = bindMount(context.Background(), source, dest, false) - assert.NoError(err) - - syscall.Unmount(dest, 0) -} - -func TestBindMountReadonlySuccessful(t *testing.T) { - assert := assert.New(t) - if tc.NotValid(ktu.NeedRoot()) { - t.Skip(testDisabledAsNonRoot) - } - - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") - syscall.Unmount(dest, 0) - os.Remove(source) - os.Remove(dest) - - err := os.MkdirAll(source, mountPerm) - assert.NoError(err) - - err = os.MkdirAll(dest, mountPerm) - assert.NoError(err) - - err = bindMount(context.Background(), source, dest, true) - assert.NoError(err) - - defer syscall.Unmount(dest, 0) - - // should not be able to create file in read-only mount - destFile := filepath.Join(dest, "foo") - _, err = os.OpenFile(destFile, os.O_CREATE, mountPerm) - assert.Error(err) -} - func TestEnsureDestinationExistsNonExistingSource(t *testing.T) { err := ensureDestinationExists("", "") assert.Error(t, err) @@ -107,8 +26,9 @@ func TestEnsureDestinationExistsWrongParentDir(t *testing.T) { os.Remove(dest) assert := assert.New(t) - _, err := os.OpenFile(source, os.O_CREATE, mountPerm) + file, err := os.OpenFile(source, os.O_CREATE, mountPerm) assert.NoError(err) + defer file.Close() err = ensureDestinationExists(source, dest) assert.Error(err) @@ -123,20 +43,22 @@ func TestEnsureDestinationExistsSuccessfulSrcDir(t *testing.T) { err := os.MkdirAll(source, mountPerm) assert.NoError(err) + defer os.Remove(source) err = ensureDestinationExists(source, dest) assert.NoError(err) } func TestEnsureDestinationExistsSuccessfulSrcFile(t *testing.T) { - source := filepath.Join(testDir, "fooDirSrc") - dest := filepath.Join(testDir, "fooDirDest") + source := filepath.Join(testDir, "fooFileSrc") + dest := filepath.Join(testDir, "fooFileDest") os.Remove(source) os.Remove(dest) assert := assert.New(t) - _, err := os.OpenFile(source, os.O_CREATE, mountPerm) + file, err := os.OpenFile(source, os.O_CREATE, mountPerm) assert.NoError(err) + defer file.Close() err = ensureDestinationExists(source, dest) assert.NoError(err)