diff --git a/virtcontainers/device/config/config.go b/virtcontainers/device/config/config.go index 57b4f0a7b9..c054f0df55 100644 --- a/virtcontainers/device/config/config.go +++ b/virtcontainers/device/config/config.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "github.com/go-ini/ini" "golang.org/x/sys/unix" @@ -91,6 +92,8 @@ var SysBusPciDevicesPath = "/sys/bus/pci/devices" // SysBusPciSlotsPath is static string of /sys/bus/pci/slots var SysBusPciSlotsPath = "/sys/bus/pci/slots" +var getSysDevPath = getSysDevPathImpl + // DeviceInfo is an embedded type that contains device data common to all types of devices. type DeviceInfo struct { // Hostpath is device path on host @@ -257,29 +260,14 @@ func GetHostPath(devInfo DeviceInfo, vhostUserStoreEnabled bool, vhostUserStoreP return "", fmt.Errorf("Empty path provided for device") } - var pathComp string - - switch devInfo.DevType { - case "c", "u": - pathComp = "char" - case "b": - pathComp = "block" - default: - // Unsupported device types. Return nil error to ignore devices - // that cannot be handled currently. - return "", nil - } - // Filter out vhost-user storage devices by device Major numbers. if vhostUserStoreEnabled && devInfo.DevType == "b" && (devInfo.Major == VhostUserSCSIMajor || devInfo.Major == VhostUserBlkMajor) { return getVhostUserHostPath(devInfo, vhostUserStorePath) } - format := strconv.FormatInt(devInfo.Major, 10) + ":" + strconv.FormatInt(devInfo.Minor, 10) - sysDevPath := filepath.Join(SysDevPrefix, pathComp, format, "uevent") - - if _, err := os.Stat(sysDevPath); err != nil { + ueventPath := filepath.Join(getSysDevPath(devInfo), "uevent") + if _, err := os.Stat(ueventPath); err != nil { // Some devices(eg. /dev/fuse, /dev/cuse) do not always implement sysfs interface under /sys/dev // These devices are passed by default by docker. // @@ -293,7 +281,7 @@ func GetHostPath(devInfo DeviceInfo, vhostUserStoreEnabled bool, vhostUserStoreP return "", err } - content, err := ini.Load(sysDevPath) + content, err := ini.Load(ueventPath) if err != nil { return "", err } @@ -306,6 +294,35 @@ func GetHostPath(devInfo DeviceInfo, vhostUserStoreEnabled bool, vhostUserStoreP return filepath.Join("/dev", devName.String()), nil } +// getBackingFile is used to fetch the backing file for the device. +func getBackingFile(devInfo DeviceInfo) (string, error) { + backingFilePath := filepath.Join(getSysDevPath(devInfo), "loop", "backing_file") + data, err := ioutil.ReadFile(backingFilePath) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(data)), nil +} + +func getSysDevPathImpl(devInfo DeviceInfo) string { + var pathComp string + + switch devInfo.DevType { + case "c", "u": + pathComp = "char" + case "b": + pathComp = "block" + default: + // Unsupported device types. Return nil error to ignore devices + // that cannot be handled currently. + return "" + } + + format := strconv.FormatInt(devInfo.Major, 10) + ":" + strconv.FormatInt(devInfo.Minor, 10) + return filepath.Join(SysDevPrefix, pathComp, format) +} + // getVhostUserHostPath is used to fetch host path for the vhost-user device. // For vhost-user block device like vhost-user-blk or vhost-user-scsi, its // socket should be under directory "/block/sockets/"; diff --git a/virtcontainers/device/config/config_test.go b/virtcontainers/device/config/config_test.go new file mode 100644 index 0000000000..698f52ecd9 --- /dev/null +++ b/virtcontainers/device/config/config_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetBackingFile(t *testing.T) { + assert := assert.New(t) + + dir, err := ioutil.TempDir("", "backing") + assert.NoError(err) + defer os.RemoveAll(dir) + + orgGetSysDevPath := getSysDevPath + getSysDevPath = func(info DeviceInfo) string { + return dir + } + defer func() { getSysDevPath = orgGetSysDevPath }() + + info := DeviceInfo{} + path, err := getBackingFile(info) + assert.Error(err) + assert.Empty(path) + + loopDir := filepath.Join(dir, "loop") + err = os.Mkdir(loopDir, os.FileMode(0755)) + assert.NoError(err) + + backingFile := "/fake-img" + + err = ioutil.WriteFile(filepath.Join(loopDir, "backing_file"), []byte(backingFile), os.FileMode(0755)) + assert.NoError(err) + + path, err = getBackingFile(info) + assert.NoError(err) + assert.Equal(backingFile, path) +} + +func TestGetSysDevPathImpl(t *testing.T) { + assert := assert.New(t) + + info := DeviceInfo{ + DevType: "", + Major: 127, + Minor: 0, + } + + path := getSysDevPathImpl(info) + assert.Empty(path) + + expectedFormat := fmt.Sprintf("%d:%d", info.Major, info.Minor) + + info.DevType = "c" + path = getSysDevPathImpl(info) + assert.Contains(path, expectedFormat) + assert.Contains(path, "char") + + info.DevType = "b" + path = getSysDevPathImpl(info) + assert.Contains(path, expectedFormat) + assert.Contains(path, "block") +}