diff --git a/virtcontainers/device/drivers/block.go b/virtcontainers/device/drivers/block.go index b977e30e8a..fdb18f4187 100644 --- a/virtcontainers/device/drivers/block.go +++ b/virtcontainers/device/drivers/block.go @@ -51,8 +51,7 @@ func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { defer func() { if err != nil { devReceiver.DecrementSandboxBlockIndex() - } else { - device.AttachCount = 1 + device.bumpAttachCount(false) } }() @@ -122,13 +121,18 @@ func (device *BlockDevice) Detach(devReceiver api.DeviceReceiver) error { return nil } + defer func() { + if err != nil { + device.bumpAttachCount(true) + } + }() + deviceLogger().WithField("device", device.DeviceInfo.HostPath).Info("Unplugging block device") - if err := devReceiver.HotplugRemoveDevice(device, config.DeviceBlock); err != nil { + if err = devReceiver.HotplugRemoveDevice(device, config.DeviceBlock); err != nil { deviceLogger().WithError(err).Error("Failed to unplug block device") return err } - device.AttachCount = 0 return nil } diff --git a/virtcontainers/device/drivers/generic.go b/virtcontainers/device/drivers/generic.go index 10835e3f16..ff60fc3fb4 100644 --- a/virtcontainers/device/drivers/generic.go +++ b/virtcontainers/device/drivers/generic.go @@ -32,28 +32,14 @@ func NewGenericDevice(devInfo *config.DeviceInfo) *GenericDevice { // Attach is standard interface of api.Device func (device *GenericDevice) Attach(devReceiver api.DeviceReceiver) error { - skip, err := device.bumpAttachCount(true) - if err != nil { - return err - } - if skip { - return nil - } - device.AttachCount = 1 - return nil + _, err := device.bumpAttachCount(true) + return err } // Detach is standard interface of api.Device func (device *GenericDevice) Detach(devReceiver api.DeviceReceiver) error { - skip, err := device.bumpAttachCount(false) - if err != nil { - return err - } - if skip { - return nil - } - device.AttachCount = 0 - return nil + _, err := device.bumpAttachCount(false) + return err } // DeviceType is standard interface of api.Device, it returns device type @@ -107,6 +93,7 @@ func (device *GenericDevice) bumpAttachCount(attach bool) (skip bool, err error) switch device.AttachCount { case 0: // do real attach + device.AttachCount++ return false, nil case intMax: return true, fmt.Errorf("device was attached too many times") @@ -120,6 +107,7 @@ func (device *GenericDevice) bumpAttachCount(attach bool) (skip bool, err error) return true, fmt.Errorf("detaching a device that wasn't attached") case 1: // do real work + device.AttachCount-- return false, nil default: device.AttachCount-- diff --git a/virtcontainers/device/drivers/generic_test.go b/virtcontainers/device/drivers/generic_test.go index 5b0c28484f..a6c46cc938 100644 --- a/virtcontainers/device/drivers/generic_test.go +++ b/virtcontainers/device/drivers/generic_test.go @@ -21,11 +21,11 @@ func TestBumpAttachCount(t *testing.T) { } data := []testData{ - {true, 0, 0, false, false}, + {true, 0, 1, false, false}, {true, 1, 2, true, false}, {true, intMax, intMax, true, true}, {false, 0, 0, true, true}, - {false, 1, 1, false, false}, + {false, 1, 0, false, false}, {false, intMax, intMax - 1, true, false}, } diff --git a/virtcontainers/device/drivers/vfio.go b/virtcontainers/device/drivers/vfio.go index a0a80dec72..f717692258 100644 --- a/virtcontainers/device/drivers/vfio.go +++ b/virtcontainers/device/drivers/vfio.go @@ -47,7 +47,7 @@ func NewVFIODevice(devInfo *config.DeviceInfo) *VFIODevice { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver -func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { +func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) (retErr error) { skip, err := device.bumpAttachCount(true) if err != nil { return err @@ -56,6 +56,12 @@ func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { return nil } + defer func() { + if retErr != nil { + device.bumpAttachCount(false) + } + }() + vfioGroup := filepath.Base(device.DeviceInfo.HostPath) iommuDevicesPath := filepath.Join(config.SysIOMMUPath, vfioGroup, "devices") @@ -90,13 +96,12 @@ func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group attached") - device.AttachCount = 1 return nil } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver -func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { +func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) (retErr error) { skip, err := device.bumpAttachCount(false) if err != nil { return err @@ -105,6 +110,12 @@ func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { return nil } + defer func() { + if retErr != nil { + device.bumpAttachCount(true) + } + }() + // hotplug a VFIO device is actually hotplugging a group of iommu devices if err := devReceiver.HotplugRemoveDevice(device, config.DeviceVFIO); err != nil { deviceLogger().WithError(err).Error("Failed to remove device") @@ -115,7 +126,6 @@ func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group detached") - device.AttachCount = 0 return nil } diff --git a/virtcontainers/device/drivers/vhost_user_blk.go b/virtcontainers/device/drivers/vhost_user_blk.go index a613abda82..78b2885dc1 100644 --- a/virtcontainers/device/drivers/vhost_user_blk.go +++ b/virtcontainers/device/drivers/vhost_user_blk.go @@ -34,6 +34,11 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er if skip { return nil } + defer func() { + if err != nil { + device.bumpAttachCount(false) + } + }() // generate a unique ID to be used for hypervisor commandline fields randBytes, err := utils.GenerateRandomBytes(8) @@ -45,27 +50,14 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er device.DevID = id device.Type = device.DeviceType() - defer func() { - if err == nil { - device.AttachCount = 1 - } - }() return devReceiver.AppendDevice(device) } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserBlkDevice) Detach(devReceiver api.DeviceReceiver) error { - skip, err := device.bumpAttachCount(false) - if err != nil { - return err - } - if skip { - return nil - } - - device.AttachCount = 0 - return nil + _, err := device.bumpAttachCount(false) + return err } // DeviceType is standard interface of api.Device, it returns device type diff --git a/virtcontainers/device/drivers/vhost_user_net.go b/virtcontainers/device/drivers/vhost_user_net.go index 2ab31dd620..a276bb6fe5 100644 --- a/virtcontainers/device/drivers/vhost_user_net.go +++ b/virtcontainers/device/drivers/vhost_user_net.go @@ -35,6 +35,12 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er return nil } + defer func() { + if err != nil { + device.bumpAttachCount(false) + } + }() + // generate a unique ID to be used for hypervisor commandline fields randBytes, err := utils.GenerateRandomBytes(8) if err != nil { @@ -45,27 +51,14 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er device.DevID = id device.Type = device.DeviceType() - defer func() { - if err == nil { - device.AttachCount = 1 - } - }() return devReceiver.AppendDevice(device) } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserNetDevice) Detach(devReceiver api.DeviceReceiver) error { - skip, err := device.bumpAttachCount(false) - if err != nil { - return err - } - if skip { - return nil - } - - device.AttachCount = 0 - return nil + _, err := device.bumpAttachCount(false) + return err } // DeviceType is standard interface of api.Device, it returns device type diff --git a/virtcontainers/device/drivers/vhost_user_scsi.go b/virtcontainers/device/drivers/vhost_user_scsi.go index d34c50ec04..e4e2a27afa 100644 --- a/virtcontainers/device/drivers/vhost_user_scsi.go +++ b/virtcontainers/device/drivers/vhost_user_scsi.go @@ -35,6 +35,12 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e return nil } + defer func() { + if err != nil { + device.bumpAttachCount(false) + } + }() + // generate a unique ID to be used for hypervisor commandline fields randBytes, err := utils.GenerateRandomBytes(8) if err != nil { @@ -45,26 +51,14 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e device.DevID = id device.Type = device.DeviceType() - defer func() { - if err == nil { - device.AttachCount = 1 - } - }() return devReceiver.AppendDevice(device) } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserSCSIDevice) Detach(devReceiver api.DeviceReceiver) error { - skip, err := device.bumpAttachCount(false) - if err != nil { - return err - } - if skip { - return nil - } - device.AttachCount = 0 - return nil + _, err := device.bumpAttachCount(false) + return err } // DeviceType is standard interface of api.Device, it returns device type