diff --git a/pkg/kubelet/cm/container_manager.go b/pkg/kubelet/cm/container_manager.go index 8d467bf2054..64728a7d6b1 100644 --- a/pkg/kubelet/cm/container_manager.go +++ b/pkg/kubelet/cm/container_manager.go @@ -95,7 +95,11 @@ type ContainerManager interface { // GetPodCgroupRoot returns the cgroup which contains all pods. GetPodCgroupRoot() string - GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn + + // GetPluginRegistrationHandler returns a plugin registration handler + // The pluginwatcher's Handlers allow to have a single module for handling + // registration. + GetPluginRegistrationHandler() pluginwatcher.PluginHandler } type NodeConfig struct { diff --git a/pkg/kubelet/cm/container_manager_linux.go b/pkg/kubelet/cm/container_manager_linux.go index 2f460ba8cba..132e2de9abe 100644 --- a/pkg/kubelet/cm/container_manager_linux.go +++ b/pkg/kubelet/cm/container_manager_linux.go @@ -605,8 +605,8 @@ func (cm *containerManagerImpl) Start(node *v1.Node, return nil } -func (cm *containerManagerImpl) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn { - return cm.deviceManager.GetWatcherCallback() +func (cm *containerManagerImpl) GetPluginRegistrationHandler() pluginwatcher.PluginHandler { + return cm.deviceManager.GetWatcherHandler() } // TODO: move the GetResources logic to PodContainerManager. diff --git a/pkg/kubelet/cm/container_manager_stub.go b/pkg/kubelet/cm/container_manager_stub.go index ed219808492..8f948c64d2a 100644 --- a/pkg/kubelet/cm/container_manager_stub.go +++ b/pkg/kubelet/cm/container_manager_stub.go @@ -77,10 +77,8 @@ func (cm *containerManagerStub) GetCapacity() v1.ResourceList { return c } -func (cm *containerManagerStub) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn { - return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { - return nil, nil - } +func (cm *containerManagerStub) GetPluginRegistrationHandler() pluginwatcher.PluginHandler { + return nil } func (cm *containerManagerStub) GetDevicePluginResourceCapacity() (v1.ResourceList, v1.ResourceList, []string) { diff --git a/pkg/kubelet/cm/devicemanager/manager.go b/pkg/kubelet/cm/devicemanager/manager.go index 9aacd08af15..8064b572b39 100644 --- a/pkg/kubelet/cm/devicemanager/manager.go +++ b/pkg/kubelet/cm/devicemanager/manager.go @@ -56,7 +56,7 @@ type ManagerImpl struct { socketname string socketdir string - endpoints map[string]endpoint // Key is ResourceName + endpoints map[string]endpointInfo // Key is ResourceName mutex sync.Mutex server *grpc.Server @@ -86,10 +86,14 @@ type ManagerImpl struct { // podDevices contains pod to allocated device mapping. podDevices podDevices - pluginOpts map[string]*pluginapi.DevicePluginOptions checkpointManager checkpointmanager.CheckpointManager } +type endpointInfo struct { + e endpoint + opts *pluginapi.DevicePluginOptions +} + type sourcesReadyStub struct{} func (s *sourcesReadyStub) AddSource(source string) {} @@ -109,13 +113,13 @@ func newManagerImpl(socketPath string) (*ManagerImpl, error) { dir, file := filepath.Split(socketPath) manager := &ManagerImpl{ - endpoints: make(map[string]endpoint), + endpoints: make(map[string]endpointInfo), + socketname: file, socketdir: dir, healthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String), - pluginOpts: make(map[string]*pluginapi.DevicePluginOptions), podDevices: make(podDevices), } manager.callback = manager.genericDeviceUpdateCallback @@ -228,8 +232,8 @@ func (m *ManagerImpl) Start(activePods ActivePodsFunc, sourcesReady config.Sourc return nil } -// GetWatcherCallback returns callback function to be registered with plugin watcher -func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn { +// GetWatcherHandler returns the plugin handler +func (m *ManagerImpl) GetWatcherHandler() watcher.PluginHandler { if f, err := os.Create(m.socketdir + "DEPRECATION"); err != nil { glog.Errorf("Failed to create deprecation file at %s", m.socketdir) } else { @@ -237,16 +241,57 @@ func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn { glog.V(4).Infof("created deprecation file %s", f.Name()) } - return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { - if !m.isVersionCompatibleWithPlugin(versions) { - return nil, fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions) - } + return watcher.PluginHandler(m) +} - if !v1helper.IsExtendedResourceName(v1.ResourceName(name)) { - return nil, fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, name)) - } +// ValidatePlugin validates a plugin if the version is correct and the name has the format of an extended resource +func (m *ManagerImpl) ValidatePlugin(pluginName string, endpoint string, versions []string) error { + glog.V(2).Infof("Got Plugin %s at endpoint %s with versions %v", pluginName, endpoint, versions) - return m.addEndpointProbeMode(name, sockPath) + if !m.isVersionCompatibleWithPlugin(versions) { + return fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions) + } + + if !v1helper.IsExtendedResourceName(v1.ResourceName(pluginName)) { + return fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, pluginName)) + } + + return nil +} + +// RegisterPlugin starts the endpoint and registers it +// TODO: Start the endpoint and wait for the First ListAndWatch call +// before registering the plugin +func (m *ManagerImpl) RegisterPlugin(pluginName string, endpoint string) error { + glog.V(2).Infof("Registering Plugin %s at endpoint %s", pluginName, endpoint) + + e, err := newEndpointImpl(endpoint, pluginName, m.callback) + if err != nil { + return fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", endpoint, err) + } + + options, err := e.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{}) + if err != nil { + return fmt.Errorf("Failed to get device plugin options: %v", err) + } + + m.registerEndpoint(pluginName, options, e) + go m.runEndpoint(pluginName, e) + + return nil +} + +// DeRegisterPlugin deregisters the plugin +// TODO work on the behavior for deregistering plugins +// e.g: Should we delete the resource +func (m *ManagerImpl) DeRegisterPlugin(pluginName string) { + m.mutex.Lock() + defer m.mutex.Unlock() + + // Note: This will mark the resource unhealthy as per the behavior + // in runEndpoint + if eI, ok := m.endpoints[pluginName]; ok { + eI.e.stop() } } @@ -333,8 +378,8 @@ func (m *ManagerImpl) Register(ctx context.Context, r *pluginapi.RegisterRequest func (m *ManagerImpl) Stop() error { m.mutex.Lock() defer m.mutex.Unlock() - for _, e := range m.endpoints { - e.stop() + for _, eI := range m.endpoints { + eI.e.stop() } if m.server == nil { @@ -346,51 +391,26 @@ func (m *ManagerImpl) Stop() error { return nil } -func (m *ManagerImpl) addEndpointProbeMode(resourceName string, socketPath string) (chan bool, error) { - chanForAckOfNotification := make(chan bool) - - new, err := newEndpointImpl(socketPath, resourceName, m.callback) - if err != nil { - glog.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err) - return nil, fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err) - } - - options, err := new.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{}) - if err != nil { - glog.Errorf("Failed to get device plugin options: %v", err) - return nil, fmt.Errorf("Failed to get device plugin options: %v", err) - } - m.registerEndpoint(resourceName, options, new) - - go func() { - select { - case <-chanForAckOfNotification: - close(chanForAckOfNotification) - m.runEndpoint(resourceName, new) - case <-time.After(time.Second): - glog.Errorf("Timed out while waiting for notification ack from plugin") - } - }() - return chanForAckOfNotification, nil -} - -func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e *endpointImpl) { +func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e endpoint) { m.mutex.Lock() defer m.mutex.Unlock() - m.pluginOpts[resourceName] = options - m.endpoints[resourceName] = e + + m.endpoints[resourceName] = endpointInfo{e: e, opts: options} glog.V(2).Infof("Registered endpoint %v", e) } -func (m *ManagerImpl) runEndpoint(resourceName string, e *endpointImpl) { +func (m *ManagerImpl) runEndpoint(resourceName string, e endpoint) { e.run() e.stop() + m.mutex.Lock() defer m.mutex.Unlock() - if old, ok := m.endpoints[resourceName]; ok && old == e { + + if old, ok := m.endpoints[resourceName]; ok && old.e == e { m.markResourceUnhealthy(resourceName) } - glog.V(2).Infof("Unregistered endpoint %v", e) + + glog.V(2).Infof("Endpoint (%s, %v) became unhealthy", resourceName, e) } func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) { @@ -437,8 +457,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string) deletedResources := sets.NewString() m.mutex.Lock() for resourceName, devices := range m.healthyDevices { - e, ok := m.endpoints[resourceName] - if (ok && e.stopGracePeriodExpired()) || !ok { + eI, ok := m.endpoints[resourceName] + if (ok && eI.e.stopGracePeriodExpired()) || !ok { // The resources contained in endpoints and (un)healthyDevices // should always be consistent. Otherwise, we run with the risk // of failing to garbage collect non-existing resources or devices. @@ -455,8 +475,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string) } } for resourceName, devices := range m.unhealthyDevices { - e, ok := m.endpoints[resourceName] - if (ok && e.stopGracePeriodExpired()) || !ok { + eI, ok := m.endpoints[resourceName] + if (ok && eI.e.stopGracePeriodExpired()) || !ok { if !ok { glog.Errorf("unexpected: unhealthyDevices and endpoints are out of sync") } @@ -519,7 +539,7 @@ func (m *ManagerImpl) readCheckpoint() error { // will stay zero till the corresponding device plugin re-registers. m.healthyDevices[resource] = sets.NewString() m.unhealthyDevices[resource] = sets.NewString() - m.endpoints[resource] = newStoppedEndpointImpl(resource) + m.endpoints[resource] = endpointInfo{e: newStoppedEndpointImpl(resource), opts: nil} } return nil } @@ -652,7 +672,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont // plugin Allocate grpc calls if it becomes common that a container may require // resources from multiple device plugins. m.mutex.Lock() - e, ok := m.endpoints[resource] + eI, ok := m.endpoints[resource] m.mutex.Unlock() if !ok { m.mutex.Lock() @@ -665,7 +685,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont // TODO: refactor this part of code to just append a ContainerAllocationRequest // in a passed in AllocateRequest pointer, and issues a single Allocate call per pod. glog.V(3).Infof("Making allocation request for devices %v for device plugin %s", devs, resource) - resp, err := e.allocate(devs) + resp, err := eI.e.allocate(devs) metrics.DevicePluginAllocationLatency.WithLabelValues(resource).Observe(metrics.SinceInMicroseconds(startRPCTime)) if err != nil { // In case of allocation failure, we want to restore m.allocatedDevices @@ -715,11 +735,13 @@ func (m *ManagerImpl) GetDeviceRunContainerOptions(pod *v1.Pod, container *v1.Co // with PreStartRequired option set. func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource string) error { m.mutex.Lock() - opts, ok := m.pluginOpts[resource] + eI, ok := m.endpoints[resource] if !ok { m.mutex.Unlock() - return fmt.Errorf("Plugin options not found in cache for resource: %s", resource) - } else if opts == nil || !opts.PreStartRequired { + return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource) + } + + if eI.opts == nil || !eI.opts.PreStartRequired { m.mutex.Unlock() glog.V(4).Infof("Plugin options indicate to skip PreStartContainer for resource: %s", resource) return nil @@ -731,16 +753,10 @@ func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource s return fmt.Errorf("no devices found allocated in local cache for pod %s, container %s, resource %s", podUID, contName, resource) } - e, ok := m.endpoints[resource] - if !ok { - m.mutex.Unlock() - return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource) - } - m.mutex.Unlock() devs := devices.UnsortedList() glog.V(4).Infof("Issuing an PreStartContainer call for container, %s, of pod %s", contName, podUID) - _, err := e.preStartContainer(devs) + _, err := eI.e.preStartContainer(devs) if err != nil { return fmt.Errorf("device plugin PreStartContainer rpc failed with err: %v", err) } diff --git a/pkg/kubelet/cm/devicemanager/manager_stub.go b/pkg/kubelet/cm/devicemanager/manager_stub.go index 66f8d1004cd..1008daca3b7 100644 --- a/pkg/kubelet/cm/devicemanager/manager_stub.go +++ b/pkg/kubelet/cm/devicemanager/manager_stub.go @@ -57,9 +57,7 @@ func (h *ManagerStub) GetCapacity() (v1.ResourceList, v1.ResourceList, []string) return nil, nil, []string{} } -// GetWatcherCallback returns plugin watcher callback -func (h *ManagerStub) GetWatcherCallback() pluginwatcher.RegisterCallbackFn { - return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { - return nil, nil - } +// GetWatcherHandler returns plugin watcher interface +func (h *ManagerStub) GetWatcherHandler() pluginwatcher.PluginHandler { + return nil } diff --git a/pkg/kubelet/cm/devicemanager/manager_test.go b/pkg/kubelet/cm/devicemanager/manager_test.go index b6ddb46506a..7168f2342c2 100644 --- a/pkg/kubelet/cm/devicemanager/manager_test.go +++ b/pkg/kubelet/cm/devicemanager/manager_test.go @@ -249,9 +249,10 @@ func setupDevicePlugin(t *testing.T, devs []*pluginapi.Device, pluginSocketName func setupPluginWatcher(pluginSocketName string, m Manager) *pluginwatcher.Watcher { w := pluginwatcher.NewWatcher(filepath.Dir(pluginSocketName)) - w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherCallback()) + w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherHandler()) w.Start() - return &w + + return w } func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub) { @@ -295,7 +296,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) { // Expects capacity for resource1 to be 2. resourceName1 := "domain1.com/resource1" e1 := &endpointImpl{} - testManager.endpoints[resourceName1] = e1 + testManager.endpoints[resourceName1] = endpointInfo{e: e1, opts: nil} callback(resourceName1, devs) capacity, allocatable, removedResources := testManager.GetCapacity() resource1Capacity, ok := capacity[v1.ResourceName(resourceName1)] @@ -345,7 +346,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) { // Tests adding another resource. resourceName2 := "resource2" e2 := &endpointImpl{} - testManager.endpoints[resourceName2] = e2 + testManager.endpoints[resourceName2] = endpointInfo{e: e2, opts: nil} callback(resourceName2, devs) capacity, allocatable, removedResources = testManager.GetCapacity() as.Equal(2, len(capacity)) @@ -456,7 +457,7 @@ func TestCheckpoint(t *testing.T) { ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) as.Nil(err) testManager := &ManagerImpl{ - endpoints: make(map[string]endpoint), + endpoints: make(map[string]endpointInfo), healthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String), @@ -577,7 +578,7 @@ func makePod(limits v1.ResourceList) *v1.Pod { } } -func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource, opts map[string]*pluginapi.DevicePluginOptions) (*ManagerImpl, error) { +func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource) (*ManagerImpl, error) { monitorCallback := func(resourceName string, devices []pluginapi.Device) {} ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) if err != nil { @@ -589,41 +590,45 @@ func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestReso healthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String), - endpoints: make(map[string]endpoint), - pluginOpts: opts, + endpoints: make(map[string]endpointInfo), podDevices: make(podDevices), activePods: activePods, sourcesReady: &sourcesReadyStub{}, checkpointManager: ckm, } + for _, res := range testRes { testManager.healthyDevices[res.resourceName] = sets.NewString() for _, dev := range res.devs { testManager.healthyDevices[res.resourceName].Insert(dev) } if res.resourceName == "domain1.com/resource1" { - testManager.endpoints[res.resourceName] = &MockEndpoint{ - allocateFunc: allocateStubFunc(), + testManager.endpoints[res.resourceName] = endpointInfo{ + e: &MockEndpoint{allocateFunc: allocateStubFunc()}, + opts: nil, } } if res.resourceName == "domain2.com/resource2" { - testManager.endpoints[res.resourceName] = &MockEndpoint{ - allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) { - resp := new(pluginapi.ContainerAllocateResponse) - resp.Envs = make(map[string]string) - for _, dev := range devs { - switch dev { - case "dev3": - resp.Envs["key2"] = "val2" + testManager.endpoints[res.resourceName] = endpointInfo{ + e: &MockEndpoint{ + allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) { + resp := new(pluginapi.ContainerAllocateResponse) + resp.Envs = make(map[string]string) + for _, dev := range devs { + switch dev { + case "dev3": + resp.Envs["key2"] = "val2" - case "dev4": - resp.Envs["key2"] = "val3" + case "dev4": + resp.Envs["key2"] = "val3" + } } - } - resps := new(pluginapi.AllocateResponse) - resps.ContainerResponses = append(resps.ContainerResponses, resp) - return resps, nil + resps := new(pluginapi.AllocateResponse) + resps.ContainerResponses = append(resps.ContainerResponses, resp) + return resps, nil + }, }, + opts: nil, } } } @@ -669,10 +674,7 @@ func TestPodContainerDeviceAllocation(t *testing.T) { as.Nil(err) defer os.RemoveAll(tmpDir) nodeInfo := getTestNodeInfo(v1.ResourceList{}) - pluginOpts := make(map[string]*pluginapi.DevicePluginOptions) - pluginOpts[res1.resourceName] = nil - pluginOpts[res2.resourceName] = nil - testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts) + testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources) as.Nil(err) testPods := []*v1.Pod{ @@ -767,10 +769,8 @@ func TestInitContainerDeviceAllocation(t *testing.T) { tmpDir, err := ioutil.TempDir("", "checkpoint") as.Nil(err) defer os.RemoveAll(tmpDir) - pluginOpts := make(map[string]*pluginapi.DevicePluginOptions) - pluginOpts[res1.resourceName] = nil - pluginOpts[res2.resourceName] = nil - testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts) + + testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources) as.Nil(err) podWithPluginResourcesInInitContainers := &v1.Pod{ @@ -904,18 +904,18 @@ func TestDevicePreStartContainer(t *testing.T) { as.Nil(err) defer os.RemoveAll(tmpDir) nodeInfo := getTestNodeInfo(v1.ResourceList{}) - pluginOpts := make(map[string]*pluginapi.DevicePluginOptions) - pluginOpts[res1.resourceName] = &pluginapi.DevicePluginOptions{PreStartRequired: true} - testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1}, pluginOpts) + testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1}) as.Nil(err) ch := make(chan []string, 1) - testManager.endpoints[res1.resourceName] = &MockEndpoint{ - initChan: ch, - allocateFunc: allocateStubFunc(), + testManager.endpoints[res1.resourceName] = endpointInfo{ + e: &MockEndpoint{ + initChan: ch, + allocateFunc: allocateStubFunc(), + }, + opts: &pluginapi.DevicePluginOptions{PreStartRequired: true}, } - pod := makePod(v1.ResourceList{ v1.ResourceName(res1.resourceName): res1.resourceQuantity}) activePods := []*v1.Pod{} diff --git a/pkg/kubelet/cm/devicemanager/types.go b/pkg/kubelet/cm/devicemanager/types.go index 52176dec71a..35923b00d12 100644 --- a/pkg/kubelet/cm/devicemanager/types.go +++ b/pkg/kubelet/cm/devicemanager/types.go @@ -53,7 +53,7 @@ type Manager interface { // GetCapacity returns the amount of available device plugin resource capacity, resource allocatable // and inactive device plugin resources previously registered on the node. GetCapacity() (v1.ResourceList, v1.ResourceList, []string) - GetWatcherCallback() watcher.RegisterCallbackFn + GetWatcherHandler() watcher.PluginHandler } // DeviceRunContainerOptions contains the combined container runtime settings to consume its allocated devices. diff --git a/pkg/kubelet/kubelet.go b/pkg/kubelet/kubelet.go index 506f409f727..6d31ace48d4 100644 --- a/pkg/kubelet/kubelet.go +++ b/pkg/kubelet/kubelet.go @@ -1194,7 +1194,7 @@ type Kubelet struct { // pluginwatcher is a utility for Kubelet to register different types of node-level plugins // such as device plugins or CSI plugins. It discovers plugins by monitoring inotify events under the // directory returned by kubelet.getPluginsDir() - pluginWatcher pluginwatcher.Watcher + pluginWatcher *pluginwatcher.Watcher // This flag sets a maximum number of images to report in the node status. nodeStatusMaxImages int32 @@ -1365,9 +1365,9 @@ func (kl *Kubelet) initializeRuntimeDependentModules() { kl.containerLogManager.Start() if kl.enablePluginsWatcher { // Adding Registration Callback function for CSI Driver - kl.pluginWatcher.AddHandler("CSIPlugin", csi.RegistrationCallback) + kl.pluginWatcher.AddHandler("CSIPlugin", pluginwatcher.PluginHandler(csi.PluginHandler)) // Adding Registration Callback function for Device Manager - kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandlerCallback()) + kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandler()) // Start the plugin watcher glog.V(4).Infof("starting watcher") if err := kl.pluginWatcher.Start(); err != nil { diff --git a/pkg/kubelet/util/pluginwatcher/BUILD b/pkg/kubelet/util/pluginwatcher/BUILD index 7b887b444f5..0301c6b95b3 100644 --- a/pkg/kubelet/util/pluginwatcher/BUILD +++ b/pkg/kubelet/util/pluginwatcher/BUILD @@ -1,10 +1,4 @@ -package(default_visibility = ["//visibility:public"]) - -load( - "@io_bazel_rules_go//go:def.bzl", - "go_library", - "go_test", -) +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", @@ -12,8 +6,10 @@ go_library( "example_handler.go", "example_plugin.go", "plugin_watcher.go", + "types.go", ], importpath = "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher", + visibility = ["//visibility:public"], deps = [ "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library", @@ -27,6 +23,16 @@ go_library( ], ) +go_test( + name = "go_default_test", + srcs = ["plugin_watcher_test.go"], + embed = [":go_default_library"], + deps = [ + "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", + "//vendor/github.com/stretchr/testify/require:go_default_library", + ], +) + filegroup( name = "package-srcs", srcs = glob(["**"]), @@ -44,14 +50,3 @@ filegroup( tags = ["automanaged"], visibility = ["//visibility:public"], ) - -go_test( - name = "go_default_test", - srcs = ["plugin_watcher_test.go"], - embed = [":go_default_library"], - deps = [ - "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", - "//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library", - "//vendor/github.com/stretchr/testify/require:go_default_library", - ], -) diff --git a/pkg/kubelet/util/pluginwatcher/example_handler.go b/pkg/kubelet/util/pluginwatcher/example_handler.go index 4eae4188d69..8f9cac5d9bd 100644 --- a/pkg/kubelet/util/pluginwatcher/example_handler.go +++ b/pkg/kubelet/util/pluginwatcher/example_handler.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "github.com/golang/glog" "golang.org/x/net/context" v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" @@ -30,41 +31,61 @@ import ( ) type exampleHandler struct { - registeredPlugins map[string]struct{} - mutex sync.Mutex - chanForHandlerAckErrors chan error // for testing + SupportedVersions []string + ExpectedNames map[string]int + + eventChans map[string]chan examplePluginEvent // map[pluginName]eventChan + + m sync.Mutex + count int } +type examplePluginEvent int + +const ( + exampleEventValidate examplePluginEvent = 0 + exampleEventRegister examplePluginEvent = 1 + exampleEventDeRegister examplePluginEvent = 2 + exampleEventError examplePluginEvent = 3 +) + // NewExampleHandler provide a example handler -func NewExampleHandler() *exampleHandler { +func NewExampleHandler(supportedVersions []string) *exampleHandler { return &exampleHandler{ - chanForHandlerAckErrors: make(chan error), - registeredPlugins: make(map[string]struct{}), + SupportedVersions: supportedVersions, + ExpectedNames: make(map[string]int), + + eventChans: make(map[string]chan examplePluginEvent), } } -func (h *exampleHandler) Cleanup() error { - h.mutex.Lock() - defer h.mutex.Unlock() - h.registeredPlugins = make(map[string]struct{}) - return nil -} +func (p *exampleHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error { + p.SendEvent(pluginName, exampleEventValidate) -func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []string, sockPath string) (chan bool, error) { + n, ok := p.DecreasePluginCount(pluginName) + if !ok && n > 0 { + return fmt.Errorf("pluginName('%s') wasn't expected (count is %d)", pluginName, n) + } - // check for supported versions - if !reflect.DeepEqual([]string{"v1beta1", "v1beta2"}, versions) { - return nil, fmt.Errorf("not the supported versions: %s", versions) + if !reflect.DeepEqual(versions, p.SupportedVersions) { + return fmt.Errorf("versions('%v') != supported versions('%v')", versions, p.SupportedVersions) } // this handler expects non-empty endpoint as an example if len(endpoint) == 0 { - return nil, errors.New("expecting non empty endpoint") + return errors.New("expecting non empty endpoint") } - _, conn, err := dial(sockPath) + return nil +} + +func (p *exampleHandler) RegisterPlugin(pluginName, endpoint string) error { + p.SendEvent(pluginName, exampleEventRegister) + + // Verifies the grpcServer is ready to serve services. + _, conn, err := dial(endpoint, time.Second) if err != nil { - return nil, err + return fmt.Errorf("Failed dialing endpoint (%s): %v", endpoint, err) } defer conn.Close() @@ -73,33 +94,54 @@ func (h *exampleHandler) Handler(pluginName string, endpoint string, versions [] v1beta2Client := v1beta2.NewExampleClient(conn) // Tests v1beta1 GetExampleInfo - if _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}); err != nil { - return nil, err + _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}) + if err != nil { + return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err) } - // Tests v1beta2 GetExampleInfo - if _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}); err != nil { - return nil, err + // Tests v1beta1 GetExampleInfo + _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}) + if err != nil { + return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err) } - // handle registered plugin - h.mutex.Lock() - if _, exist := h.registeredPlugins[pluginName]; exist { - h.mutex.Unlock() - return nil, fmt.Errorf("plugin %s already registered", pluginName) - } - h.registeredPlugins[pluginName] = struct{}{} - h.mutex.Unlock() - - chanForAckOfNotification := make(chan bool) - go func() { - select { - case <-chanForAckOfNotification: - // TODO: handle the negative scenario - close(chanForAckOfNotification) - case <-time.After(time.Second): - h.chanForHandlerAckErrors <- errors.New("Timed out while waiting for notification ack") - } - }() - return chanForAckOfNotification, nil + return nil +} + +func (p *exampleHandler) DeRegisterPlugin(pluginName string) { + p.SendEvent(pluginName, exampleEventDeRegister) +} + +func (p *exampleHandler) EventChan(pluginName string) chan examplePluginEvent { + return p.eventChans[pluginName] +} + +func (p *exampleHandler) SendEvent(pluginName string, event examplePluginEvent) { + glog.V(2).Infof("Sending %v for plugin %s over chan %v", event, pluginName, p.eventChans[pluginName]) + p.eventChans[pluginName] <- event +} + +func (p *exampleHandler) AddPluginName(pluginName string) { + p.m.Lock() + defer p.m.Unlock() + + v, ok := p.ExpectedNames[pluginName] + if !ok { + p.eventChans[pluginName] = make(chan examplePluginEvent) + v = 1 + } + + p.ExpectedNames[pluginName] = v +} + +func (p *exampleHandler) DecreasePluginCount(pluginName string) (old int, ok bool) { + p.m.Lock() + defer p.m.Unlock() + + v, ok := p.ExpectedNames[pluginName] + if !ok { + v = -1 + } + + return v, ok } diff --git a/pkg/kubelet/util/pluginwatcher/example_plugin.go b/pkg/kubelet/util/pluginwatcher/example_plugin.go index 5c2dd966ba4..694b3661202 100644 --- a/pkg/kubelet/util/pluginwatcher/example_plugin.go +++ b/pkg/kubelet/util/pluginwatcher/example_plugin.go @@ -18,7 +18,9 @@ package pluginwatcher import ( "errors" + "fmt" "net" + "os" "sync" "time" @@ -39,6 +41,7 @@ type examplePlugin struct { endpoint string // for testing pluginName string pluginType string + versions []string } type pluginServiceV1Beta1 struct { @@ -73,12 +76,13 @@ func NewExamplePlugin() *examplePlugin { } // NewTestExamplePlugin returns an initialized examplePlugin instance for testing -func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string) *examplePlugin { +func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string, advertisedVersions ...string) *examplePlugin { return &examplePlugin{ pluginName: pluginName, pluginType: pluginType, - registrationStatus: make(chan registerapi.RegistrationStatus), endpoint: endpoint, + versions: advertisedVersions, + registrationStatus: make(chan registerapi.RegistrationStatus), } } @@ -88,36 +92,48 @@ func (e *examplePlugin) GetInfo(ctx context.Context, req *registerapi.InfoReques Type: e.pluginType, Name: e.pluginName, Endpoint: e.endpoint, - SupportedVersions: []string{"v1beta1", "v1beta2"}, + SupportedVersions: e.versions, }, nil } func (e *examplePlugin) NotifyRegistrationStatus(ctx context.Context, status *registerapi.RegistrationStatus) (*registerapi.RegistrationStatusResponse, error) { + glog.Errorf("Registration is: %v\n", status) + if e.registrationStatus != nil { e.registrationStatus <- *status } - if !status.PluginRegistered { - glog.Errorf("Registration failed: %s\n", status.Error) - } + return ®isterapi.RegistrationStatusResponse{}, nil } -// Serve starts example plugin grpc server -func (e *examplePlugin) Serve(socketPath string) error { - glog.Infof("starting example server at: %s\n", socketPath) - lis, err := net.Listen("unix", socketPath) +// Serve starts a pluginwatcher server and one or more of the plugin services +func (e *examplePlugin) Serve(services ...string) error { + glog.Infof("starting example server at: %s\n", e.endpoint) + lis, err := net.Listen("unix", e.endpoint) if err != nil { return err } - glog.Infof("example server started at: %s\n", socketPath) + + glog.Infof("example server started at: %s\n", e.endpoint) e.grpcServer = grpc.NewServer() + // Registers kubelet plugin watcher api. registerapi.RegisterRegistrationServer(e.grpcServer, e) - // Registers services for both v1beta1 and v1beta2 versions. - v1beta1 := &pluginServiceV1Beta1{server: e} - v1beta1.RegisterService() - v1beta2 := &pluginServiceV1Beta2{server: e} - v1beta2.RegisterService() + + for _, service := range services { + switch service { + case "v1beta1": + v1beta1 := &pluginServiceV1Beta1{server: e} + v1beta1.RegisterService() + break + case "v1beta2": + v1beta2 := &pluginServiceV1Beta2{server: e} + v1beta2.RegisterService() + break + default: + return fmt.Errorf("Unsupported service: '%s'", service) + } + } // Starts service e.wg.Add(1) @@ -128,22 +144,30 @@ func (e *examplePlugin) Serve(socketPath string) error { glog.Errorf("example server stopped serving: %v", err) } }() + return nil } func (e *examplePlugin) Stop() error { - glog.Infof("Stopping example server\n") + glog.Infof("Stopping example server at: %s\n", e.endpoint) + e.grpcServer.Stop() c := make(chan struct{}) go func() { defer close(c) e.wg.Wait() }() + select { case <-c: - return nil + break case <-time.After(time.Second): - glog.Errorf("Timed out on waiting for stop completion") return errors.New("Timed out on waiting for stop completion") } + + if err := os.Remove(e.endpoint); err != nil && !os.IsNotExist(err) { + return err + } + + return nil } diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go index 6db743dd4fb..cbc33e47444 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "os" + "strings" "sync" "time" @@ -28,43 +29,144 @@ import ( "github.com/pkg/errors" "golang.org/x/net/context" "google.golang.org/grpc" + registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" utilfs "k8s.io/kubernetes/pkg/util/filesystem" ) -// RegisterCallbackFn is the type of the callback function that handlers will provide -type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) - // Watcher is the plugin watcher type Watcher struct { path string - handlers map[string]RegisterCallbackFn stopCh chan interface{} fs utilfs.Filesystem fsWatcher *fsnotify.Watcher wg sync.WaitGroup - mutex sync.Mutex + + mutex sync.Mutex + handlers map[string]PluginHandler + plugins map[string]pathInfo + pluginsPool map[string]map[string]*sync.Mutex // map[pluginType][pluginName] +} + +type pathInfo struct { + pluginType string + pluginName string } // NewWatcher provides a new watcher -func NewWatcher(sockDir string) Watcher { - return Watcher{ - path: sockDir, - handlers: make(map[string]RegisterCallbackFn), - fs: &utilfs.DefaultFs{}, +func NewWatcher(sockDir string) *Watcher { + return &Watcher{ + path: sockDir, + fs: &utilfs.DefaultFs{}, + + handlers: make(map[string]PluginHandler), + plugins: make(map[string]pathInfo), + pluginsPool: make(map[string]map[string]*sync.Mutex), } } -// AddHandler registers a callback to be invoked for a particular type of plugin -func (w *Watcher) AddHandler(pluginType string, handlerCbkFn RegisterCallbackFn) { +func (w *Watcher) AddHandler(pluginType string, handler PluginHandler) { w.mutex.Lock() defer w.mutex.Unlock() - w.handlers[pluginType] = handlerCbkFn + + w.handlers[pluginType] = handler } -// Creates the plugin directory, if it doesn't already exist. -func (w *Watcher) createPluginDir() error { +func (w *Watcher) getHandler(pluginType string) (PluginHandler, bool) { + w.mutex.Lock() + defer w.mutex.Unlock() + + h, ok := w.handlers[pluginType] + return h, ok +} + +// Start watches for the creation of plugin sockets at the path +func (w *Watcher) Start() error { + glog.V(2).Infof("Plugin Watcher Start at %s", w.path) + w.stopCh = make(chan interface{}) + + // Creating the directory to be watched if it doesn't exist yet, + // and walks through the directory to discover the existing plugins. + if err := w.init(); err != nil { + return err + } + + fsWatcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err) + } + w.fsWatcher = fsWatcher + + w.wg.Add(1) + go func(fsWatcher *fsnotify.Watcher) { + defer w.wg.Done() + for { + select { + case event := <-fsWatcher.Events: + //TODO: Handle errors by taking corrective measures + + w.wg.Add(1) + go func() { + defer w.wg.Done() + + if event.Op&fsnotify.Create == fsnotify.Create { + err := w.handleCreateEvent(event) + if err != nil { + glog.Errorf("error %v when handling create event: %s", err, event) + } + } else if event.Op&fsnotify.Remove == fsnotify.Remove { + err := w.handleDeleteEvent(event) + if err != nil { + glog.Errorf("error %v when handling delete event: %s", err, event) + } + } + return + }() + continue + case err := <-fsWatcher.Errors: + if err != nil { + glog.Errorf("fsWatcher received error: %v", err) + } + continue + case <-w.stopCh: + return + } + } + }(fsWatcher) + + // Traverse plugin dir after starting the plugin processing goroutine + if err := w.traversePluginDir(w.path); err != nil { + w.Stop() + return fmt.Errorf("failed to traverse plugin socket path, err: %v", err) + } + + return nil +} + +// Stop stops probing the creation of plugin sockets at the path +func (w *Watcher) Stop() error { + close(w.stopCh) + + c := make(chan struct{}) + go func() { + defer close(c) + w.wg.Wait() + }() + + select { + case <-c: + case <-time.After(11 * time.Second): + return fmt.Errorf("timeout on stopping watcher") + } + + w.fsWatcher.Close() + + return nil +} + +func (w *Watcher) init() error { glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path) + if err := w.fs.MkdirAll(w.path, 0755); err != nil { return fmt.Errorf("error (re-)creating root %s: %v", w.path, err) } @@ -91,22 +193,38 @@ func (w *Watcher) traversePluginDir(dir string) error { Op: fsnotify.Create, } }() + default: + glog.V(5).Infof("Ignoring file %s with mode %v", path, mode) } return nil }) } -func (w *Watcher) init() error { - if err := w.createPluginDir(); err != nil { - return err +// Handle filesystem notify event. +func (w *Watcher) handleCreateEvent(event fsnotify.Event) error { + glog.V(6).Infof("Handling create event: %v", event) + + fi, err := os.Stat(event.Name) + if err != nil { + return fmt.Errorf("stat file %s failed: %v", event.Name, err) } - return nil + + if strings.HasPrefix(fi.Name(), ".") { + glog.Errorf("Ignoring file: %s", fi.Name()) + return nil + } + + if !fi.IsDir() { + return w.handlePluginRegistration(event.Name) + } + + return w.traversePluginDir(event.Name) } -func (w *Watcher) registerPlugin(socketPath string) error { +func (w *Watcher) handlePluginRegistration(socketPath string) error { //TODO: Implement rate limiting to mitigate any DOS kind of attacks. - client, conn, err := dial(socketPath) + client, conn, err := dial(socketPath, 10*time.Second) if err != nil { return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err) } @@ -114,154 +232,161 @@ func (w *Watcher) registerPlugin(socketPath string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + infoResp, err := client.GetInfo(ctx, ®isterapi.InfoRequest{}) if err != nil { return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err) } - return w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath) -} - -func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, client registerapi.RegistrationClient, infoResp *registerapi.PluginInfo, socketPath string) error { - var handlerCbkFn RegisterCallbackFn - var ok bool - handlerCbkFn, ok = w.handlers[infoResp.Type] + handler, ok := w.handlers[infoResp.Type] if !ok { - errStr := fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath) - if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ - PluginRegistered: false, - Error: errStr, - }); err != nil { - return errors.Wrap(err, errStr) - } - return errors.New(errStr) + return w.notifyPlugin(client, false, fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath)) } - var versions []string - for _, version := range infoResp.SupportedVersions { - versions = append(versions, version) + // ReRegistration: We want to handle multiple plugins registering at the same time with the same name sequentially. + // See the state machine for more information. + // This is done by using a Lock for each plugin with the same name and type + pool := w.getPluginPool(infoResp.Type, infoResp.Name) + + pool.Lock() + defer pool.Unlock() + + if infoResp.Endpoint == "" { + infoResp.Endpoint = socketPath } + // calls handler callback to verify registration request - chanForAckOfNotification, err := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) - if err != nil { - errStr := fmt.Sprintf("plugin registration failed with err: %v", err) - if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ - PluginRegistered: false, - Error: errStr, - }); err != nil { - return errors.Wrap(err, errStr) - } - return errors.New(errStr) + if err := handler.ValidatePlugin(infoResp.Name, infoResp.Endpoint, infoResp.SupportedVersions); err != nil { + return w.notifyPlugin(client, false, fmt.Sprintf("plugin validation failed with err: %v", err)) } - if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ - PluginRegistered: true, - }); err != nil { - chanForAckOfNotification <- false + // We add the plugin to the pluginwatcher's map before calling a plugin consumer's Register handle + // so that if we receive a delete event during Register Plugin, we can process it as a DeRegister call. + w.registerPlugin(socketPath, infoResp.Type, infoResp.Name) + + if err := handler.RegisterPlugin(infoResp.Name, infoResp.Endpoint); err != nil { + return w.notifyPlugin(client, false, fmt.Sprintf("plugin registration failed with err: %v", err)) + } + + // Notify is called after register to guarantee that even if notify throws an error Register will always be called after validate + if err := w.notifyPlugin(client, true, ""); err != nil { return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err) } - chanForAckOfNotification <- true return nil } -// Handle filesystem notify event. -func (w *Watcher) handleFsNotifyEvent(event fsnotify.Event) error { - if event.Op&fsnotify.Create != fsnotify.Create { +func (w *Watcher) handleDeleteEvent(event fsnotify.Event) error { + glog.V(6).Infof("Handling delete event: %v", event) + + plugin, ok := w.getPlugin(event.Name) + if !ok { + return fmt.Errorf("could not find plugin for deleted file %s", event.Name) + } + + // You should not get a Deregister call while registering a plugin + pool := w.getPluginPool(plugin.pluginType, plugin.pluginName) + + pool.Lock() + defer pool.Unlock() + + // ReRegisteration: When waiting for the lock a plugin with the same name (not socketPath) could have registered + // In that case, we don't want to issue a DeRegister call for that plugin + // When ReRegistering, the new plugin will have removed the current mapping (map[socketPath] = plugin) and replaced + // it with it's own socketPath. + if _, ok = w.getPlugin(event.Name); !ok { + glog.V(2).Infof("A newer plugin watcher has been registered for plugin %v, dropping DeRegister call", plugin) return nil } - fi, err := os.Stat(event.Name) - if err != nil { - return fmt.Errorf("stat file %s failed: %v", event.Name, err) + h, ok := w.getHandler(plugin.pluginType) + if !ok { + return fmt.Errorf("could not find handler %s for plugin %s at path %s", plugin.pluginType, plugin.pluginName, event.Name) } - if !fi.IsDir() { - return w.registerPlugin(event.Name) - } - - if err := w.traversePluginDir(event.Name); err != nil { - return fmt.Errorf("failed to traverse plugin path %s, err: %v", event.Name, err) - } + glog.V(2).Infof("DeRegistering plugin %v at path %s", plugin, event.Name) + w.deRegisterPlugin(event.Name, plugin.pluginType, plugin.pluginName) + h.DeRegisterPlugin(plugin.pluginName) return nil } -// Start watches for the creation of plugin sockets at the path -func (w *Watcher) Start() error { - glog.V(2).Infof("Plugin Watcher Start at %s", w.path) - w.stopCh = make(chan interface{}) +func (w *Watcher) registerPlugin(socketPath, pluginType, pluginName string) { + w.mutex.Lock() + defer w.mutex.Unlock() - // Creating the directory to be watched if it doesn't exist yet, - // and walks through the directory to discover the existing plugins. - if err := w.init(); err != nil { - return err - } - - fsWatcher, err := fsnotify.NewWatcher() - if err != nil { - return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err) - } - w.fsWatcher = fsWatcher - - if err := w.traversePluginDir(w.path); err != nil { - fsWatcher.Close() - return fmt.Errorf("failed to traverse plugin socket path, err: %v", err) - } - - w.wg.Add(1) - go func(fsWatcher *fsnotify.Watcher) { - defer w.wg.Done() - for { - select { - case event := <-fsWatcher.Events: - //TODO: Handle errors by taking corrective measures - go func() { - err := w.handleFsNotifyEvent(event) - if err != nil { - glog.Errorf("error %v when handle event: %s", err, event) - } - }() - continue - case err := <-fsWatcher.Errors: - if err != nil { - glog.Errorf("fsWatcher received error: %v", err) - } - continue - case <-w.stopCh: - fsWatcher.Close() - return - } + // Reregistration case, if this plugin is already in the map, remove it + // This will prevent handleDeleteEvent to issue a DeRegister call + for path, info := range w.plugins { + if info.pluginType != pluginType || info.pluginName != pluginName { + continue } - }(fsWatcher) - return nil -} -// Stop stops probing the creation of plugin sockets at the path -func (w *Watcher) Stop() error { - close(w.stopCh) - c := make(chan struct{}) - go func() { - defer close(c) - w.wg.Wait() - }() - select { - case <-c: - case <-time.After(10 * time.Second): - return fmt.Errorf("timeout on stopping watcher") + delete(w.plugins, path) + break + } + + w.plugins[socketPath] = pathInfo{ + pluginType: pluginType, + pluginName: pluginName, } - return nil } -// Cleanup cleans the path by removing sockets -func (w *Watcher) Cleanup() error { - return os.RemoveAll(w.path) +func (w *Watcher) deRegisterPlugin(socketPath, pluginType, pluginName string) { + w.mutex.Lock() + defer w.mutex.Unlock() + + delete(w.plugins, socketPath) + delete(w.pluginsPool[pluginType], pluginName) +} + +func (w *Watcher) getPlugin(socketPath string) (pathInfo, bool) { + w.mutex.Lock() + defer w.mutex.Unlock() + + plugin, ok := w.plugins[socketPath] + return plugin, ok +} + +func (w *Watcher) getPluginPool(pluginType, pluginName string) *sync.Mutex { + w.mutex.Lock() + defer w.mutex.Unlock() + + if _, ok := w.pluginsPool[pluginType]; !ok { + w.pluginsPool[pluginType] = make(map[string]*sync.Mutex) + } + + if _, ok := w.pluginsPool[pluginType][pluginName]; !ok { + w.pluginsPool[pluginType][pluginName] = &sync.Mutex{} + } + + return w.pluginsPool[pluginType][pluginName] +} + +func (w *Watcher) notifyPlugin(client registerapi.RegistrationClient, registered bool, errStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + status := ®isterapi.RegistrationStatus{ + PluginRegistered: registered, + Error: errStr, + } + + if _, err := client.NotifyRegistrationStatus(ctx, status); err != nil { + return errors.Wrap(err, errStr) + } + + if errStr != "" { + return errors.New(errStr) + } + + return nil } // Dial establishes the gRPC communication with the picked up plugin socket. https://godoc.org/google.golang.org/grpc#Dial -func dial(unixSocketPath string) (registerapi.RegistrationClient, *grpc.ClientConn, error) { +func dial(unixSocketPath string, timeout time.Duration) (registerapi.RegistrationClient, *grpc.ClientConn, error) { c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithTimeout(10*time.Second), + grpc.WithTimeout(timeout), grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("unix", addr, timeout) }), diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go index 5bfb49568e6..fdcb8b705bc 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go @@ -17,192 +17,222 @@ limitations under the License. package pluginwatcher import ( - "errors" + "flag" + "fmt" "io/ioutil" - "path/filepath" - "strconv" + "os" "sync" "testing" "time" "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/util/sets" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" ) -// helper function -func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { +var ( + socketDir string + + supportedVersions = []string{"v1beta1", "v1beta2"} +) + +func init() { + var logLevel string + + flag.Set("alsologtostderr", fmt.Sprintf("%t", true)) + flag.StringVar(&logLevel, "logLevel", "6", "test") + flag.Lookup("v").Value.Set(logLevel) + + d, err := ioutil.TempDir("", "plugin_test") + if err != nil { + panic(fmt.Sprintf("Could not create a temp directory: %s", d)) + } + + socketDir = d +} + +func cleanup(t *testing.T) { + require.NoError(t, os.RemoveAll(socketDir)) + os.MkdirAll(socketDir, 0755) +} + +func TestPluginRegistration(t *testing.T) { + defer cleanup(t) + + hdlr := NewExampleHandler(supportedVersions) + w := newWatcherWithHandler(t, hdlr) + defer func() { require.NoError(t, w.Stop()) }() + + for i := 0; i < 10; i++ { + socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i) + pluginName := fmt.Sprintf("example-plugin-%d", i) + + hdlr.AddPluginName(pluginName) + + p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...) + require.NoError(t, p.Serve("v1beta1", "v1beta2")) + + require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName))) + require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName))) + + require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + + require.NoError(t, p.Stop()) + require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(p.pluginName))) + } +} + +func TestPluginReRegistration(t *testing.T) { + defer cleanup(t) + + pluginName := fmt.Sprintf("example-plugin") + hdlr := NewExampleHandler(supportedVersions) + + w := newWatcherWithHandler(t, hdlr) + defer func() { require.NoError(t, w.Stop()) }() + + plugins := make([]*examplePlugin, 10) + + for i := 0; i < 10; i++ { + socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i) + hdlr.AddPluginName(pluginName) + + p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...) + require.NoError(t, p.Serve("v1beta1", "v1beta2")) + + require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName))) + require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName))) + + require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + + plugins[i] = p + } + + plugins[len(plugins)-1].Stop() + require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(pluginName))) + + close(hdlr.EventChan(pluginName)) + for i := 0; i < len(plugins)-1; i++ { + plugins[i].Stop() + } +} + +func TestPluginRegistrationAtKubeletStart(t *testing.T) { + defer cleanup(t) + + hdlr := NewExampleHandler(supportedVersions) + plugins := make([]*examplePlugin, 10) + + for i := 0; i < len(plugins); i++ { + socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i) + pluginName := fmt.Sprintf("example-plugin-%d", i) + hdlr.AddPluginName(pluginName) + + p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...) + require.NoError(t, p.Serve("v1beta1", "v1beta2")) + defer func(p *examplePlugin) { require.NoError(t, p.Stop()) }(p) + + plugins[i] = p + } + + w := newWatcherWithHandler(t, hdlr) + defer func() { require.NoError(t, w.Stop()) }() + + var wg sync.WaitGroup + for i := 0; i < len(plugins); i++ { + wg.Add(1) + go func(p *examplePlugin) { + defer wg.Done() + + require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName))) + require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName))) + + require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + }(plugins[i]) + } + c := make(chan struct{}) go func() { defer close(c) wg.Wait() }() + select { case <-c: - return false // completed normally - case <-time.After(timeout): - return true // timed out - } -} - -func TestExamplePlugin(t *testing.T) { - rootDir, err := ioutil.TempDir("", "plugin_test") - require.NoError(t, err) - w := NewWatcher(rootDir) - h := NewExampleHandler() - w.AddHandler(registerapi.DevicePlugin, h.Handler) - - require.NoError(t, w.Start()) - - socketPath := filepath.Join(rootDir, "plugin.sock") - PluginName := "example-plugin" - - // handler expecting plugin has a non-empty endpoint - p := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "") - require.NoError(t, p.Serve(socketPath)) - require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) - require.NoError(t, p.Stop()) - - p = NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") - require.NoError(t, p.Serve(socketPath)) - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) - - // Trying to start a plugin service at the same socket path should fail - // with "bind: address already in use" - require.NotNil(t, p.Serve(socketPath)) - - // grpcServer.Stop() will remove the socket and starting plugin service - // at the same path again should succeeds and trigger another callback. - require.NoError(t, p.Stop()) - require.Nil(t, p.Serve(socketPath)) - require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) - - // Starting another plugin with the same name got verification error. - p2 := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") - socketPath2 := filepath.Join(rootDir, "plugin2.sock") - require.NoError(t, p2.Serve(socketPath2)) - require.False(t, waitForPluginRegistrationStatus(t, p2.registrationStatus)) - - // Restarts plugin watcher should traverse the socket directory and issues a - // callback for every existing socket. - require.NoError(t, w.Stop()) - require.NoError(t, h.Cleanup()) - require.NoError(t, w.Start()) - - var wg sync.WaitGroup - wg.Add(2) - var pStatus string - var p2Status string - go func() { - pStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, p.registrationStatus)) - wg.Done() - }() - go func() { - p2Status = strconv.FormatBool(waitForPluginRegistrationStatus(t, p2.registrationStatus)) - wg.Done() - }() - - if waitTimeout(&wg, 2*time.Second) { - t.Fatalf("Timed out waiting for wait group") - } - - expectedSet := sets.NewString() - expectedSet.Insert("true", "false") - actualSet := sets.NewString() - actualSet.Insert(pStatus, p2Status) - - require.Equal(t, expectedSet, actualSet) - - select { - case err := <-h.chanForHandlerAckErrors: - t.Fatalf("%v", err) + return case <-time.After(2 * time.Second): + t.Fatalf("Timeout while waiting for the plugin registration status") } - - require.NoError(t, w.Stop()) - require.NoError(t, w.Cleanup()) } -func TestPluginWithSubDir(t *testing.T) { - rootDir, err := ioutil.TempDir("", "plugin_test") - require.NoError(t, err) +func TestPluginRegistrationFailureWithUnsupportedVersion(t *testing.T) { + defer cleanup(t) - w := NewWatcher(rootDir) - hcsi := NewExampleHandler() - hdp := NewExampleHandler() + pluginName := fmt.Sprintf("example-plugin") + socketPath := socketDir + "/plugin.sock" - w.AddHandler(registerapi.CSIPlugin, hcsi.Handler) - w.AddHandler(registerapi.DevicePlugin, hdp.Handler) + hdlr := NewExampleHandler(supportedVersions) + hdlr.AddPluginName(pluginName) - err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.DevicePlugin), 0755) - require.NoError(t, err) - err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.CSIPlugin), 0755) - require.NoError(t, err) + w := newWatcherWithHandler(t, hdlr) + defer func() { require.NoError(t, w.Stop()) }() - dpSocketPath := filepath.Join(rootDir, registerapi.DevicePlugin, "plugin.sock") - csiSocketPath := filepath.Join(rootDir, registerapi.CSIPlugin, "plugin.sock") + // Advertise v1beta3 but don't serve anything else than the plugin service + p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3") + require.NoError(t, p.Serve()) + defer func() { require.NoError(t, p.Stop()) }() - require.NoError(t, w.Start()) - - // two plugins using the same name but with different type - dp := NewTestExamplePlugin("exampleplugin", registerapi.DevicePlugin, "example-endpoint") - require.NoError(t, dp.Serve(dpSocketPath)) - require.True(t, waitForPluginRegistrationStatus(t, dp.registrationStatus)) - - csi := NewTestExamplePlugin("exampleplugin", registerapi.CSIPlugin, "example-endpoint") - require.NoError(t, csi.Serve(csiSocketPath)) - require.True(t, waitForPluginRegistrationStatus(t, csi.registrationStatus)) - - // Restarts plugin watcher should traverse the socket directory and issues a - // callback for every existing socket. - require.NoError(t, w.Stop()) - require.NoError(t, hcsi.Cleanup()) - require.NoError(t, hdp.Cleanup()) - require.NoError(t, w.Start()) - - var wg sync.WaitGroup - wg.Add(2) - var dpStatus string - var csiStatus string - go func() { - dpStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, dp.registrationStatus)) - wg.Done() - }() - go func() { - csiStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, csi.registrationStatus)) - wg.Done() - }() - - if waitTimeout(&wg, 4*time.Second) { - require.NoError(t, errors.New("Timed out waiting for wait group")) - } - - expectedSet := sets.NewString() - expectedSet.Insert("true", "true") - actualSet := sets.NewString() - actualSet.Insert(dpStatus, csiStatus) - - require.Equal(t, expectedSet, actualSet) - - select { - case err := <-hcsi.chanForHandlerAckErrors: - t.Fatalf("%v", err) - case err := <-hdp.chanForHandlerAckErrors: - t.Fatalf("%v", err) - case <-time.After(4 * time.Second): - } - - require.NoError(t, w.Stop()) - require.NoError(t, w.Cleanup()) + require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName))) + require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) } -func waitForPluginRegistrationStatus(t *testing.T, statusCh chan registerapi.RegistrationStatus) bool { +func TestPlugiRegistrationFailureWithUnsupportedVersionAtKubeletStart(t *testing.T) { + defer cleanup(t) + + pluginName := fmt.Sprintf("example-plugin") + socketPath := socketDir + "/plugin.sock" + + // Advertise v1beta3 but don't serve anything else than the plugin service + p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3") + require.NoError(t, p.Serve()) + defer func() { require.NoError(t, p.Stop()) }() + + hdlr := NewExampleHandler(supportedVersions) + hdlr.AddPluginName(pluginName) + + w := newWatcherWithHandler(t, hdlr) + defer func() { require.NoError(t, w.Stop()) }() + + require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName))) + require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) +} + +func waitForPluginRegistrationStatus(t *testing.T, statusChan chan registerapi.RegistrationStatus) bool { select { - case status := <-statusCh: + case status := <-statusChan: return status.PluginRegistered case <-time.After(10 * time.Second): t.Fatalf("Timed out while waiting for registration status") } return false } + +func waitForEvent(t *testing.T, expected examplePluginEvent, eventChan chan examplePluginEvent) bool { + select { + case event := <-eventChan: + return event == expected + case <-time.After(2 * time.Second): + t.Fatalf("Timed out while waiting for registration status %v", expected) + } + + return false +} + +func newWatcherWithHandler(t *testing.T, hdlr PluginHandler) *Watcher { + w := NewWatcher(socketDir) + + w.AddHandler(registerapi.DevicePlugin, hdlr) + require.NoError(t, w.Start()) + + return w +} diff --git a/pkg/kubelet/util/pluginwatcher/types.go b/pkg/kubelet/util/pluginwatcher/types.go new file mode 100644 index 00000000000..f37ed241db3 --- /dev/null +++ b/pkg/kubelet/util/pluginwatcher/types.go @@ -0,0 +1,59 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pluginwatcher + +// PluginHandler is an interface a client of the pluginwatcher API needs to implement in +// order to consume plugins +// The PluginHandler follows the simple following state machine: +// +// +--------------------------------------+ +// | ReRegistration | +// | Socket created with same plugin name | +// | | +// | | +// Socket Created v + Socket Deleted +// +------------------> Validate +---------------------------> Register +------------------> DeRegister +// + + + +// | | | +// | Error | Error | +// | | | +// v v v +// Out Out Out +// +// The pluginwatcher module follows strictly and sequentially this state machine for each *plugin name*. +// e.g: If you are Registering a plugin foo, you cannot get a DeRegister call for plugin foo +// until the Register("foo") call returns. Nor will you get a Validate("foo", "Different endpoint", ...) +// call until the Register("foo") call returns. +// +// ReRegistration: Socket created with same plugin name, usually for a plugin update +// e.g: plugin with name foo registers at foo.com/foo-1.9.7 later a plugin with name foo +// registers at foo.com/foo-1.9.9 +// +// DeRegistration: When ReRegistration happens only the deletion of the new socket will trigger a DeRegister call + +type PluginHandler interface { + // Validate returns an error if the information provided by + // the potential plugin is erroneous (unsupported version, ...) + ValidatePlugin(pluginName string, endpoint string, versions []string) error + // RegisterPlugin is called so that the plugin can be register by any + // plugin consumer + // Error encountered here can still be Notified to the plugin. + RegisterPlugin(pluginName, endpoint string) error + // DeRegister is called once the pluginwatcher observes that the socket has + // been deleted. + DeRegisterPlugin(pluginName string) +} diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index 258362192d1..8eefb4e69d1 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -28,6 +28,7 @@ import ( "context" "github.com/golang/glog" + api "k8s.io/api/core/v1" apierrs "k8s.io/apimachinery/pkg/api/errors" meta "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -89,6 +90,10 @@ type csiDriversStore struct { sync.RWMutex } +// RegistrationHandler is the handler which is fed to the pluginwatcher API. +type RegistrationHandler struct { +} + // TODO (verult) consider using a struct instead of global variables // csiDrivers map keep track of all registered CSI drivers on the node and their // corresponding sockets @@ -96,21 +101,28 @@ var csiDrivers csiDriversStore var nodeUpdater nodeupdater.Interface -// RegistrationCallback is called by kubelet's plugin watcher upon detection +// PluginHandler is the plugin registration handler interface passed to the +// pluginwatcher module in kubelet +var PluginHandler = &RegistrationHandler{} + +// ValidatePlugin is called by kubelet's plugin watcher upon detection // of a new registration socket opened by CSI Driver registrar side car. -func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) { +func (h *RegistrationHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error { + glog.Infof(log("Trying to register a new plugin with name: %s endpoint: %s versions: %s", + pluginName, endpoint, strings.Join(versions, ","))) - glog.Infof(log("Callback from kubelet with plugin name: %s endpoint: %s versions: %s socket path: %s", - pluginName, endpoint, strings.Join(versions, ","), socketPath)) + return nil +} - if endpoint == "" { - endpoint = socketPath - } +// RegisterPlugin is called when a plugin can be registered +func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string) error { + glog.Infof(log("Register new plugin with name: %s at endpoint: %s", pluginName, endpoint)) // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key // all other CSI components will be able to get the actual socket of CSI drivers by its name. csiDrivers.Lock() defer csiDrivers.Unlock() + csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint} // Get node info from the driver. @@ -118,22 +130,27 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string, // TODO (verult) retry with exponential backoff, possibly added in csi client library. ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) defer cancel() + driverNodeID, maxVolumePerNode, _, err := csi.NodeGetInfo(ctx) if err != nil { - return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err) + return fmt.Errorf("error during CSI NodeGetInfo() call: %v", err) } // Calling nodeLabelManager to update annotations and labels for newly registered CSI driver err = nodeUpdater.AddLabelsAndLimits(pluginName, driverNodeID, maxVolumePerNode) if err != nil { // Unregister the driver and return error - csiDrivers.Lock() - defer csiDrivers.Unlock() delete(csiDrivers.driversMap, pluginName) - return nil, err + return fmt.Errorf("error while adding CSI labels: %v", err) } - return nil, nil + return nil +} + +// DeRegisterPlugin is called when a plugin removed it's socket, signaling +// it is no longer available +// TODO: Handle DeRegistration +func (h *RegistrationHandler) DeRegisterPlugin(pluginName string) { } func (p *csiPlugin) Init(host volume.VolumeHost) error {