virtcontainers/persist: update GetDriver to support rootless fs

GetDriver returns new PersistDriver according to current needs, a mock fs
driver is returned when mockTesting is enabled, a rootless fs is returned when
rootless is detected, otherwise a fs driver is used.

Signed-off-by: Julio Montes <julio.montes@intel.com>
This commit is contained in:
Julio Montes 2020-01-31 20:13:14 +00:00
parent dd2762fdad
commit 71f48a3364
2 changed files with 79 additions and 6 deletions

View File

@ -1,4 +1,5 @@
// Copyright (c) 2019 Huawei Corporation // Copyright (c) 2019 Huawei Corporation
// Copyright (c) 2020 Intel Corporation
// //
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
@ -9,12 +10,18 @@ import (
"fmt" "fmt"
exp "github.com/kata-containers/runtime/virtcontainers/experimental" exp "github.com/kata-containers/runtime/virtcontainers/experimental"
"github.com/kata-containers/runtime/virtcontainers/persist/api" persistapi "github.com/kata-containers/runtime/virtcontainers/persist/api"
"github.com/kata-containers/runtime/virtcontainers/persist/fs" "github.com/kata-containers/runtime/virtcontainers/persist/fs"
"github.com/kata-containers/runtime/virtcontainers/pkg/rootless"
) )
type initFunc (func() (persistapi.PersistDriver, error)) type initFunc (func() (persistapi.PersistDriver, error))
const (
RootFSName = "fs"
RootlessFSName = "rootlessfs"
)
var ( var (
// NewStoreFeature is an experimental feature // NewStoreFeature is an experimental feature
NewStoreFeature = exp.Feature{ NewStoreFeature = exp.Feature{
@ -25,16 +32,22 @@ var (
expErr error expErr error
supportedDrivers = map[string]initFunc{ supportedDrivers = map[string]initFunc{
"fs": fs.Init, RootFSName: fs.Init,
RootlessFSName: fs.RootlessInit,
} }
mockTesting = false
) )
func init() { func init() {
expErr = exp.Register(NewStoreFeature) expErr = exp.Register(NewStoreFeature)
} }
func EnableMockTesting() {
mockTesting = true
}
// GetDriver returns new PersistDriver according to driver name // GetDriver returns new PersistDriver according to driver name
func GetDriver(name string) (persistapi.PersistDriver, error) { func GetDriverByName(name string) (persistapi.PersistDriver, error) {
if expErr != nil { if expErr != nil {
return nil, expErr return nil, expErr
} }
@ -45,3 +58,28 @@ func GetDriver(name string) (persistapi.PersistDriver, error) {
return nil, fmt.Errorf("failed to get storage driver %q", name) return nil, fmt.Errorf("failed to get storage driver %q", name)
} }
// GetDriver returns new PersistDriver according to current needs.
// For example, a rootless FS driver is returned if the process is running
// as unprivileged process.
func GetDriver() (persistapi.PersistDriver, error) {
if expErr != nil {
return nil, expErr
}
if mockTesting {
return fs.MockFSInit()
}
if rootless.IsRootless() {
if f, ok := supportedDrivers[RootlessFSName]; ok {
return f()
}
}
if f, ok := supportedDrivers[RootFSName]; ok {
return f()
}
return nil, fmt.Errorf("Could not find a FS driver")
}

View File

@ -1,4 +1,5 @@
// Copyright (c) 2019 Huawei Corporation // Copyright (c) 2019 Huawei Corporation
// Copyright (c) 2020 Intel Corporation
// //
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
@ -6,17 +7,51 @@
package persist package persist
import ( import (
"os"
"testing" "testing"
persistapi "github.com/kata-containers/runtime/virtcontainers/persist/api"
"github.com/kata-containers/runtime/virtcontainers/persist/fs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGetDriver(t *testing.T) { func TestGetDriverByName(t *testing.T) {
nonexist, err := GetDriver("non-exist") nonexist, err := GetDriverByName("non-exist")
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Nil(t, nonexist) assert.Nil(t, nonexist)
fsDriver, err := GetDriver("fs") fsDriver, err := GetDriverByName("fs")
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, fsDriver) assert.NotNil(t, fsDriver)
} }
func TestGetDriver(t *testing.T) {
assert := assert.New(t)
orgMockTesting := mockTesting
defer func() {
mockTesting = orgMockTesting
}()
mockTesting = false
fsd, err := GetDriver()
assert.NoError(err)
var expectedFS persistapi.PersistDriver
if os.Getuid() != 0 {
expectedFS, err = fs.RootlessInit()
} else {
expectedFS, err = fs.Init()
}
assert.NoError(err)
assert.Equal(expectedFS, fsd)
// Testing mock driver
mockTesting = true
fsd, err = GetDriver()
assert.NoError(err)
expectedFS, err = fs.MockFSInit()
assert.NoError(err)
assert.Equal(expectedFS, fsd)
}