diff --git a/hack/.golint_failures b/hack/.golint_failures index 7bbfcddb20f..833a513222f 100644 --- a/hack/.golint_failures +++ b/hack/.golint_failures @@ -383,7 +383,6 @@ pkg/volume/azure_dd pkg/volume/azure_file pkg/volume/cephfs pkg/volume/configmap -pkg/volume/csi pkg/volume/csi/fake pkg/volume/csi/labelmanager pkg/volume/empty_dir diff --git a/pkg/kubelet/util/pluginwatcher/BUILD b/pkg/kubelet/util/pluginwatcher/BUILD index 0b62c3a4658..7b887b444f5 100644 --- a/pkg/kubelet/util/pluginwatcher/BUILD +++ b/pkg/kubelet/util/pluginwatcher/BUILD @@ -9,6 +9,7 @@ load( go_library( name = "go_default_library", srcs = [ + "example_handler.go", "example_plugin.go", "plugin_watcher.go", ], @@ -20,6 +21,7 @@ go_library( "//pkg/util/filesystem:go_default_library", "//vendor/github.com/fsnotify/fsnotify:go_default_library", "//vendor/github.com/golang/glog:go_default_library", + "//vendor/github.com/pkg/errors:go_default_library", "//vendor/golang.org/x/net/context:go_default_library", "//vendor/google.golang.org/grpc:go_default_library", ], @@ -49,10 +51,7 @@ go_test( embed = [":go_default_library"], deps = [ "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", - "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library", - "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library", "//vendor/github.com/stretchr/testify/require:go_default_library", - "//vendor/golang.org/x/net/context:go_default_library", ], ) diff --git a/pkg/kubelet/util/pluginwatcher/README b/pkg/kubelet/util/pluginwatcher/README index 9654b2cf62a..c8b6cc28440 100644 --- a/pkg/kubelet/util/pluginwatcher/README +++ b/pkg/kubelet/util/pluginwatcher/README @@ -13,17 +13,22 @@ communication with any API version supported by the plugin. Here are the general rules that Kubelet plugin developers should follow: - Run as 'root' user. Currently creating socket under PluginsSockDir, a root owned directory, requires plugin process to be running as 'root'. + - Implements the Registration service specified in pkg/kubelet/apis/pluginregistration/v*/api.proto. + - The plugin name sent during Registration.GetInfo grpc should be unique for the given plugin type (CSIPlugin or DevicePlugin). -- The socket path needs to be unique and doesn't conflict with the path chosen - by any other potential plugins. Currently we only support flat fs namespace - under PluginsSockDir but will soon support recursive inotify watch for - hierarchical socket paths. + +- The socket path needs to be unique within one directory, in normal case, + each plugin type has its own sub directory, but the design does support socket file + under any sub directory of PluginSockDir. + - A plugin should clean up its own socket upon exiting or when a new instance comes up. A plugin should NOT remove any sockets belonging to other plugins. + - A plugin should make sure it has service ready for any supported service API version listed in the PluginInfo. + - For an example plugin implementation, take a look at example_plugin.go included in this directory. diff --git a/pkg/kubelet/util/pluginwatcher/example_handler.go b/pkg/kubelet/util/pluginwatcher/example_handler.go new file mode 100644 index 00000000000..4eae4188d69 --- /dev/null +++ b/pkg/kubelet/util/pluginwatcher/example_handler.go @@ -0,0 +1,105 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pluginwatcher + +import ( + "errors" + "fmt" + "reflect" + "sync" + "time" + + "golang.org/x/net/context" + + v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" + v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" +) + +type exampleHandler struct { + registeredPlugins map[string]struct{} + mutex sync.Mutex + chanForHandlerAckErrors chan error // for testing +} + +// NewExampleHandler provide a example handler +func NewExampleHandler() *exampleHandler { + return &exampleHandler{ + chanForHandlerAckErrors: make(chan error), + registeredPlugins: make(map[string]struct{}), + } +} + +func (h *exampleHandler) Cleanup() error { + h.mutex.Lock() + defer h.mutex.Unlock() + h.registeredPlugins = make(map[string]struct{}) + return nil +} + +func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []string, sockPath string) (chan bool, error) { + + // check for supported versions + if !reflect.DeepEqual([]string{"v1beta1", "v1beta2"}, versions) { + return nil, fmt.Errorf("not the supported versions: %s", versions) + } + + // this handler expects non-empty endpoint as an example + if len(endpoint) == 0 { + return nil, errors.New("expecting non empty endpoint") + } + + _, conn, err := dial(sockPath) + if err != nil { + return nil, err + } + defer conn.Close() + + // The plugin handler should be able to use any listed service API version. + v1beta1Client := v1beta1.NewExampleClient(conn) + v1beta2Client := v1beta2.NewExampleClient(conn) + + // Tests v1beta1 GetExampleInfo + if _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}); err != nil { + return nil, err + } + + // Tests v1beta2 GetExampleInfo + if _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}); err != nil { + return nil, err + } + + // handle registered plugin + h.mutex.Lock() + if _, exist := h.registeredPlugins[pluginName]; exist { + h.mutex.Unlock() + return nil, fmt.Errorf("plugin %s already registered", pluginName) + } + h.registeredPlugins[pluginName] = struct{}{} + h.mutex.Unlock() + + chanForAckOfNotification := make(chan bool) + go func() { + select { + case <-chanForAckOfNotification: + // TODO: handle the negative scenario + close(chanForAckOfNotification) + case <-time.After(time.Second): + h.chanForHandlerAckErrors <- errors.New("Timed out while waiting for notification ack") + } + }() + return chanForAckOfNotification, nil +} diff --git a/pkg/kubelet/util/pluginwatcher/example_plugin.go b/pkg/kubelet/util/pluginwatcher/example_plugin.go index fbca43acad5..5c2dd966ba4 100644 --- a/pkg/kubelet/util/pluginwatcher/example_plugin.go +++ b/pkg/kubelet/util/pluginwatcher/example_plugin.go @@ -17,7 +17,7 @@ limitations under the License. package pluginwatcher import ( - "fmt" + "errors" "net" "sync" "time" @@ -31,17 +31,14 @@ import ( v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" ) -const ( - PluginName = "example-plugin" - PluginType = "example-plugin-type" -) - // examplePlugin is a sample plugin to work with plugin watcher type examplePlugin struct { grpcServer *grpc.Server wg sync.WaitGroup registrationStatus chan registerapi.RegistrationStatus // for testing endpoint string // for testing + pluginName string + pluginType string } type pluginServiceV1Beta1 struct { @@ -76,8 +73,10 @@ func NewExamplePlugin() *examplePlugin { } // NewTestExamplePlugin returns an initialized examplePlugin instance for testing -func NewTestExamplePlugin(endpoint string) *examplePlugin { +func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string) *examplePlugin { return &examplePlugin{ + pluginName: pluginName, + pluginType: pluginType, registrationStatus: make(chan registerapi.RegistrationStatus), endpoint: endpoint, } @@ -86,8 +85,8 @@ func NewTestExamplePlugin(endpoint string) *examplePlugin { // GetInfo is the RPC invoked by plugin watcher func (e *examplePlugin) GetInfo(ctx context.Context, req *registerapi.InfoRequest) (*registerapi.PluginInfo, error) { return ®isterapi.PluginInfo{ - Type: PluginType, - Name: PluginName, + Type: e.pluginType, + Name: e.pluginName, Endpoint: e.endpoint, SupportedVersions: []string{"v1beta1", "v1beta2"}, }, nil @@ -145,6 +144,6 @@ func (e *examplePlugin) Stop() error { return nil case <-time.After(time.Second): glog.Errorf("Timed out on waiting for stop completion") - return fmt.Errorf("Timed out on waiting for stop completion") + return errors.New("Timed out on waiting for stop completion") } } diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go index 9a5241cb2e5..6db743dd4fb 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go @@ -20,13 +20,12 @@ import ( "fmt" "net" "os" - "path" - "path/filepath" "sync" "time" "github.com/fsnotify/fsnotify" "github.com/golang/glog" + "github.com/pkg/errors" "golang.org/x/net/context" "google.golang.org/grpc" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" @@ -34,17 +33,17 @@ import ( ) // RegisterCallbackFn is the type of the callback function that handlers will provide -type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (error, chan bool) +type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) // Watcher is the plugin watcher type Watcher struct { - path string - handlers map[string]RegisterCallbackFn - stopCh chan interface{} - fs utilfs.Filesystem - watcher *fsnotify.Watcher - wg sync.WaitGroup - mutex sync.Mutex + path string + handlers map[string]RegisterCallbackFn + stopCh chan interface{} + fs utilfs.Filesystem + fsWatcher *fsnotify.Watcher + wg sync.WaitGroup + mutex sync.Mutex } // NewWatcher provides a new watcher @@ -57,40 +56,45 @@ func NewWatcher(sockDir string) Watcher { } // AddHandler registers a callback to be invoked for a particular type of plugin -func (w *Watcher) AddHandler(handlerType string, handlerCbkFn RegisterCallbackFn) { +func (w *Watcher) AddHandler(pluginType string, handlerCbkFn RegisterCallbackFn) { w.mutex.Lock() defer w.mutex.Unlock() - w.handlers[handlerType] = handlerCbkFn + w.handlers[pluginType] = handlerCbkFn } // Creates the plugin directory, if it doesn't already exist. func (w *Watcher) createPluginDir() error { glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path) if err := w.fs.MkdirAll(w.path, 0755); err != nil { - return fmt.Errorf("error (re-)creating driver directory: %s", err) + return fmt.Errorf("error (re-)creating root %s: %v", w.path, err) } + return nil } -// Walks through the plugin directory to discover any existing plugin sockets. -func (w *Watcher) traversePluginDir() error { - files, err := w.fs.ReadDir(w.path) - if err != nil { - return fmt.Errorf("error reading the plugin directory: %v", err) - } - for _, f := range files { - // Currently only supports flat fs namespace under the plugin directory. - // TODO: adds support for hierarchical fs namespace. - if !f.IsDir() && filepath.Base(f.Name())[0] != '.' { - go func(sockName string) { - w.watcher.Events <- fsnotify.Event{ - Name: sockName, - Op: fsnotify.Op(uint32(1)), - } - }(path.Join(w.path, f.Name())) +// Walks through the plugin directory discover any existing plugin sockets. +func (w *Watcher) traversePluginDir(dir string) error { + return w.fs.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("error accessing path: %s error: %v", path, err) } - } - return nil + + switch mode := info.Mode(); { + case mode.IsDir(): + if err := w.fsWatcher.Add(path); err != nil { + return fmt.Errorf("failed to watch %s, err: %v", path, err) + } + case mode&os.ModeSocket != 0: + go func() { + w.fsWatcher.Events <- fsnotify.Event{ + Name: path, + Op: fsnotify.Create, + } + }() + } + + return nil + }) } func (w *Watcher) init() error { @@ -102,7 +106,6 @@ func (w *Watcher) init() error { func (w *Watcher) registerPlugin(socketPath string) error { //TODO: Implement rate limiting to mitigate any DOS kind of attacks. - glog.V(4).Infof("registerPlugin called for socketPath: %s", socketPath) client, conn, err := dial(socketPath) if err != nil { return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err) @@ -115,11 +118,8 @@ func (w *Watcher) registerPlugin(socketPath string) error { if err != nil { return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err) } - if err := w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath); err != nil { - return fmt.Errorf("failed to register plugin. Callback handler returned err: %v", err) - } - glog.V(4).Infof("Successfully registered plugin for plugin type: %s, name: %s, socket: %s", infoResp.Type, infoResp.Name, socketPath) - return nil + + return w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath) } func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, client registerapi.RegistrationClient, infoResp *registerapi.PluginInfo, socketPath string) error { @@ -127,13 +127,14 @@ func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, clien var ok bool handlerCbkFn, ok = w.handlers[infoResp.Type] if !ok { + errStr := fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath) if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: false, - Error: fmt.Sprintf("No handler found registered for plugin type: %s, socket: %s", infoResp.Type, socketPath), + Error: errStr, }); err != nil { - glog.Errorf("Failed to send registration status at socket %s, err: %v", socketPath, err) + return errors.Wrap(err, errStr) } - return fmt.Errorf("no handler found registered for plugin type: %s, socket: %s", infoResp.Type, socketPath) + return errors.New(errStr) } var versions []string @@ -141,27 +142,51 @@ func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, clien versions = append(versions, version) } // calls handler callback to verify registration request - err, chanForAckOfNotification := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) + chanForAckOfNotification, err := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) if err != nil { + errStr := fmt.Sprintf("plugin registration failed with err: %v", err) if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: false, - Error: fmt.Sprintf("Plugin registration failed with err: %v", err), + Error: errStr, }); err != nil { - glog.Errorf("Failed to send registration status at socket %s, err: %v", socketPath, err) + return errors.Wrap(err, errStr) } - chanForAckOfNotification <- false - return fmt.Errorf("plugin registration failed with err: %v", err) + return errors.New(errStr) } if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: true, }); err != nil { + chanForAckOfNotification <- false return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err) } + chanForAckOfNotification <- true return nil } +// Handle filesystem notify event. +func (w *Watcher) handleFsNotifyEvent(event fsnotify.Event) error { + if event.Op&fsnotify.Create != fsnotify.Create { + return nil + } + + fi, err := os.Stat(event.Name) + if err != nil { + return fmt.Errorf("stat file %s failed: %v", event.Name, err) + } + + if !fi.IsDir() { + return w.registerPlugin(event.Name) + } + + if err := w.traversePluginDir(event.Name); err != nil { + return fmt.Errorf("failed to traverse plugin path %s, err: %v", event.Name, err) + } + + return nil +} + // Start watches for the creation of plugin sockets at the path func (w *Watcher) Start() error { glog.V(2).Infof("Plugin Watcher Start at %s", w.path) @@ -173,52 +198,42 @@ func (w *Watcher) Start() error { return err } - watcher, err := fsnotify.NewWatcher() + fsWatcher, err := fsnotify.NewWatcher() if err != nil { - return fmt.Errorf("failed to start plugin watcher, err: %v", err) + return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err) } + w.fsWatcher = fsWatcher - if err := watcher.Add(w.path); err != nil { - watcher.Close() - return fmt.Errorf("failed to start plugin watcher, err: %v", err) - } - - w.watcher = watcher - - if err := w.traversePluginDir(); err != nil { - watcher.Close() + if err := w.traversePluginDir(w.path); err != nil { + fsWatcher.Close() return fmt.Errorf("failed to traverse plugin socket path, err: %v", err) } w.wg.Add(1) - go func(watcher *fsnotify.Watcher) { + go func(fsWatcher *fsnotify.Watcher) { defer w.wg.Done() for { select { - case event := <-watcher.Events: - if event.Op&fsnotify.Create == fsnotify.Create { - go func(eventName string) { - err := w.registerPlugin(eventName) - if err != nil { - glog.Errorf("Plugin %s registration failed with error: %v", eventName, err) - } - }(event.Name) - } - continue - case err := <-watcher.Errors: + case event := <-fsWatcher.Events: //TODO: Handle errors by taking corrective measures + go func() { + err := w.handleFsNotifyEvent(event) + if err != nil { + glog.Errorf("error %v when handle event: %s", err, event) + } + }() + continue + case err := <-fsWatcher.Errors: if err != nil { - glog.Errorf("Watcher received error: %v", err) + glog.Errorf("fsWatcher received error: %v", err) } continue - case <-w.stopCh: - watcher.Close() - break + fsWatcher.Close() + return } - break } - }(watcher) + }(fsWatcher) return nil } diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go index 44bccf9a6f3..5bfb49568e6 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go @@ -17,135 +17,56 @@ limitations under the License. package pluginwatcher import ( - "fmt" + "errors" "io/ioutil" + "path/filepath" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/require" - "golang.org/x/net/context" "k8s.io/apimachinery/pkg/util/sets" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" - v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" - v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" ) -func TestExamplePlugin(t *testing.T) { - socketDir, err := ioutil.TempDir("", "plugin_test") - require.NoError(t, err) - socketPath := socketDir + "/plugin.sock" - w := NewWatcher(socketDir) - - testCases := []struct { - description string - expectedEndpoint string - returnErr error - }{ - { - description: "Successfully register plugin through inotify", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Successfully register plugin through inotify and got expected optional endpoint", - expectedEndpoint: "dummyEndpoint", - returnErr: nil, - }, - { - description: "Fails registration because endpoint is expected to be non-empty", - expectedEndpoint: "dummyEndpoint", - returnErr: fmt.Errorf("empty endpoint received"), - }, - { - description: "Successfully register plugin through inotify after plugin restarts", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Fails registration with conflicting plugin name", - expectedEndpoint: "", - returnErr: fmt.Errorf("conflicting plugin name"), - }, - { - description: "Successfully register plugin during initial traverse after plugin watcher restarts", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Fails registration with conflicting plugin name during initial traverse after plugin watcher restarts", - expectedEndpoint: "", - returnErr: fmt.Errorf("conflicting plugin name"), - }, +// helper function +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false // completed normally + case <-time.After(timeout): + return true // timed out } +} - callbackCount := struct { - mutex sync.Mutex - count int32 - }{} - w.AddHandler(PluginType, func(name string, endpoint string, versions []string, sockPath string) (error, chan bool) { - callbackCount.mutex.Lock() - localCount := callbackCount.count - callbackCount.count = callbackCount.count + 1 - callbackCount.mutex.Unlock() +func TestExamplePlugin(t *testing.T) { + rootDir, err := ioutil.TempDir("", "plugin_test") + require.NoError(t, err) + w := NewWatcher(rootDir) + h := NewExampleHandler() + w.AddHandler(registerapi.DevicePlugin, h.Handler) - require.True(t, localCount <= int32((len(testCases)-1))) - require.Equal(t, PluginName, name, "Plugin name mismatched!!") - retError := testCases[localCount].returnErr - if retError == nil || retError.Error() != "empty endpoint received" { - require.Equal(t, testCases[localCount].expectedEndpoint, endpoint, "Unexpected endpoint") - } else { - require.NotEqual(t, testCases[localCount].expectedEndpoint, endpoint, "Unexpected endpoint") - } - - require.Equal(t, []string{"v1beta1", "v1beta2"}, versions, "Plugin version mismatched!!") - // Verifies the grpcServer is ready to serve services. - _, conn, err := dial(sockPath) - require.Nil(t, err) - defer conn.Close() - - // The plugin handler should be able to use any listed service API version. - v1beta1Client := v1beta1.NewExampleClient(conn) - v1beta2Client := v1beta2.NewExampleClient(conn) - - // Tests v1beta1 GetExampleInfo - _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}) - require.Nil(t, err) - - // Tests v1beta1 GetExampleInfo - _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}) - //atomic.AddInt32(&callbackCount, 1) - chanForAckOfNotification := make(chan bool) - - go func() { - select { - case <-chanForAckOfNotification: - close(chanForAckOfNotification) - case <-time.After(time.Second): - t.Fatalf("Timed out while waiting for notification ack") - } - }() - return retError, chanForAckOfNotification - }) require.NoError(t, w.Start()) - p := NewTestExamplePlugin("") - require.NoError(t, p.Serve(socketPath)) - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + socketPath := filepath.Join(rootDir, "plugin.sock") + PluginName := "example-plugin" - require.NoError(t, p.Stop()) - - p = NewTestExamplePlugin("dummyEndpoint") - require.NoError(t, p.Serve(socketPath)) - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) - - require.NoError(t, p.Stop()) - - p = NewTestExamplePlugin("") + // handler expecting plugin has a non-empty endpoint + p := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "") require.NoError(t, p.Serve(socketPath)) require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + require.NoError(t, p.Stop()) + + p = NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") + require.NoError(t, p.Serve(socketPath)) + require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) // Trying to start a plugin service at the same socket path should fail // with "bind: address already in use" @@ -154,27 +75,20 @@ func TestExamplePlugin(t *testing.T) { // grpcServer.Stop() will remove the socket and starting plugin service // at the same path again should succeeds and trigger another callback. require.NoError(t, p.Stop()) - p = NewTestExamplePlugin("") - go func() { - require.Nil(t, p.Serve(socketPath)) - }() - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + require.Nil(t, p.Serve(socketPath)) + require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) // Starting another plugin with the same name got verification error. - p2 := NewTestExamplePlugin("") - socketPath2 := socketDir + "/plugin2.sock" - go func() { - require.NoError(t, p2.Serve(socketPath2)) - }() + p2 := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") + socketPath2 := filepath.Join(rootDir, "plugin2.sock") + require.NoError(t, p2.Serve(socketPath2)) require.False(t, waitForPluginRegistrationStatus(t, p2.registrationStatus)) // Restarts plugin watcher should traverse the socket directory and issues a // callback for every existing socket. require.NoError(t, w.Stop()) - errCh := make(chan error) - go func() { - errCh <- w.Start() - }() + require.NoError(t, h.Cleanup()) + require.NoError(t, w.Start()) var wg sync.WaitGroup wg.Add(2) @@ -188,7 +102,11 @@ func TestExamplePlugin(t *testing.T) { p2Status = strconv.FormatBool(waitForPluginRegistrationStatus(t, p2.registrationStatus)) wg.Done() }() - wg.Wait() + + if waitTimeout(&wg, 2*time.Second) { + t.Fatalf("Timed out waiting for wait group") + } + expectedSet := sets.NewString() expectedSet.Insert("true", "false") actualSet := sets.NewString() @@ -197,16 +115,86 @@ func TestExamplePlugin(t *testing.T) { require.Equal(t, expectedSet, actualSet) select { - case err = <-errCh: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatalf("Timed out while waiting for watcher start") - + case err := <-h.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case <-time.After(2 * time.Second): } require.NoError(t, w.Stop()) - err = w.Cleanup() + require.NoError(t, w.Cleanup()) +} + +func TestPluginWithSubDir(t *testing.T) { + rootDir, err := ioutil.TempDir("", "plugin_test") require.NoError(t, err) + + w := NewWatcher(rootDir) + hcsi := NewExampleHandler() + hdp := NewExampleHandler() + + w.AddHandler(registerapi.CSIPlugin, hcsi.Handler) + w.AddHandler(registerapi.DevicePlugin, hdp.Handler) + + err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.DevicePlugin), 0755) + require.NoError(t, err) + err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.CSIPlugin), 0755) + require.NoError(t, err) + + dpSocketPath := filepath.Join(rootDir, registerapi.DevicePlugin, "plugin.sock") + csiSocketPath := filepath.Join(rootDir, registerapi.CSIPlugin, "plugin.sock") + + require.NoError(t, w.Start()) + + // two plugins using the same name but with different type + dp := NewTestExamplePlugin("exampleplugin", registerapi.DevicePlugin, "example-endpoint") + require.NoError(t, dp.Serve(dpSocketPath)) + require.True(t, waitForPluginRegistrationStatus(t, dp.registrationStatus)) + + csi := NewTestExamplePlugin("exampleplugin", registerapi.CSIPlugin, "example-endpoint") + require.NoError(t, csi.Serve(csiSocketPath)) + require.True(t, waitForPluginRegistrationStatus(t, csi.registrationStatus)) + + // Restarts plugin watcher should traverse the socket directory and issues a + // callback for every existing socket. + require.NoError(t, w.Stop()) + require.NoError(t, hcsi.Cleanup()) + require.NoError(t, hdp.Cleanup()) + require.NoError(t, w.Start()) + + var wg sync.WaitGroup + wg.Add(2) + var dpStatus string + var csiStatus string + go func() { + dpStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, dp.registrationStatus)) + wg.Done() + }() + go func() { + csiStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, csi.registrationStatus)) + wg.Done() + }() + + if waitTimeout(&wg, 4*time.Second) { + require.NoError(t, errors.New("Timed out waiting for wait group")) + } + + expectedSet := sets.NewString() + expectedSet.Insert("true", "true") + actualSet := sets.NewString() + actualSet.Insert(dpStatus, csiStatus) + + require.Equal(t, expectedSet, actualSet) + + select { + case err := <-hcsi.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case err := <-hdp.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case <-time.After(4 * time.Second): + } + + require.NoError(t, w.Stop()) + require.NoError(t, w.Cleanup()) } func waitForPluginRegistrationStatus(t *testing.T, statusCh chan registerapi.RegistrationStatus) bool { diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index 85d513d7c42..60f04a2419b 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -84,7 +84,7 @@ var lm labelmanager.Interface // RegistrationCallback is called by kubelet's plugin watcher upon detection // of a new registration socket opened by CSI Driver registrar side car. -func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (error, chan bool) { +func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) { glog.Infof(log("Callback from kubelet with plugin name: %s endpoint: %s versions: %s socket path: %s", pluginName, endpoint, strings.Join(versions, ","), socketPath)) @@ -95,7 +95,7 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string, // Calling nodeLabelManager to update label for newly registered CSI driver err := lm.AddLabels(pluginName) if err != nil { - return err, nil + return nil, err } // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key // all other CSI components will be able to get the actual socket of CSI drivers by its name.