diff --git a/src/runtime/virtcontainers/persist/fs/mockfs.go b/src/runtime/virtcontainers/persist/fs/mockfs.go index 23e09c8c04..4c4f492693 100644 --- a/src/runtime/virtcontainers/persist/fs/mockfs.go +++ b/src/runtime/virtcontainers/persist/fs/mockfs.go @@ -13,11 +13,17 @@ import ( persistapi "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/persist/api" ) +var mockTesting = false + type MockFS struct { // inherit from FS. Overwrite if needed. *FS } +func EnableMockTesting() { + mockTesting = true +} + func MockStorageRootPath() string { return filepath.Join(os.TempDir(), "vc", "mockfs") } @@ -46,3 +52,10 @@ func MockFSInit() (persistapi.PersistDriver, error) { return &MockFS{fsDriver}, nil } + +func MockAutoInit() (persistapi.PersistDriver, error) { + if mockTesting { + return MockFSInit() + } + return nil, nil +} diff --git a/src/runtime/virtcontainers/persist/fs/mockfs_test.go b/src/runtime/virtcontainers/persist/fs/mockfs_test.go new file mode 100644 index 0000000000..99709a78e6 --- /dev/null +++ b/src/runtime/virtcontainers/persist/fs/mockfs_test.go @@ -0,0 +1,34 @@ +// Copyright Red Hat. +// +// SPDX-License-Identifier: Apache-2.0 +// + +package fs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMockAutoInit(t *testing.T) { + assert := assert.New(t) + orgMockTesting := mockTesting + defer func() { + mockTesting = orgMockTesting + }() + + mockTesting = false + + fsd, err := MockAutoInit() + assert.Nil(fsd) + assert.NoError(err) + + // Testing mock driver + mockTesting = true + fsd, err = MockAutoInit() + assert.NoError(err) + expectedFS, err := MockFSInit() + assert.NoError(err) + assert.Equal(expectedFS, fsd) +} diff --git a/src/runtime/virtcontainers/persist/manager.go b/src/runtime/virtcontainers/persist/manager.go index a104bdcabf..32504c932b 100644 --- a/src/runtime/virtcontainers/persist/manager.go +++ b/src/runtime/virtcontainers/persist/manager.go @@ -28,13 +28,8 @@ var ( RootFSName: fs.Init, RootlessFSName: fs.RootlessInit, } - mockTesting = false ) -func EnableMockTesting() { - mockTesting = true -} - // GetDriver returns new PersistDriver according to driver name func GetDriverByName(name string) (persistapi.PersistDriver, error) { if expErr != nil { @@ -56,8 +51,9 @@ func GetDriver() (persistapi.PersistDriver, error) { return nil, expErr } - if mockTesting { - return fs.MockFSInit() + mock, err := fs.MockAutoInit() + if mock != nil || err != nil { + return mock, err } if rootless.IsRootless() { diff --git a/src/runtime/virtcontainers/persist/manager_test.go b/src/runtime/virtcontainers/persist/manager_test.go index 074ca92665..4347f9adc2 100644 --- a/src/runtime/virtcontainers/persist/manager_test.go +++ b/src/runtime/virtcontainers/persist/manager_test.go @@ -27,12 +27,6 @@ func TestGetDriverByName(t *testing.T) { func TestGetDriver(t *testing.T) { assert := assert.New(t) - orgMockTesting := mockTesting - defer func() { - mockTesting = orgMockTesting - }() - - mockTesting = false fsd, err := GetDriver() assert.NoError(err) @@ -46,12 +40,4 @@ func TestGetDriver(t *testing.T) { assert.NoError(err) assert.Equal(expectedFS, fsd) - - // Testing mock driver - mockTesting = true - fsd, err = GetDriver() - assert.NoError(err) - expectedFS, err = fs.MockFSInit() - assert.NoError(err) - assert.Equal(expectedFS, fsd) } diff --git a/src/runtime/virtcontainers/virtcontainers_test.go b/src/runtime/virtcontainers/virtcontainers_test.go index cb03e2351b..6a3d7fa580 100644 --- a/src/runtime/virtcontainers/virtcontainers_test.go +++ b/src/runtime/virtcontainers/virtcontainers_test.go @@ -15,7 +15,6 @@ import ( "syscall" "testing" - "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/persist" "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/persist/fs" "github.com/kata-containers/kata-containers/src/runtime/virtcontainers/utils" "github.com/sirupsen/logrus" @@ -108,7 +107,7 @@ func setupClh() { func TestMain(m *testing.M) { var err error - persist.EnableMockTesting() + fs.EnableMockTesting() flag.Parse()