diff --git a/pkg/kubelet/cm/deviceplugin/endpoint.go b/pkg/kubelet/cm/deviceplugin/endpoint.go index 8e0b619d7cb..44898dc023b 100644 --- a/pkg/kubelet/cm/deviceplugin/endpoint.go +++ b/pkg/kubelet/cm/deviceplugin/endpoint.go @@ -46,7 +46,7 @@ type endpoint struct { } // newEndpoint creates a new endpoint for the given resourceName. -func newEndpoint(socketPath, resourceName string, callback MonitorCallback) (*endpoint, error) { +func newEndpoint(socketPath, resourceName string, devices map[string]pluginapi.Device, callback MonitorCallback) (*endpoint, error) { client, c, err := dial(socketPath) if err != nil { glog.Errorf("Can't create new endpoint with path %s err %v", socketPath, err) @@ -60,7 +60,7 @@ func newEndpoint(socketPath, resourceName string, callback MonitorCallback) (*en socketPath: socketPath, resourceName: resourceName, - devices: nil, + devices: devices, callback: callback, }, nil } @@ -77,45 +77,22 @@ func (e *endpoint) getDevices() []pluginapi.Device { return devs } -// list initializes ListAndWatch gRPC call for the device plugin and gets the -// initial list of the devices. Returns ListAndWatch gRPC stream on success. -func (e *endpoint) list() (pluginapi.DevicePlugin_ListAndWatchClient, error) { - stream, err := e.client.ListAndWatch(context.Background(), &pluginapi.Empty{}) - if err != nil { - glog.Errorf(errListAndWatch, e.resourceName, err) - return nil, err - } - - devs, err := stream.Recv() - if err != nil { - glog.Errorf(errListAndWatch, e.resourceName, err) - return nil, err - } - - devices := make(map[string]pluginapi.Device) - var added, updated, deleted []pluginapi.Device - for _, d := range devs.Devices { - devices[d.ID] = *d - added = append(added, *d) - } - - e.mutex.Lock() - e.devices = devices - e.mutex.Unlock() - - e.callback(e.resourceName, added, updated, deleted) - - return stream, nil -} - -// listAndWatch blocks on receiving ListAndWatch gRPC stream updates. Each ListAndWatch +// run initializes ListAndWatch gRPC call for the device plugin and +// blocks on receiving ListAndWatch gRPC stream updates. Each ListAndWatch // stream update contains a new list of device states. listAndWatch compares the new // device states with its cached states to get list of new, updated, and deleted devices. // It then issues a callback to pass this information to the device_plugin_handler which // will adjust the resource available information accordingly. -func (e *endpoint) listAndWatch(stream pluginapi.DevicePlugin_ListAndWatchClient) { +func (e *endpoint) run() { glog.V(3).Infof("Starting ListAndWatch") + stream, err := e.client.ListAndWatch(context.Background(), &pluginapi.Empty{}) + if err != nil { + glog.Errorf(errListAndWatch, e.resourceName, err) + + return + } + devices := make(map[string]pluginapi.Device) e.mutex.Lock() diff --git a/pkg/kubelet/cm/deviceplugin/endpoint_test.go b/pkg/kubelet/cm/deviceplugin/endpoint_test.go index a7e458258e5..cb27c89f319 100644 --- a/pkg/kubelet/cm/deviceplugin/endpoint_test.go +++ b/pkg/kubelet/cm/deviceplugin/endpoint_test.go @@ -41,32 +41,7 @@ func TestNewEndpoint(t *testing.T) { defer ecleanup(t, p, e) } -func TestList(t *testing.T) { - socket := path.Join("/tmp", esocketName) - - devs := []*pluginapi.Device{ - {ID: "ADeviceId", Health: pluginapi.Healthy}, - } - - p, e := esetup(t, devs, socket, "mock", func(n string, a, u, r []pluginapi.Device) {}) - defer ecleanup(t, p, e) - - _, err := e.list() - require.NoError(t, err) - - e.mutex.Lock() - defer e.mutex.Unlock() - - require.Len(t, e.devices, 1) - - d, ok := e.devices[devs[0].ID] - require.True(t, ok) - - require.Equal(t, d.ID, devs[0].ID) - require.Equal(t, d.Health, devs[0].Health) -} - -func TestListAndWatch(t *testing.T) { +func TestRun(t *testing.T) { socket := path.Join("/tmp", esocketName) devs := []*pluginapi.Device{ @@ -93,10 +68,7 @@ func TestListAndWatch(t *testing.T) { }) defer ecleanup(t, p, e) - s, err := e.list() - require.NoError(t, err) - - go e.listAndWatch(s) + go e.run() p.Update(updated) time.Sleep(time.Second) @@ -130,7 +102,7 @@ func esetup(t *testing.T, devs []*pluginapi.Device, socket, resourceName string, err := p.Start() require.NoError(t, err) - e, err := newEndpoint(socket, "mock", func(n string, a, u, r []pluginapi.Device) {}) + e, err := newEndpoint(socket, "mock", make(map[string]pluginapi.Device), func(n string, a, u, r []pluginapi.Device) {}) require.NoError(t, err) return p, e diff --git a/pkg/kubelet/cm/deviceplugin/manager.go b/pkg/kubelet/cm/deviceplugin/manager.go index e255054af31..2b2c0a333ff 100644 --- a/pkg/kubelet/cm/deviceplugin/manager.go +++ b/pkg/kubelet/cm/deviceplugin/manager.go @@ -97,13 +97,13 @@ func (m *ManagerImpl) removeContents(dir string) error { } const ( - // defaultCheckpoint is the file name of device plugin checkpoint - defaultCheckpoint = "kubelet_internal_checkpoint" + // kubeletDevicePluginCheckpoint is the file name of device plugin checkpoint + kubeletDevicePluginCheckpoint = "kubelet_internal_checkpoint" ) // CheckpointFile returns device plugin checkpoint file path. func (m *ManagerImpl) CheckpointFile() string { - return filepath.Join(m.socketdir, defaultCheckpoint) + return filepath.Join(m.socketdir, kubeletDevicePluginCheckpoint) } // Start starts the Device Plugin Manager @@ -205,34 +205,46 @@ func (m *ManagerImpl) Stop() error { } func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) { + existingDevs := make(map[string]pluginapi.Device) + m.mutex.Lock() + old, ok := m.endpoints[r.ResourceName] + if ok && old != nil { + // Pass devices of previous endpoint into re-registered one, + // to avoid potential orphaned devices upon re-registration + existingDevs = old.devices + } + m.mutex.Unlock() + socketPath := filepath.Join(m.socketdir, r.Endpoint) - e, err := newEndpoint(socketPath, r.ResourceName, m.callback) + e, err := newEndpoint(socketPath, r.ResourceName, existingDevs, m.callback) if err != nil { glog.Errorf("Failed to dial device plugin with request %v: %v", r, err) return } - stream, err := e.list() - if err != nil { - glog.Errorf("Failed to List devices for plugin %v: %v", r.ResourceName, err) + m.mutex.Lock() + // Check for potential re-registration during the initialization of new endpoint, + // and skip updating if re-registration happens. + // TODO: simplify the part once we have a better way to handle registered devices + ext := m.endpoints[r.ResourceName] + if ext != old { + glog.Warningf("Some other endpoint %v is added while endpoint %v is initialized", ext, e) + m.mutex.Unlock() e.stop() return } - // Associates the newly created endpoint with the corresponding resource name. // Stops existing endpoint if there is any. - m.mutex.Lock() - old, ok := m.endpoints[r.ResourceName] m.endpoints[r.ResourceName] = e glog.V(2).Infof("Registered endpoint %v", e) m.mutex.Unlock() - if ok && old != nil { + if old != nil { old.stop() } go func() { - e.listAndWatch(stream) + e.run() e.stop() m.mutex.Lock() diff --git a/pkg/kubelet/cm/deviceplugin/manager_test.go b/pkg/kubelet/cm/deviceplugin/manager_test.go index a9f05700b4b..0d2178f92c3 100644 --- a/pkg/kubelet/cm/deviceplugin/manager_test.go +++ b/pkg/kubelet/cm/deviceplugin/manager_test.go @@ -47,20 +47,23 @@ func TestNewManagerImplStart(t *testing.T) { // Tests that the device plugin manager correctly handles registration and re-registration by // making sure that after registration, devices are correctly updated and if a re-registration -// happens, we will NOT delete devices. +// happens, we will NOT delete devices; and no orphaned devices left. func TestDevicePluginReRegistration(t *testing.T) { devs := []*pluginapi.Device{ {ID: "Dev1", Health: pluginapi.Healthy}, {ID: "Dev2", Health: pluginapi.Healthy}, } + devsForRegistration := []*pluginapi.Device{ + {ID: "Dev3", Health: pluginapi.Healthy}, + } callbackCount := 0 callbackChan := make(chan int) var stopping int32 stopping = 0 callback := func(n string, a, u, r []pluginapi.Device) { - // Should be called twice, one for each plugin registration, till we are stopping. - if callbackCount > 1 && atomic.LoadInt32(&stopping) <= 0 { + // Should be called three times, one for each plugin registration, till we are stopping. + if callbackCount > 2 && atomic.LoadInt32(&stopping) <= 0 { t.FailNow() } callbackCount++ @@ -89,12 +92,25 @@ func TestDevicePluginReRegistration(t *testing.T) { devices2 := m.Devices() require.Equal(t, 2, len(devices2[testResourceName]), "Devices shouldn't change.") + + // Test the scenario that a plugin re-registers with different devices. + p3 := NewDevicePluginStub(devsForRegistration, pluginSocketName+".third") + err = p3.Start() + require.NoError(t, err) + p3.Register(socketName, testResourceName) + // Wait for the second callback to be issued. + <-callbackChan + + devices3 := m.Devices() + require.Equal(t, 1, len(devices3[testResourceName]), "Devices of plugin previously registered should be removed.") // Wait long enough to catch unexpected callbacks. time.Sleep(5 * time.Second) atomic.StoreInt32(&stopping, 1) - cleanup(t, m, p1) p2.Stop() + p3.Stop() + cleanup(t, m, p1) + } func setup(t *testing.T, devs []*pluginapi.Device, callback MonitorCallback) (Manager, *Stub) {