diff --git a/virtcontainers/sandbox.go b/virtcontainers/sandbox.go index e52be3ead0..9517c6de30 100644 --- a/virtcontainers/sandbox.go +++ b/virtcontainers/sandbox.go @@ -1446,12 +1446,43 @@ func togglePauseSandbox(sandboxID string, pause bool) (*Sandbox, error) { func (s *Sandbox) HotplugAddDevice(device api.Device, devType config.DeviceType) error { switch devType { case config.DeviceVFIO: - vfioDevice, ok := device.(*drivers.VFIODevice) + vfioDevices, ok := device.GetDeviceDrive().([]*config.VFIODrive) if !ok { return fmt.Errorf("device type mismatch, expect device type to be %s", devType) } - _, err := s.hypervisor.hotplugAddDevice(*vfioDevice, vfioDev) - return err + addedDev := []*config.VFIODrive{} + var err error + defer func() { + // if err happens,roll back and remove added device! + if err != nil { + for _, dev := range addedDev { + if _, rollbackErr := s.hypervisor.hotplugRemoveDevice(dev, vfioDev); rollbackErr != nil { + s.Logger(). + WithFields(logrus.Fields{ + "sandboxid": s.id, + "vfio device ID": dev.ID, + "vfio device BDF": dev.BDF, + }).WithError(rollbackErr). + Error("failed to remove vfio device for rolling back") + } + } + } + }() + + // adding a group of VFIO devices + for _, dev := range vfioDevices { + if _, err = s.hypervisor.hotplugAddDevice(dev, vfioDev); err != nil { + s.Logger(). + WithFields(logrus.Fields{ + "sandboxid": s.id, + "vfio device ID": dev.ID, + "vfio device BDF": dev.BDF, + }).WithError(err).Error("failed to hotplug VFIO device") + return err + } + addedDev = append(addedDev, dev) + } + return nil case config.DeviceBlock: blockDevice, ok := device.(*drivers.BlockDevice) if !ok { @@ -1471,18 +1502,49 @@ func (s *Sandbox) HotplugAddDevice(device api.Device, devType config.DeviceType) func (s *Sandbox) HotplugRemoveDevice(device api.Device, devType config.DeviceType) error { switch devType { case config.DeviceVFIO: - vfioDevice, ok := device.(*drivers.VFIODevice) + vfioDevices, ok := device.GetDeviceDrive().([]*config.VFIODrive) if !ok { return fmt.Errorf("device type mismatch, expect device type to be %s", devType) } - _, err := s.hypervisor.hotplugRemoveDevice(*vfioDevice, vfioDev) - return err + removedDev := []*config.VFIODrive{} + var err error + defer func() { + // if err happens,roll back and add the removed devices back! + if err != nil { + for _, dev := range removedDev { + if _, rollbackErr := s.hypervisor.hotplugAddDevice(dev, vfioDev); rollbackErr != nil { + + s.Logger().WithError(rollbackErr). + WithFields(logrus.Fields{ + "sandboxid": s.id, + "vfio device ID": dev.ID, + "vfio device BDF": dev.BDF, + }).Error("failed to add vfio device for rolling back") + } + } + } + }() + + // remove a group of VFIO devices + for _, dev := range vfioDevices { + if _, err = s.hypervisor.hotplugRemoveDevice(dev, vfioDev); err != nil { + s.Logger().WithError(err). + WithFields(logrus.Fields{ + "sandboxid": s.id, + "vfio device ID": dev.ID, + "vfio device BDF": dev.BDF, + }).Error("failed to hot unplug VFIO device") + return err + } + removedDev = append(removedDev, dev) + } + return nil case config.DeviceBlock: - blockDevice, ok := device.(*drivers.BlockDevice) + blockDrive, ok := device.GetDeviceDrive().(*config.BlockDrive) if !ok { return fmt.Errorf("device type mismatch, expect device type to be %s", devType) } - _, err := s.hypervisor.hotplugRemoveDevice(blockDevice.BlockDrive, blockDev) + _, err := s.hypervisor.hotplugRemoveDevice(blockDrive, blockDev) return err case config.DeviceGeneric: // TODO: what?