diff --git a/pkg/kubelet/cm/deviceplugin/device_plugin_stub.go b/pkg/kubelet/cm/deviceplugin/device_plugin_stub.go index a04389cc192..08dcd5a992f 100644 --- a/pkg/kubelet/cm/deviceplugin/device_plugin_stub.go +++ b/pkg/kubelet/cm/deviceplugin/device_plugin_stub.go @@ -89,7 +89,7 @@ func (m *Stub) Start() error { // Wait till grpc server is ready. for i := 0; i < 10; i++ { services := m.server.GetServiceInfo() - if len(services) > 1 { + if len(services) > 0 { break } time.Sleep(1 * time.Second) @@ -134,16 +134,8 @@ func (m *Stub) Register(kubeletEndpoint, resourceName string) error { // ListAndWatch lists devices and update that list according to the Update call func (m *Stub) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { log.Println("ListAndWatch") - var devs []*pluginapi.Device - for _, d := range m.devs { - devs = append(devs, &pluginapi.Device{ - ID: d.ID, - Health: pluginapi.Healthy, - }) - } - - s.Send(&pluginapi.ListAndWatchResponse{Devices: devs}) + s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs}) for { select { diff --git a/pkg/kubelet/cm/deviceplugin/endpoint_test.go b/pkg/kubelet/cm/deviceplugin/endpoint_test.go index 226148a6b06..6005310181a 100644 --- a/pkg/kubelet/cm/deviceplugin/endpoint_test.go +++ b/pkg/kubelet/cm/deviceplugin/endpoint_test.go @@ -19,7 +19,6 @@ package deviceplugin import ( "path" "testing" - "time" "github.com/stretchr/testify/require" @@ -54,23 +53,49 @@ func TestRun(t *testing.T) { {ID: "AThirdDeviceId", Health: pluginapi.Healthy}, } - p, e := esetup(t, devs, socket, "mock", func(n string, a, u, r []pluginapi.Device) { - require.Len(t, a, 1) - require.Len(t, u, 1) - require.Len(t, r, 1) + callbackCount := 0 + callbackChan := make(chan int) + callback := func(n string, a, u, r []pluginapi.Device) { + // Should be called twice: + // one for plugin registration, one for plugin update. + if callbackCount > 2 { + t.FailNow() + } - require.Equal(t, a[0].ID, updated[1].ID) + // Check plugin registration + if callbackCount == 0 { + require.Len(t, a, 2) + require.Len(t, u, 0) + require.Len(t, r, 0) + } - require.Equal(t, u[0].ID, updated[0].ID) - require.Equal(t, u[0].Health, updated[0].Health) + // Check plugin update + if callbackCount == 1 { + require.Len(t, a, 1) + require.Len(t, u, 1) + require.Len(t, r, 1) - require.Equal(t, r[0].ID, devs[1].ID) - }) + require.Equal(t, a[0].ID, updated[1].ID) + require.Equal(t, u[0].ID, updated[0].ID) + require.Equal(t, u[0].Health, updated[0].Health) + require.Equal(t, r[0].ID, devs[1].ID) + } + + callbackCount++ + callbackChan <- callbackCount + } + + p, e := esetup(t, devs, socket, "mock", callback) defer ecleanup(t, p, e) go e.run() + // Wait for the first callback to be issued. + <-callbackChan + p.Update(updated) - time.Sleep(time.Second) + + // Wait for the second callback to be issued. + <-callbackChan e.mutex.Lock() defer e.mutex.Unlock() @@ -102,7 +127,7 @@ func esetup(t *testing.T, devs []*pluginapi.Device, socket, resourceName string, err := p.Start() require.NoError(t, err) - e, err := newEndpointImpl(socket, "mock", make(map[string]pluginapi.Device), func(n string, a, u, r []pluginapi.Device) {}) + e, err := newEndpointImpl(socket, resourceName, make(map[string]pluginapi.Device), callback) require.NoError(t, err) return p, e