diff --git a/virtcontainers/container.go b/virtcontainers/container.go index cf795720e1..a54e493419 100644 --- a/virtcontainers/container.go +++ b/virtcontainers/container.go @@ -342,28 +342,6 @@ func (c *Container) setStateFstype(fstype string) error { return nil } -func (c *Container) setStateHotpluggedDrive(hotplugged bool) error { - c.state.HotpluggedDrive = hotplugged - - err := c.sandbox.storage.storeContainerResource(c.sandbox.id, c.id, stateFileType, c.state) - if err != nil { - return err - } - - return nil -} - -func (c *Container) setContainerRootfsPCIAddr(addr string) error { - c.state.RootfsPCIAddr = addr - - err := c.sandbox.storage.storeContainerResource(c.sandbox.id, c.id, stateFileType, c.state) - if err != nil { - return err - } - - return nil -} - // GetAnnotations returns container's annotations func (c *Container) GetAnnotations() map[string]string { return c.config.Annotations @@ -457,8 +435,7 @@ func (c *Container) mountSharedDirMounts(hostSharedDir, guestSharedDir string) ( // instead of passing this as a shared mount. if len(m.BlockDeviceID) > 0 { // Attach this block device, all other devices passed in the config have been attached at this point - if err := c.sandbox.devManager.AttachDevice(m.BlockDeviceID, c.sandbox); err != nil && - err != manager.ErrDeviceAttached { + if err := c.sandbox.devManager.AttachDevice(m.BlockDeviceID, c.sandbox); err != nil { return nil, err } @@ -1087,38 +1064,43 @@ func (c *Container) hotplugDrive() error { return err } + devicePath, err = filepath.EvalSymlinks(devicePath) + if err != nil { + return err + } + c.Logger().WithFields(logrus.Fields{ "device-path": devicePath, "fs-type": fsType, }).Info("Block device detected") - driveIndex, err := c.sandbox.getAndSetSandboxBlockIndex() - if err != nil { - return err + var stat unix.Stat_t + if err := unix.Stat(devicePath, &stat); err != nil { + return fmt.Errorf("stat %q failed: %v", devicePath, err) } - // TODO: use general device manager instead of BlockDrive directly - // Add drive with id as container id - devID := utils.MakeNameID("drive", c.id, maxDevIDSize) - drive := config.BlockDrive{ - File: devicePath, - Format: "raw", - ID: devID, - Index: driveIndex, - } + if c.checkBlockDeviceSupport() && stat.Mode&unix.S_IFBLK == unix.S_IFBLK { + b, err := c.sandbox.devManager.NewDevice(config.DeviceInfo{ + HostPath: devicePath, + ContainerPath: filepath.Join(kataGuestSharedDir, c.id), + DevType: "b", + Major: int64(unix.Major(stat.Rdev)), + Minor: int64(unix.Minor(stat.Rdev)), + }) + if err != nil { + return fmt.Errorf("device manager failed to create rootfs device for %q: %v", devicePath, err) + } - if _, err := c.sandbox.hypervisor.hotplugAddDevice(&drive, blockDev); err != nil { - return err - } + c.state.BlockDeviceID = b.DeviceID() - if drive.PCIAddr != "" { - c.setContainerRootfsPCIAddr(drive.PCIAddr) - } + // attach rootfs device + if err := c.sandbox.devManager.AttachDevice(b.DeviceID(), c.sandbox); err != nil { + return err + } - c.setStateHotpluggedDrive(true) - - if err := c.setStateBlockIndex(driveIndex); err != nil { - return err + if err := c.sandbox.storeSandboxDevices(); err != nil { + return err + } } return c.setStateFstype(fsType) @@ -1133,19 +1115,28 @@ func (c *Container) isDriveUsed() bool { } func (c *Container) removeDrive() (err error) { - if c.isDriveUsed() && c.state.HotpluggedDrive { + if c.isDriveUsed() { c.Logger().Info("unplugging block device") - devID := utils.MakeNameID("drive", c.id, maxDevIDSize) - drive := &config.BlockDrive{ - ID: devID, + devID := c.state.BlockDeviceID + err := c.sandbox.devManager.DetachDevice(devID, c.sandbox) + if err != nil && err != manager.ErrDeviceNotAttached { + return err } - l := c.Logger().WithField("device-id", devID) - l.Info("Unplugging block device") + if err = c.sandbox.devManager.RemoveDevice(devID); err != nil { + c.Logger().WithFields(logrus.Fields{ + "container": c.id, + "device-id": devID, + }).WithError(err).Error("remove device failed") - if _, err := c.sandbox.hypervisor.hotplugRemoveDevice(drive, blockDev); err != nil { - l.WithError(err).Info("Failed to unplug block device") + // ignore the device not exist error + if err != manager.ErrDeviceNotExist { + return err + } + } + + if err := c.sandbox.storeSandboxDevices(); err != nil { return err } } @@ -1159,10 +1150,6 @@ func (c *Container) attachDevices() error { // and rollbackFailingContainerCreation could do all the rollbacks for _, dev := range c.devices { if err := c.sandbox.devManager.AttachDevice(dev.ID, c.sandbox); err != nil { - if err == manager.ErrDeviceAttached { - // skip if device is already attached before - continue - } return err } } diff --git a/virtcontainers/container_test.go b/virtcontainers/container_test.go index 68bfd699e1..139a767ea4 100644 --- a/virtcontainers/container_test.go +++ b/virtcontainers/container_test.go @@ -6,6 +6,7 @@ package virtcontainers import ( + "context" "io/ioutil" "os" "os/exec" @@ -15,6 +16,10 @@ import ( "syscall" "testing" + "github.com/kata-containers/runtime/virtcontainers/device/api" + "github.com/kata-containers/runtime/virtcontainers/device/config" + "github.com/kata-containers/runtime/virtcontainers/device/drivers" + "github.com/kata-containers/runtime/virtcontainers/device/manager" vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" "github.com/stretchr/testify/assert" ) @@ -83,7 +88,11 @@ func TestContainerSandbox(t *testing.T) { } func TestContainerRemoveDrive(t *testing.T) { - sandbox := &Sandbox{} + sandbox := &Sandbox{ + id: "sandbox", + devManager: manager.NewDeviceManager(manager.VirtioSCSI, nil), + storage: &filesystem{}, + } container := Container{ sandbox: sandbox, @@ -95,26 +104,35 @@ func TestContainerRemoveDrive(t *testing.T) { // hotplugRemoveDevice for hypervisor should not be called. // test should pass without a hypervisor created for the container's sandbox. - if err != nil { - t.Fatal("") + assert.Nil(t, err, "remove drive should succeed") + + sandbox.hypervisor = &mockHypervisor{} + path := "/dev/hda" + deviceInfo := config.DeviceInfo{ + HostPath: path, + ContainerPath: path, + DevType: "b", } + devReceiver := &api.MockDeviceReceiver{} + + device, err := sandbox.devManager.NewDevice(deviceInfo) + assert.Nil(t, err) + _, ok := device.(*drivers.BlockDevice) + assert.True(t, ok) + err = device.Attach(devReceiver) + assert.Nil(t, err) + err = sandbox.storage.createAllResources(context.Background(), sandbox) + if err != nil { + t.Fatal(err) + } + + err = sandbox.storeSandboxDevices() + assert.Nil(t, err) container.state.Fstype = "xfs" - container.state.HotpluggedDrive = false + container.state.BlockDeviceID = device.DeviceID() err = container.removeDrive() - - // hotplugRemoveDevice for hypervisor should not be called. - if err != nil { - t.Fatal("") - } - - container.state.HotpluggedDrive = true - sandbox.hypervisor = &mockHypervisor{} - err = container.removeDrive() - - if err != nil { - t.Fatal() - } + assert.Nil(t, err, "remove drive should succeed") } func testSetupFakeRootfs(t *testing.T) (testRawFile, loopDev, mntDir string, err error) { @@ -197,6 +215,7 @@ func TestContainerAddDriveDir(t *testing.T) { fs := &filesystem{} sandbox := &Sandbox{ id: testSandboxID, + devManager: manager.NewDeviceManager(manager.VirtioSCSI, nil), storage: fs, hypervisor: &mockHypervisor{}, agent: &noopAgent{}, @@ -243,14 +262,13 @@ func TestContainerAddDriveDir(t *testing.T) { }() container.state.Fstype = "" - container.state.HotpluggedDrive = false err = container.hotplugDrive() if err != nil { t.Fatalf("Error with hotplugDrive :%v", err) } - if container.state.Fstype == "" || !container.state.HotpluggedDrive { + if container.state.Fstype == "" { t.Fatal() } } diff --git a/virtcontainers/device/api/interface.go b/virtcontainers/device/api/interface.go index 6bdba9b1bc..da15831aef 100644 --- a/virtcontainers/device/api/interface.go +++ b/virtcontainers/device/api/interface.go @@ -49,13 +49,20 @@ type Device interface { // DeviceType indicates which kind of device it is // e.g. block, vfio or vhost user DeviceType() config.DeviceType + // GetMajorMinor returns major and minor numbers + GetMajorMinor() (int64, int64) // GetDeviceInfo returns device specific data used for hotplugging by hypervisor // Caller could cast the return value to device specific struct // e.g. Block device returns *config.BlockDrive and // vfio device returns []*config.VFIODev GetDeviceInfo() interface{} - // IsAttached checks if the device is attached - IsAttached() bool + // GetAttachCount returns how many times the device has been attached + GetAttachCount() uint + + // Reference adds one reference to device then returns final ref count + Reference() uint + // Dereference removes one reference to device then returns final ref count + Dereference() uint } // DeviceManager can be used to create a new device, this can be used as single diff --git a/virtcontainers/device/config/config.go b/virtcontainers/device/config/config.go index d9c3dfe934..0bfb79b571 100644 --- a/virtcontainers/device/config/config.go +++ b/virtcontainers/device/config/config.go @@ -53,7 +53,7 @@ type DeviceInfo struct { HostPath string // ContainerPath is device path inside container - ContainerPath string + ContainerPath string `json:"-"` // Type of device: c, b, u or p // c , u - character(unbuffered) @@ -75,10 +75,6 @@ type DeviceInfo struct { // id of the device group. GID uint32 - // Hotplugged is used to store device state indicating if the - // device was hotplugged. - Hotplugged bool - // ID for the device that is passed to the hypervisor. ID string @@ -123,7 +119,7 @@ type VFIODev struct { // VhostUserDeviceAttrs represents data shared by most vhost-user devices type VhostUserDeviceAttrs struct { - ID string + DevID string SocketPath string Type DeviceType diff --git a/virtcontainers/device/drivers/block.go b/virtcontainers/device/drivers/block.go index 38e5a4b64b..cbbe7bb316 100644 --- a/virtcontainers/device/drivers/block.go +++ b/virtcontainers/device/drivers/block.go @@ -18,23 +18,28 @@ const maxDevIDSize = 31 // BlockDevice refers to a block storage device implementation. type BlockDevice struct { - ID string - DeviceInfo *config.DeviceInfo + *GenericDevice BlockDrive *config.BlockDrive } // NewBlockDevice creates a new block device based on DeviceInfo func NewBlockDevice(devInfo *config.DeviceInfo) *BlockDevice { return &BlockDevice{ - ID: devInfo.ID, - DeviceInfo: devInfo, + GenericDevice: &GenericDevice{ + ID: devInfo.ID, + DeviceInfo: devInfo, + }, } } // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -46,6 +51,8 @@ func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { defer func() { if err != nil { devReceiver.DecrementSandboxBlockIndex() + } else { + device.AttachCount = 1 } }() @@ -83,15 +90,17 @@ func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { return err } - device.DeviceInfo.Hotplugged = true - return nil } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *BlockDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } @@ -101,26 +110,19 @@ func (device *BlockDevice) Detach(devReceiver api.DeviceReceiver) error { deviceLogger().WithError(err).Error("Failed to unplug block device") return err } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *BlockDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - // DeviceType is standard interface of api.Device, it returns device type func (device *BlockDevice) DeviceType() config.DeviceType { return config.DeviceBlock } -// DeviceID returns device ID -func (device *BlockDevice) DeviceID() string { - return device.ID -} - // GetDeviceInfo returns device information used for creating func (device *BlockDevice) GetDeviceInfo() interface{} { return device.BlockDrive } + +// It should implement GetAttachCount() and DeviceID() as api.Device implementation +// here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/generic.go b/virtcontainers/device/drivers/generic.go index 431d178439..10835e3f16 100644 --- a/virtcontainers/device/drivers/generic.go +++ b/virtcontainers/device/drivers/generic.go @@ -7,6 +7,8 @@ package drivers import ( + "fmt" + "github.com/kata-containers/runtime/virtcontainers/device/api" "github.com/kata-containers/runtime/virtcontainers/device/config" ) @@ -15,6 +17,9 @@ import ( type GenericDevice struct { ID string DeviceInfo *config.DeviceInfo + + RefCount uint + AttachCount uint } // NewGenericDevice creates a new GenericDevice @@ -27,32 +32,30 @@ func NewGenericDevice(devInfo *config.DeviceInfo) *GenericDevice { // Attach is standard interface of api.Device func (device *GenericDevice) Attach(devReceiver api.DeviceReceiver) error { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } - + device.AttachCount = 1 return nil } // Detach is standard interface of api.Device func (device *GenericDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *GenericDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - -// DeviceID returns device ID -func (device *GenericDevice) DeviceID() string { - return device.ID -} - // DeviceType is standard interface of api.Device, it returns device type func (device *GenericDevice) DeviceType() config.DeviceType { return config.DeviceGeneric @@ -62,3 +65,65 @@ func (device *GenericDevice) DeviceType() config.DeviceType { func (device *GenericDevice) GetDeviceInfo() interface{} { return device.DeviceInfo } + +// GetAttachCount returns how many times the device has been attached +func (device *GenericDevice) GetAttachCount() uint { + return device.AttachCount +} + +// DeviceID returns device ID +func (device *GenericDevice) DeviceID() string { + return device.ID +} + +// GetMajorMinor returns device major and minor numbers +func (device *GenericDevice) GetMajorMinor() (int64, int64) { + return device.DeviceInfo.Major, device.DeviceInfo.Minor +} + +// Reference adds one reference to device +func (device *GenericDevice) Reference() uint { + if device.RefCount != intMax { + device.RefCount++ + } + return device.RefCount +} + +// Dereference remove one reference from device +func (device *GenericDevice) Dereference() uint { + if device.RefCount != 0 { + device.RefCount-- + } + return device.RefCount +} + +// bumpAttachCount is used to add/minus attach count for a device +// * attach bool: true means attach, false means detach +// return values: +// * skip bool: no need to do real attach/detach, skip following actions. +// * err error: error while do attach count bump +func (device *GenericDevice) bumpAttachCount(attach bool) (skip bool, err error) { + if attach { // attach use case + switch device.AttachCount { + case 0: + // do real attach + return false, nil + case intMax: + return true, fmt.Errorf("device was attached too many times") + default: + device.AttachCount++ + return true, nil + } + } else { // detach use case + switch device.AttachCount { + case 0: + return true, fmt.Errorf("detaching a device that wasn't attached") + case 1: + // do real work + return false, nil + default: + device.AttachCount-- + return true, nil + } + } +} diff --git a/virtcontainers/device/drivers/generic_test.go b/virtcontainers/device/drivers/generic_test.go new file mode 100644 index 0000000000..5b0c28484f --- /dev/null +++ b/virtcontainers/device/drivers/generic_test.go @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Huawei Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package drivers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBumpAttachCount(t *testing.T) { + type testData struct { + attach bool + attachCount uint + expectedAC uint + expectSkip bool + expectErr bool + } + + data := []testData{ + {true, 0, 0, false, false}, + {true, 1, 2, true, false}, + {true, intMax, intMax, true, true}, + {false, 0, 0, true, true}, + {false, 1, 1, false, false}, + {false, intMax, intMax - 1, true, false}, + } + + dev := &GenericDevice{} + for _, d := range data { + dev.AttachCount = d.attachCount + skip, err := dev.bumpAttachCount(d.attach) + assert.Equal(t, skip, d.expectSkip, "") + assert.Equal(t, dev.GetAttachCount(), d.expectedAC, "") + if d.expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + } +} diff --git a/virtcontainers/device/drivers/utils.go b/virtcontainers/device/drivers/utils.go index c60eef0f07..33c18ae8ea 100644 --- a/virtcontainers/device/drivers/utils.go +++ b/virtcontainers/device/drivers/utils.go @@ -12,6 +12,8 @@ import ( "github.com/kata-containers/runtime/virtcontainers/device/api" ) +const intMax uint = ^uint(0) + func deviceLogger() *logrus.Entry { return api.DeviceLogger() } diff --git a/virtcontainers/device/drivers/vfio.go b/virtcontainers/device/drivers/vfio.go index c5d7037452..b29cccd675 100644 --- a/virtcontainers/device/drivers/vfio.go +++ b/virtcontainers/device/drivers/vfio.go @@ -30,23 +30,28 @@ const ( // VFIODevice is a vfio device meant to be passed to the hypervisor // to be used by the Virtual Machine. type VFIODevice struct { - ID string - DeviceInfo *config.DeviceInfo - vfioDevs []*config.VFIODev + *GenericDevice + vfioDevs []*config.VFIODev } // NewVFIODevice create a new VFIO device func NewVFIODevice(devInfo *config.DeviceInfo) *VFIODevice { return &VFIODevice{ - ID: devInfo.ID, - DeviceInfo: devInfo, + GenericDevice: &GenericDevice{ + ID: devInfo.ID, + DeviceInfo: devInfo, + }, } } // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -82,14 +87,18 @@ func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group attached") - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 return nil } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } @@ -103,30 +112,23 @@ func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group detached") - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *VFIODevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - // DeviceType is standard interface of api.Device, it returns device type func (device *VFIODevice) DeviceType() config.DeviceType { return config.DeviceVFIO } -// DeviceID returns device ID -func (device *VFIODevice) DeviceID() string { - return device.ID -} - // GetDeviceInfo returns device information used for creating func (device *VFIODevice) GetDeviceInfo() interface{} { return device.vfioDevs } +// It should implement GetAttachCount() and DeviceID() as api.Device implementation +// here it shares function from *GenericDevice so we don't need duplicate codes + // getBDF returns the BDF of pci device // Expected input strng format is []:[][].[] eg. 0000:02:10.0 func getBDF(deviceSysStr string) (string, error) { diff --git a/virtcontainers/device/drivers/vhost_user_blk.go b/virtcontainers/device/drivers/vhost_user_blk.go index 6bb629461f..cfc53d4145 100644 --- a/virtcontainers/device/drivers/vhost_user_blk.go +++ b/virtcontainers/device/drivers/vhost_user_blk.go @@ -16,8 +16,8 @@ import ( // VhostUserBlkDevice is a block vhost-user based device type VhostUserBlkDevice struct { + *GenericDevice config.VhostUserDeviceAttrs - DeviceInfo *config.DeviceInfo } // @@ -27,7 +27,11 @@ type VhostUserBlkDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -38,12 +42,12 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er } id := hex.EncodeToString(randBytes) - device.ID = id + device.DevID = id device.Type = device.DeviceType() defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,24 +56,18 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserBlkDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *VhostUserBlkDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - -// DeviceID returns device ID -func (device *VhostUserBlkDevice) DeviceID() string { - return device.ID -} - // DeviceType is standard interface of api.Device, it returns device type func (device *VhostUserBlkDevice) DeviceType() config.DeviceType { return config.VhostUserBlk @@ -80,3 +78,6 @@ func (device *VhostUserBlkDevice) GetDeviceInfo() interface{} { device.Type = device.DeviceType() return &device.VhostUserDeviceAttrs } + +// It should implement GetAttachCount() and DeviceID() as api.Device implementation +// here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/vhost_user_net.go b/virtcontainers/device/drivers/vhost_user_net.go index 6a731542ce..2ab31dd620 100644 --- a/virtcontainers/device/drivers/vhost_user_net.go +++ b/virtcontainers/device/drivers/vhost_user_net.go @@ -16,8 +16,8 @@ import ( // VhostUserNetDevice is a network vhost-user based device type VhostUserNetDevice struct { + *GenericDevice config.VhostUserDeviceAttrs - DeviceInfo *config.DeviceInfo } // @@ -27,7 +27,11 @@ type VhostUserNetDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -38,12 +42,12 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er } id := hex.EncodeToString(randBytes) - device.ID = id + device.DevID = id device.Type = device.DeviceType() defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,24 +56,18 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserNetDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *VhostUserNetDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - -// DeviceID returns device ID -func (device *VhostUserNetDevice) DeviceID() string { - return device.ID -} - // DeviceType is standard interface of api.Device, it returns device type func (device *VhostUserNetDevice) DeviceType() config.DeviceType { return config.VhostUserNet @@ -80,3 +78,6 @@ func (device *VhostUserNetDevice) GetDeviceInfo() interface{} { device.Type = device.DeviceType() return &device.VhostUserDeviceAttrs } + +// It should implement GetAttachCount() and DeviceID() as api.Device implementation +// here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/vhost_user_scsi.go b/virtcontainers/device/drivers/vhost_user_scsi.go index ce5f846cd5..d34c50ec04 100644 --- a/virtcontainers/device/drivers/vhost_user_scsi.go +++ b/virtcontainers/device/drivers/vhost_user_scsi.go @@ -16,8 +16,8 @@ import ( // VhostUserSCSIDevice is a SCSI vhost-user based device type VhostUserSCSIDevice struct { + *GenericDevice config.VhostUserDeviceAttrs - DeviceInfo *config.DeviceInfo } // @@ -27,7 +27,11 @@ type VhostUserSCSIDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -38,12 +42,12 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e } id := hex.EncodeToString(randBytes) - device.ID = id + device.DevID = id device.Type = device.DeviceType() defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,24 +56,17 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserSCSIDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } -// IsAttached checks if the device is attached -func (device *VhostUserSCSIDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged -} - -// DeviceID returns device ID -func (device *VhostUserSCSIDevice) DeviceID() string { - return device.ID -} - // DeviceType is standard interface of api.Device, it returns device type func (device *VhostUserSCSIDevice) DeviceType() config.DeviceType { return config.VhostUserSCSI @@ -80,3 +77,6 @@ func (device *VhostUserSCSIDevice) GetDeviceInfo() interface{} { device.Type = device.DeviceType() return &device.VhostUserDeviceAttrs } + +// It should implement GetAttachCount() and DeviceID() as api.Device implementation +// here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/manager/manager.go b/virtcontainers/device/manager/manager.go index 62ce0742a6..5841b938da 100644 --- a/virtcontainers/device/manager/manager.go +++ b/virtcontainers/device/manager/manager.go @@ -32,10 +32,11 @@ var ( ErrIDExhausted = errors.New("IDs are exhausted") // ErrDeviceNotExist represents device hasn't been created before ErrDeviceNotExist = errors.New("device with specified ID hasn't been created") - // ErrDeviceAttached represents the device is already attached - ErrDeviceAttached = errors.New("device is already attached") // ErrDeviceNotAttached represents the device isn't attached ErrDeviceNotAttached = errors.New("device isn't attached") + // ErrRemoveAttachedDevice represents the device isn't detached + // so not allow to remove from list + ErrRemoveAttachedDevice = errors.New("can't remove attached device") ) type deviceManager struct { @@ -66,14 +67,34 @@ func NewDeviceManager(blockDriver string, devices []api.Device) api.DeviceManage return dm } +func (dm *deviceManager) findDeviceByMajorMinor(major, minor int64) api.Device { + for _, dev := range dm.devices { + dma, dmi := dev.GetMajorMinor() + if dma == major && dmi == minor { + return dev + } + } + return nil +} + // createDevice creates one device based on DeviceInfo -func (dm *deviceManager) createDevice(devInfo config.DeviceInfo) (api.Device, error) { +func (dm *deviceManager) createDevice(devInfo config.DeviceInfo) (dev api.Device, err error) { path, err := config.GetHostPathFunc(devInfo) if err != nil { return nil, err } devInfo.HostPath = path + defer func() { + if err == nil { + dev.Reference() + } + }() + + if existingDev := dm.findDeviceByMajorMinor(devInfo.Major, devInfo.Minor); existingDev != nil { + return existingDev, nil + } + // device ID must be generated by manager instead of device itself // in case of ID collision if devInfo.ID, err = dm.newDeviceID(); err != nil { @@ -108,10 +129,17 @@ func (dm *deviceManager) NewDevice(devInfo config.DeviceInfo) (api.Device, error func (dm *deviceManager) RemoveDevice(id string) error { dm.Lock() defer dm.Unlock() - if _, ok := dm.devices[id]; !ok { + dev, ok := dm.devices[id] + if !ok { return ErrDeviceNotExist } - delete(dm.devices, id) + + if dev.Dereference() == 0 { + if dev.GetAttachCount() > 0 { + return ErrRemoveAttachedDevice + } + delete(dm.devices, id) + } return nil } @@ -141,10 +169,6 @@ func (dm *deviceManager) AttachDevice(id string, dr api.DeviceReceiver) error { return ErrDeviceNotExist } - if d.IsAttached() { - return ErrDeviceAttached - } - if err := d.Attach(dr); err != nil { return err } @@ -159,7 +183,7 @@ func (dm *deviceManager) DetachDevice(id string, dr api.DeviceReceiver) error { if !ok { return ErrDeviceNotExist } - if !d.IsAttached() { + if d.GetAttachCount() <= 0 { return ErrDeviceNotAttached } @@ -168,6 +192,7 @@ func (dm *deviceManager) DetachDevice(id string, dr api.DeviceReceiver) error { } return nil } + func (dm *deviceManager) GetDeviceByID(id string) api.Device { dm.RLock() defer dm.RUnlock() @@ -194,5 +219,5 @@ func (dm *deviceManager) IsDeviceAttached(id string) bool { if !ok { return false } - return d.IsAttached() + return d.GetAttachCount() > 0 } diff --git a/virtcontainers/device/manager/manager_test.go b/virtcontainers/device/manager/manager_test.go index 2818820109..b99dd992b4 100644 --- a/virtcontainers/device/manager/manager_test.go +++ b/virtcontainers/device/manager/manager_test.go @@ -216,13 +216,18 @@ func TestAttachDetachDevice(t *testing.T) { device, err := dm.NewDevice(deviceInfo) assert.Nil(t, err) + // attach non-exist device + err = dm.AttachDevice("non-exist", devReceiver) + assert.NotNil(t, err) + // attach device err = dm.AttachDevice(device.DeviceID(), devReceiver) assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(1), "attach device count should be 1") // attach device again(twice) err = dm.AttachDevice(device.DeviceID(), devReceiver) - assert.NotNil(t, err) - assert.Equal(t, err, ErrDeviceAttached, "attach device twice should report error %q", ErrDeviceAttached) + assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(2), "attach device count should be 2") attached := dm.IsDeviceAttached(device.DeviceID()) assert.True(t, attached) @@ -230,12 +235,20 @@ func TestAttachDetachDevice(t *testing.T) { // detach device err = dm.DetachDevice(device.DeviceID(), devReceiver) assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(1), "attach device count should be 1") // detach device again(twice) err = dm.DetachDevice(device.DeviceID(), devReceiver) + assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(0), "attach device count should be 0") + // detach device again should report error + err = dm.DetachDevice(device.DeviceID(), devReceiver) assert.NotNil(t, err) - assert.Equal(t, err, ErrDeviceNotAttached, "attach device twice should report error %q", ErrDeviceNotAttached) + assert.Equal(t, err, ErrDeviceNotAttached, "") + assert.Equal(t, device.GetAttachCount(), uint(0), "attach device count should be 0") attached = dm.IsDeviceAttached(device.DeviceID()) assert.False(t, attached) + err = dm.RemoveDevice(device.DeviceID()) + assert.Nil(t, err) } diff --git a/virtcontainers/kata_agent.go b/virtcontainers/kata_agent.go index 717acb3449..fc9d10c89b 100644 --- a/virtcontainers/kata_agent.go +++ b/virtcontainers/kata_agent.go @@ -767,7 +767,7 @@ func (k *kataAgent) rollbackFailingContainerCreation(c *Container) { } func (k *kataAgent) buildContainerRootfs(sandbox *Sandbox, c *Container, rootPathParent string) (*grpc.Storage, error) { - if c.state.Fstype != "" { + if c.state.Fstype != "" && c.state.BlockDeviceID != "" { // The rootfs storage volume represents the container rootfs // mount point inside the guest. // It can be a block based device (when using block based container @@ -776,23 +776,25 @@ func (k *kataAgent) buildContainerRootfs(sandbox *Sandbox, c *Container, rootPat rootfs := &grpc.Storage{} // This is a block based device rootfs. - - // Pass a drive name only in case of virtio-blk driver. - // If virtio-scsi driver, the agent will be able to find the - // device based on the provided address. - if sandbox.config.HypervisorConfig.BlockDeviceDriver == VirtioBlock { - rootfs.Driver = kataBlkDevType - rootfs.Source = c.state.RootfsPCIAddr - } else { - scsiAddr, err := utils.GetSCSIAddress(c.state.BlockIndex) - if err != nil { - return nil, err - } - - rootfs.Driver = kataSCSIDevType - rootfs.Source = scsiAddr + device := sandbox.devManager.GetDeviceByID(c.state.BlockDeviceID) + if device == nil { + k.Logger().WithField("device", c.state.BlockDeviceID).Error("failed to find device by id") + return nil, fmt.Errorf("failed to find device by id %q", c.state.BlockDeviceID) } + blockDrive, ok := device.GetDeviceInfo().(*config.BlockDrive) + if !ok || blockDrive == nil { + k.Logger().Error("malformed block drive") + return nil, fmt.Errorf("malformed block drive") + } + + if sandbox.config.HypervisorConfig.BlockDeviceDriver == VirtioBlock { + rootfs.Driver = kataBlkDevType + rootfs.Source = blockDrive.VirtPath + } else { + rootfs.Driver = kataSCSIDevType + rootfs.Source = blockDrive.SCSIAddr + } rootfs.MountPoint = rootPathParent rootfs.Fstype = c.state.Fstype @@ -802,6 +804,7 @@ func (k *kataAgent) buildContainerRootfs(sandbox *Sandbox, c *Container, rootPat return rootfs, nil } + // This is not a block based device rootfs. // We are going to bind mount it into the 9pfs // shared drive between the host and the guest. @@ -970,7 +973,7 @@ func (k *kataAgent) handleBlockVolumes(c *Container) []*grpc.Storage { // Add the block device to the list of container devices, to make sure the // device is detached with detachDevices() for a container. - c.devices = append(c.devices, ContainerDevice{ID: id}) + c.devices = append(c.devices, ContainerDevice{ID: id, ContainerPath: m.Destination}) if err := c.storeDevices(); err != nil { k.Logger().WithField("device", id).WithError(err).Error("store device failed") return nil diff --git a/virtcontainers/kata_agent_test.go b/virtcontainers/kata_agent_test.go index e4892f7fc6..305bfabf81 100644 --- a/virtcontainers/kata_agent_test.go +++ b/virtcontainers/kata_agent_test.go @@ -382,7 +382,9 @@ func TestAppendDevices(t *testing.T) { id := "test-append-block" ctrDevices := []api.Device{ &drivers.BlockDevice{ - ID: id, + GenericDevice: &drivers.GenericDevice{ + ID: id, + }, BlockDrive: &config.BlockDrive{ PCIAddr: testPCIAddr, }, diff --git a/virtcontainers/network.go b/virtcontainers/network.go index 81ed06e798..c623c6cbe6 100644 --- a/virtcontainers/network.go +++ b/virtcontainers/network.go @@ -322,7 +322,7 @@ func (endpoint *VhostUserEndpoint) Attach(h hypervisor) error { id := hex.EncodeToString(randBytes) d := config.VhostUserDeviceAttrs{ - ID: id, + DevID: id, SocketPath: endpoint.SocketPath, MacAddress: endpoint.HardAddr, Type: config.VhostUserNet, diff --git a/virtcontainers/qemu_arch_base.go b/virtcontainers/qemu_arch_base.go index 3bc24e82fd..31e5c54928 100644 --- a/virtcontainers/qemu_arch_base.go +++ b/virtcontainers/qemu_arch_base.go @@ -475,16 +475,16 @@ func (q *qemuArchBase) appendVhostUserDevice(devices []govmmQemu.Device, attr co switch attr.Type { case config.VhostUserNet: - qemuVhostUserDevice.TypeDevID = utils.MakeNameID("net", attr.ID, maxDevIDSize) + qemuVhostUserDevice.TypeDevID = utils.MakeNameID("net", attr.DevID, maxDevIDSize) qemuVhostUserDevice.Address = attr.MacAddress case config.VhostUserSCSI: - qemuVhostUserDevice.TypeDevID = utils.MakeNameID("scsi", attr.ID, maxDevIDSize) + qemuVhostUserDevice.TypeDevID = utils.MakeNameID("scsi", attr.DevID, maxDevIDSize) case config.VhostUserBlk: } qemuVhostUserDevice.VhostUserType = govmmQemu.VhostUserDeviceType(attr.Type) qemuVhostUserDevice.SocketPath = attr.SocketPath - qemuVhostUserDevice.CharDevID = utils.MakeNameID("char", attr.ID, maxDevIDSize) + qemuVhostUserDevice.CharDevID = utils.MakeNameID("char", attr.DevID, maxDevIDSize) devices = append(devices, qemuVhostUserDevice) diff --git a/virtcontainers/qemu_arch_base_test.go b/virtcontainers/qemu_arch_base_test.go index 3465f8cdad..601341fbb4 100644 --- a/virtcontainers/qemu_arch_base_test.go +++ b/virtcontainers/qemu_arch_base_test.go @@ -391,7 +391,7 @@ func TestQemuArchBaseAppendVhostUserDevice(t *testing.T) { Type: config.VhostUserNet, MacAddress: macAddress, } - vhostUserDevice.ID = id + vhostUserDevice.DevID = id vhostUserDevice.SocketPath = socketPath testQemuArchBaseAppend(t, vhostUserDevice, expectedOut) diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index 4542e80984..274d832b06 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -54,18 +54,13 @@ const ( type State struct { State stateString `json:"state"` + BlockDeviceID string // Index of the block device passed to hypervisor. BlockIndex int `json:"blockIndex"` // File system of the rootfs incase it is block device Fstype string `json:"fstype"` - // Bool to indicate if the drive for a container was hotplugged. - HotpluggedDrive bool `json:"hotpluggedDrive"` - - // PCI slot at which the block device backing the container rootfs is attached. - RootfsPCIAddr string `json:"rootfsPCIAddr"` - // Pid is the process id of the sandbox container which is the first // container to be started. Pid int `json:"pid"` diff --git a/virtcontainers/vm.go b/virtcontainers/vm.go index 7ffa47c611..d9b6947f55 100644 --- a/virtcontainers/vm.go +++ b/virtcontainers/vm.go @@ -200,12 +200,12 @@ func (v *VM) ReseedRNG() error { data := make([]byte, 512) f, err := os.OpenFile(urandomDev, os.O_RDONLY, 0) if err != nil { - v.logger().WithError(err).Warn("fail to open %s", urandomDev) + v.logger().WithError(err).Warnf("fail to open %s", urandomDev) return err } defer f.Close() if _, err = f.Read(data); err != nil { - v.logger().WithError(err).Warn("fail to read %s", urandomDev) + v.logger().WithError(err).Warnf("fail to read %s", urandomDev) return err }