diff --git a/virtcontainers/persist/manager.go b/virtcontainers/persist/manager.go index 7e46b36bce..228b835953 100644 --- a/virtcontainers/persist/manager.go +++ b/virtcontainers/persist/manager.go @@ -1,4 +1,5 @@ // Copyright (c) 2019 Huawei Corporation +// Copyright (c) 2020 Intel Corporation // // SPDX-License-Identifier: Apache-2.0 // @@ -9,12 +10,18 @@ import ( "fmt" exp "github.com/kata-containers/runtime/virtcontainers/experimental" - "github.com/kata-containers/runtime/virtcontainers/persist/api" + persistapi "github.com/kata-containers/runtime/virtcontainers/persist/api" "github.com/kata-containers/runtime/virtcontainers/persist/fs" + "github.com/kata-containers/runtime/virtcontainers/pkg/rootless" ) type initFunc (func() (persistapi.PersistDriver, error)) +const ( + RootFSName = "fs" + RootlessFSName = "rootlessfs" +) + var ( // NewStoreFeature is an experimental feature NewStoreFeature = exp.Feature{ @@ -25,16 +32,22 @@ var ( expErr error supportedDrivers = map[string]initFunc{ - "fs": fs.Init, + RootFSName: fs.Init, + RootlessFSName: fs.RootlessInit, } + mockTesting = false ) func init() { expErr = exp.Register(NewStoreFeature) } +func EnableMockTesting() { + mockTesting = true +} + // GetDriver returns new PersistDriver according to driver name -func GetDriver(name string) (persistapi.PersistDriver, error) { +func GetDriverByName(name string) (persistapi.PersistDriver, error) { if expErr != nil { return nil, expErr } @@ -45,3 +58,28 @@ func GetDriver(name string) (persistapi.PersistDriver, error) { return nil, fmt.Errorf("failed to get storage driver %q", name) } + +// GetDriver returns new PersistDriver according to current needs. +// For example, a rootless FS driver is returned if the process is running +// as unprivileged process. +func GetDriver() (persistapi.PersistDriver, error) { + if expErr != nil { + return nil, expErr + } + + if mockTesting { + return fs.MockFSInit() + } + + if rootless.IsRootless() { + if f, ok := supportedDrivers[RootlessFSName]; ok { + return f() + } + } + + if f, ok := supportedDrivers[RootFSName]; ok { + return f() + } + + return nil, fmt.Errorf("Could not find a FS driver") +} diff --git a/virtcontainers/persist/manager_test.go b/virtcontainers/persist/manager_test.go index f0d5b0383e..bd74b36657 100644 --- a/virtcontainers/persist/manager_test.go +++ b/virtcontainers/persist/manager_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2019 Huawei Corporation +// Copyright (c) 2020 Intel Corporation // // SPDX-License-Identifier: Apache-2.0 // @@ -6,17 +7,51 @@ package persist import ( + "os" "testing" + persistapi "github.com/kata-containers/runtime/virtcontainers/persist/api" + "github.com/kata-containers/runtime/virtcontainers/persist/fs" "github.com/stretchr/testify/assert" ) -func TestGetDriver(t *testing.T) { - nonexist, err := GetDriver("non-exist") +func TestGetDriverByName(t *testing.T) { + nonexist, err := GetDriverByName("non-exist") assert.NotNil(t, err) assert.Nil(t, nonexist) - fsDriver, err := GetDriver("fs") + fsDriver, err := GetDriverByName("fs") assert.Nil(t, err) assert.NotNil(t, fsDriver) } + +func TestGetDriver(t *testing.T) { + assert := assert.New(t) + orgMockTesting := mockTesting + defer func() { + mockTesting = orgMockTesting + }() + + mockTesting = false + + fsd, err := GetDriver() + assert.NoError(err) + + var expectedFS persistapi.PersistDriver + if os.Getuid() != 0 { + expectedFS, err = fs.RootlessInit() + } else { + expectedFS, err = fs.Init() + } + + assert.NoError(err) + assert.Equal(expectedFS, fsd) + + // Testing mock driver + mockTesting = true + fsd, err = GetDriver() + assert.NoError(err) + expectedFS, err = fs.MockFSInit() + assert.NoError(err) + assert.Equal(expectedFS, fsd) +}