Merge pull request #64621 from RenaudWasTaken/pluginwatcher

Automatic merge from submit-queue (batch tested with PRs 68087, 68256, 64621, 68299, 68296). If you want to cherry-pick this change to another branch, please follow the instructions here: https://github.com/kubernetes/community/blob/master/contributors/devel/cherry-picks.md.

Change plugin watcher registration mechanism

**Which issue(s) this PR fixes**: #64637

**Notes For Reviewers**:
The current API the plugin watcher exposes to kubelet is the following:
```golang
type RegisterCallbackFn func(pluginName string, endpoint string,
                             versions []string, socketPath string) (error, chan bool)	
```

The callback channel is here to signal the plugin watcher consumer when the plugin watcher API has notified the plugin of it's successful registration.
In other words the current lifecycle of a plugin is the following:
```
(pluginwatcher) GetInfo -> (pluginwatcher) NotifyRegistrationStatus -> (deviceplugin) ListWatch
```
Rather than
```
(pluginwatcher) GetInfo (race) -> (pluginwatcher) NotifyRegistrationStatus
                        (race) -> (deviceplugin) ListWatch
```

This PR changes the callback/channel mechanism to a more explicit, interfaced based contract (and more maintainable than a function to which we add more channels for more lifecycle events).

This PR also introduces three new states: {Init, Register, DeRegister}
```golang
// PluginHandler is an interface a client of the pluginwatcher API needs to implement in
// order to consume plugins
// The PluginHandler follows the simple following state machine:
//
//                         +--------------------------------------+
//                         |            ReRegistration            |
//                         | Socket created with same plugin name |
//                         |                                      |
//                         |                                      |
//    Socket Created       v                                      +        Socket Deleted
// +------------------> Validate +----------> Init +---------> Register +------------------> DeRegister
//                         +                   +                                                +
//                         |                   |                                                |
//                         | Error             | Error                                          |
//                         |                   |                                                |
//                         v                   v                                                v
//                        Out                 Out                                              Out
//
// The pluginwatcher module follows strictly and sequentially this state machine for each *plugin name*.
// e.g: If you are Registering a plugin foo, you cannot get a DeRegister call for plugin foo
//      until the Register("foo") call returns. Nor will you get a Validate("foo", "Different endpoint", ...)
//      call until the Register("foo") call returns.
//
// ReRegistration: Socket created with same plugin name, usually for a plugin update
// e.g: plugin with name foo registers at foo.com/foo-1.9.7 later a plugin with name foo
//      registers at foo.com/foo-1.9.9
//
// DeRegistration: When ReRegistration happens only the deletion of the new socket will trigger a DeRegister call

type PluginHandler interface {
        // Validate returns an error if the information provided by
        // the potential plugin is erroneous (unsupported version, ...)
        ValidatePlugin(pluginName string, endpoint string, versions []string) error
        // Init starts the plugin (e.g: contact the gRPC client, gets plugin
        // specific information, ...) but if another plugin with the same name
        // exists does not switch to the newer one.
        // Any error encountered here can still be Notified to the plugin.
        InitPlugin(pluginName string, endpoint string) error
        // Register is called once the pluginwatcher has notified the plugin
        // of its successful registration.
        // Errors at this point can no longer be bubbled up to the plugin
        RegisterPlugin(pluginName, endpoint string)
        // DeRegister is called once the pluginwatcher observes that the socket has
        // been deleted.
        DeRegisterPlugin(pluginName string)
}
```

```release-note
NONE
```
/sig node
/area hw-accelerators

/cc @jiayingz @vikaschoudhary16 @vishh @vladimirvivien @sbezverk @figo (ccing the main reviewers of the original PR, feel free to cc more people)
This commit is contained in:
Kubernetes Submit Queue 2018-09-06 14:49:39 -07:00 committed by GitHub
commit 4da3bdc4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 809 additions and 501 deletions

View File

@ -95,7 +95,11 @@ type ContainerManager interface {
// GetPodCgroupRoot returns the cgroup which contains all pods. // GetPodCgroupRoot returns the cgroup which contains all pods.
GetPodCgroupRoot() string GetPodCgroupRoot() string
GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn
// GetPluginRegistrationHandler returns a plugin registration handler
// The pluginwatcher's Handlers allow to have a single module for handling
// registration.
GetPluginRegistrationHandler() pluginwatcher.PluginHandler
} }
type NodeConfig struct { type NodeConfig struct {

View File

@ -605,8 +605,8 @@ func (cm *containerManagerImpl) Start(node *v1.Node,
return nil return nil
} }
func (cm *containerManagerImpl) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn { func (cm *containerManagerImpl) GetPluginRegistrationHandler() pluginwatcher.PluginHandler {
return cm.deviceManager.GetWatcherCallback() return cm.deviceManager.GetWatcherHandler()
} }
// TODO: move the GetResources logic to PodContainerManager. // TODO: move the GetResources logic to PodContainerManager.

View File

@ -77,10 +77,8 @@ func (cm *containerManagerStub) GetCapacity() v1.ResourceList {
return c return c
} }
func (cm *containerManagerStub) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn { func (cm *containerManagerStub) GetPluginRegistrationHandler() pluginwatcher.PluginHandler {
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { return nil
return nil, nil
}
} }
func (cm *containerManagerStub) GetDevicePluginResourceCapacity() (v1.ResourceList, v1.ResourceList, []string) { func (cm *containerManagerStub) GetDevicePluginResourceCapacity() (v1.ResourceList, v1.ResourceList, []string) {

View File

@ -56,7 +56,7 @@ type ManagerImpl struct {
socketname string socketname string
socketdir string socketdir string
endpoints map[string]endpoint // Key is ResourceName endpoints map[string]endpointInfo // Key is ResourceName
mutex sync.Mutex mutex sync.Mutex
server *grpc.Server server *grpc.Server
@ -86,10 +86,14 @@ type ManagerImpl struct {
// podDevices contains pod to allocated device mapping. // podDevices contains pod to allocated device mapping.
podDevices podDevices podDevices podDevices
pluginOpts map[string]*pluginapi.DevicePluginOptions
checkpointManager checkpointmanager.CheckpointManager checkpointManager checkpointmanager.CheckpointManager
} }
type endpointInfo struct {
e endpoint
opts *pluginapi.DevicePluginOptions
}
type sourcesReadyStub struct{} type sourcesReadyStub struct{}
func (s *sourcesReadyStub) AddSource(source string) {} func (s *sourcesReadyStub) AddSource(source string) {}
@ -109,13 +113,13 @@ func newManagerImpl(socketPath string) (*ManagerImpl, error) {
dir, file := filepath.Split(socketPath) dir, file := filepath.Split(socketPath)
manager := &ManagerImpl{ manager := &ManagerImpl{
endpoints: make(map[string]endpoint), endpoints: make(map[string]endpointInfo),
socketname: file, socketname: file,
socketdir: dir, socketdir: dir,
healthyDevices: make(map[string]sets.String), healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String),
pluginOpts: make(map[string]*pluginapi.DevicePluginOptions),
podDevices: make(podDevices), podDevices: make(podDevices),
} }
manager.callback = manager.genericDeviceUpdateCallback manager.callback = manager.genericDeviceUpdateCallback
@ -228,8 +232,8 @@ func (m *ManagerImpl) Start(activePods ActivePodsFunc, sourcesReady config.Sourc
return nil return nil
} }
// GetWatcherCallback returns callback function to be registered with plugin watcher // GetWatcherHandler returns the plugin handler
func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn { func (m *ManagerImpl) GetWatcherHandler() watcher.PluginHandler {
if f, err := os.Create(m.socketdir + "DEPRECATION"); err != nil { if f, err := os.Create(m.socketdir + "DEPRECATION"); err != nil {
glog.Errorf("Failed to create deprecation file at %s", m.socketdir) glog.Errorf("Failed to create deprecation file at %s", m.socketdir)
} else { } else {
@ -237,16 +241,57 @@ func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn {
glog.V(4).Infof("created deprecation file %s", f.Name()) glog.V(4).Infof("created deprecation file %s", f.Name())
} }
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { return watcher.PluginHandler(m)
}
// ValidatePlugin validates a plugin if the version is correct and the name has the format of an extended resource
func (m *ManagerImpl) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
glog.V(2).Infof("Got Plugin %s at endpoint %s with versions %v", pluginName, endpoint, versions)
if !m.isVersionCompatibleWithPlugin(versions) { if !m.isVersionCompatibleWithPlugin(versions) {
return nil, fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions) return fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions)
} }
if !v1helper.IsExtendedResourceName(v1.ResourceName(name)) { if !v1helper.IsExtendedResourceName(v1.ResourceName(pluginName)) {
return nil, fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, name)) return fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, pluginName))
} }
return m.addEndpointProbeMode(name, sockPath) return nil
}
// RegisterPlugin starts the endpoint and registers it
// TODO: Start the endpoint and wait for the First ListAndWatch call
// before registering the plugin
func (m *ManagerImpl) RegisterPlugin(pluginName string, endpoint string) error {
glog.V(2).Infof("Registering Plugin %s at endpoint %s", pluginName, endpoint)
e, err := newEndpointImpl(endpoint, pluginName, m.callback)
if err != nil {
return fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", endpoint, err)
}
options, err := e.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{})
if err != nil {
return fmt.Errorf("Failed to get device plugin options: %v", err)
}
m.registerEndpoint(pluginName, options, e)
go m.runEndpoint(pluginName, e)
return nil
}
// DeRegisterPlugin deregisters the plugin
// TODO work on the behavior for deregistering plugins
// e.g: Should we delete the resource
func (m *ManagerImpl) DeRegisterPlugin(pluginName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
// Note: This will mark the resource unhealthy as per the behavior
// in runEndpoint
if eI, ok := m.endpoints[pluginName]; ok {
eI.e.stop()
} }
} }
@ -333,8 +378,8 @@ func (m *ManagerImpl) Register(ctx context.Context, r *pluginapi.RegisterRequest
func (m *ManagerImpl) Stop() error { func (m *ManagerImpl) Stop() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
for _, e := range m.endpoints { for _, eI := range m.endpoints {
e.stop() eI.e.stop()
} }
if m.server == nil { if m.server == nil {
@ -346,51 +391,26 @@ func (m *ManagerImpl) Stop() error {
return nil return nil
} }
func (m *ManagerImpl) addEndpointProbeMode(resourceName string, socketPath string) (chan bool, error) { func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e endpoint) {
chanForAckOfNotification := make(chan bool)
new, err := newEndpointImpl(socketPath, resourceName, m.callback)
if err != nil {
glog.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err)
return nil, fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err)
}
options, err := new.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{})
if err != nil {
glog.Errorf("Failed to get device plugin options: %v", err)
return nil, fmt.Errorf("Failed to get device plugin options: %v", err)
}
m.registerEndpoint(resourceName, options, new)
go func() {
select {
case <-chanForAckOfNotification:
close(chanForAckOfNotification)
m.runEndpoint(resourceName, new)
case <-time.After(time.Second):
glog.Errorf("Timed out while waiting for notification ack from plugin")
}
}()
return chanForAckOfNotification, nil
}
func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e *endpointImpl) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.pluginOpts[resourceName] = options
m.endpoints[resourceName] = e m.endpoints[resourceName] = endpointInfo{e: e, opts: options}
glog.V(2).Infof("Registered endpoint %v", e) glog.V(2).Infof("Registered endpoint %v", e)
} }
func (m *ManagerImpl) runEndpoint(resourceName string, e *endpointImpl) { func (m *ManagerImpl) runEndpoint(resourceName string, e endpoint) {
e.run() e.run()
e.stop() e.stop()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if old, ok := m.endpoints[resourceName]; ok && old == e {
if old, ok := m.endpoints[resourceName]; ok && old.e == e {
m.markResourceUnhealthy(resourceName) m.markResourceUnhealthy(resourceName)
} }
glog.V(2).Infof("Unregistered endpoint %v", e)
glog.V(2).Infof("Endpoint (%s, %v) became unhealthy", resourceName, e)
} }
func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) { func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) {
@ -437,8 +457,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
deletedResources := sets.NewString() deletedResources := sets.NewString()
m.mutex.Lock() m.mutex.Lock()
for resourceName, devices := range m.healthyDevices { for resourceName, devices := range m.healthyDevices {
e, ok := m.endpoints[resourceName] eI, ok := m.endpoints[resourceName]
if (ok && e.stopGracePeriodExpired()) || !ok { if (ok && eI.e.stopGracePeriodExpired()) || !ok {
// The resources contained in endpoints and (un)healthyDevices // The resources contained in endpoints and (un)healthyDevices
// should always be consistent. Otherwise, we run with the risk // should always be consistent. Otherwise, we run with the risk
// of failing to garbage collect non-existing resources or devices. // of failing to garbage collect non-existing resources or devices.
@ -455,8 +475,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
} }
} }
for resourceName, devices := range m.unhealthyDevices { for resourceName, devices := range m.unhealthyDevices {
e, ok := m.endpoints[resourceName] eI, ok := m.endpoints[resourceName]
if (ok && e.stopGracePeriodExpired()) || !ok { if (ok && eI.e.stopGracePeriodExpired()) || !ok {
if !ok { if !ok {
glog.Errorf("unexpected: unhealthyDevices and endpoints are out of sync") glog.Errorf("unexpected: unhealthyDevices and endpoints are out of sync")
} }
@ -519,7 +539,7 @@ func (m *ManagerImpl) readCheckpoint() error {
// will stay zero till the corresponding device plugin re-registers. // will stay zero till the corresponding device plugin re-registers.
m.healthyDevices[resource] = sets.NewString() m.healthyDevices[resource] = sets.NewString()
m.unhealthyDevices[resource] = sets.NewString() m.unhealthyDevices[resource] = sets.NewString()
m.endpoints[resource] = newStoppedEndpointImpl(resource) m.endpoints[resource] = endpointInfo{e: newStoppedEndpointImpl(resource), opts: nil}
} }
return nil return nil
} }
@ -652,7 +672,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont
// plugin Allocate grpc calls if it becomes common that a container may require // plugin Allocate grpc calls if it becomes common that a container may require
// resources from multiple device plugins. // resources from multiple device plugins.
m.mutex.Lock() m.mutex.Lock()
e, ok := m.endpoints[resource] eI, ok := m.endpoints[resource]
m.mutex.Unlock() m.mutex.Unlock()
if !ok { if !ok {
m.mutex.Lock() m.mutex.Lock()
@ -665,7 +685,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont
// TODO: refactor this part of code to just append a ContainerAllocationRequest // TODO: refactor this part of code to just append a ContainerAllocationRequest
// in a passed in AllocateRequest pointer, and issues a single Allocate call per pod. // in a passed in AllocateRequest pointer, and issues a single Allocate call per pod.
glog.V(3).Infof("Making allocation request for devices %v for device plugin %s", devs, resource) glog.V(3).Infof("Making allocation request for devices %v for device plugin %s", devs, resource)
resp, err := e.allocate(devs) resp, err := eI.e.allocate(devs)
metrics.DevicePluginAllocationLatency.WithLabelValues(resource).Observe(metrics.SinceInMicroseconds(startRPCTime)) metrics.DevicePluginAllocationLatency.WithLabelValues(resource).Observe(metrics.SinceInMicroseconds(startRPCTime))
if err != nil { if err != nil {
// In case of allocation failure, we want to restore m.allocatedDevices // In case of allocation failure, we want to restore m.allocatedDevices
@ -715,11 +735,13 @@ func (m *ManagerImpl) GetDeviceRunContainerOptions(pod *v1.Pod, container *v1.Co
// with PreStartRequired option set. // with PreStartRequired option set.
func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource string) error { func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource string) error {
m.mutex.Lock() m.mutex.Lock()
opts, ok := m.pluginOpts[resource] eI, ok := m.endpoints[resource]
if !ok { if !ok {
m.mutex.Unlock() m.mutex.Unlock()
return fmt.Errorf("Plugin options not found in cache for resource: %s", resource) return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource)
} else if opts == nil || !opts.PreStartRequired { }
if eI.opts == nil || !eI.opts.PreStartRequired {
m.mutex.Unlock() m.mutex.Unlock()
glog.V(4).Infof("Plugin options indicate to skip PreStartContainer for resource: %s", resource) glog.V(4).Infof("Plugin options indicate to skip PreStartContainer for resource: %s", resource)
return nil return nil
@ -731,16 +753,10 @@ func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource s
return fmt.Errorf("no devices found allocated in local cache for pod %s, container %s, resource %s", podUID, contName, resource) return fmt.Errorf("no devices found allocated in local cache for pod %s, container %s, resource %s", podUID, contName, resource)
} }
e, ok := m.endpoints[resource]
if !ok {
m.mutex.Unlock()
return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource)
}
m.mutex.Unlock() m.mutex.Unlock()
devs := devices.UnsortedList() devs := devices.UnsortedList()
glog.V(4).Infof("Issuing an PreStartContainer call for container, %s, of pod %s", contName, podUID) glog.V(4).Infof("Issuing an PreStartContainer call for container, %s, of pod %s", contName, podUID)
_, err := e.preStartContainer(devs) _, err := eI.e.preStartContainer(devs)
if err != nil { if err != nil {
return fmt.Errorf("device plugin PreStartContainer rpc failed with err: %v", err) return fmt.Errorf("device plugin PreStartContainer rpc failed with err: %v", err)
} }

View File

@ -57,9 +57,7 @@ func (h *ManagerStub) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
return nil, nil, []string{} return nil, nil, []string{}
} }
// GetWatcherCallback returns plugin watcher callback // GetWatcherHandler returns plugin watcher interface
func (h *ManagerStub) GetWatcherCallback() pluginwatcher.RegisterCallbackFn { func (h *ManagerStub) GetWatcherHandler() pluginwatcher.PluginHandler {
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) { return nil
return nil, nil
}
} }

View File

@ -249,9 +249,10 @@ func setupDevicePlugin(t *testing.T, devs []*pluginapi.Device, pluginSocketName
func setupPluginWatcher(pluginSocketName string, m Manager) *pluginwatcher.Watcher { func setupPluginWatcher(pluginSocketName string, m Manager) *pluginwatcher.Watcher {
w := pluginwatcher.NewWatcher(filepath.Dir(pluginSocketName)) w := pluginwatcher.NewWatcher(filepath.Dir(pluginSocketName))
w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherCallback()) w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherHandler())
w.Start() w.Start()
return &w
return w
} }
func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub) { func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub) {
@ -295,7 +296,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) {
// Expects capacity for resource1 to be 2. // Expects capacity for resource1 to be 2.
resourceName1 := "domain1.com/resource1" resourceName1 := "domain1.com/resource1"
e1 := &endpointImpl{} e1 := &endpointImpl{}
testManager.endpoints[resourceName1] = e1 testManager.endpoints[resourceName1] = endpointInfo{e: e1, opts: nil}
callback(resourceName1, devs) callback(resourceName1, devs)
capacity, allocatable, removedResources := testManager.GetCapacity() capacity, allocatable, removedResources := testManager.GetCapacity()
resource1Capacity, ok := capacity[v1.ResourceName(resourceName1)] resource1Capacity, ok := capacity[v1.ResourceName(resourceName1)]
@ -345,7 +346,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) {
// Tests adding another resource. // Tests adding another resource.
resourceName2 := "resource2" resourceName2 := "resource2"
e2 := &endpointImpl{} e2 := &endpointImpl{}
testManager.endpoints[resourceName2] = e2 testManager.endpoints[resourceName2] = endpointInfo{e: e2, opts: nil}
callback(resourceName2, devs) callback(resourceName2, devs)
capacity, allocatable, removedResources = testManager.GetCapacity() capacity, allocatable, removedResources = testManager.GetCapacity()
as.Equal(2, len(capacity)) as.Equal(2, len(capacity))
@ -456,7 +457,7 @@ func TestCheckpoint(t *testing.T) {
ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) ckm, err := checkpointmanager.NewCheckpointManager(tmpDir)
as.Nil(err) as.Nil(err)
testManager := &ManagerImpl{ testManager := &ManagerImpl{
endpoints: make(map[string]endpoint), endpoints: make(map[string]endpointInfo),
healthyDevices: make(map[string]sets.String), healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String),
@ -577,7 +578,7 @@ func makePod(limits v1.ResourceList) *v1.Pod {
} }
} }
func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource, opts map[string]*pluginapi.DevicePluginOptions) (*ManagerImpl, error) { func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource) (*ManagerImpl, error) {
monitorCallback := func(resourceName string, devices []pluginapi.Device) {} monitorCallback := func(resourceName string, devices []pluginapi.Device) {}
ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) ckm, err := checkpointmanager.NewCheckpointManager(tmpDir)
if err != nil { if err != nil {
@ -589,25 +590,27 @@ func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestReso
healthyDevices: make(map[string]sets.String), healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String),
endpoints: make(map[string]endpoint), endpoints: make(map[string]endpointInfo),
pluginOpts: opts,
podDevices: make(podDevices), podDevices: make(podDevices),
activePods: activePods, activePods: activePods,
sourcesReady: &sourcesReadyStub{}, sourcesReady: &sourcesReadyStub{},
checkpointManager: ckm, checkpointManager: ckm,
} }
for _, res := range testRes { for _, res := range testRes {
testManager.healthyDevices[res.resourceName] = sets.NewString() testManager.healthyDevices[res.resourceName] = sets.NewString()
for _, dev := range res.devs { for _, dev := range res.devs {
testManager.healthyDevices[res.resourceName].Insert(dev) testManager.healthyDevices[res.resourceName].Insert(dev)
} }
if res.resourceName == "domain1.com/resource1" { if res.resourceName == "domain1.com/resource1" {
testManager.endpoints[res.resourceName] = &MockEndpoint{ testManager.endpoints[res.resourceName] = endpointInfo{
allocateFunc: allocateStubFunc(), e: &MockEndpoint{allocateFunc: allocateStubFunc()},
opts: nil,
} }
} }
if res.resourceName == "domain2.com/resource2" { if res.resourceName == "domain2.com/resource2" {
testManager.endpoints[res.resourceName] = &MockEndpoint{ testManager.endpoints[res.resourceName] = endpointInfo{
e: &MockEndpoint{
allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) { allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) {
resp := new(pluginapi.ContainerAllocateResponse) resp := new(pluginapi.ContainerAllocateResponse)
resp.Envs = make(map[string]string) resp.Envs = make(map[string]string)
@ -624,6 +627,8 @@ func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestReso
resps.ContainerResponses = append(resps.ContainerResponses, resp) resps.ContainerResponses = append(resps.ContainerResponses, resp)
return resps, nil return resps, nil
}, },
},
opts: nil,
} }
} }
} }
@ -669,10 +674,7 @@ func TestPodContainerDeviceAllocation(t *testing.T) {
as.Nil(err) as.Nil(err)
defer os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir)
nodeInfo := getTestNodeInfo(v1.ResourceList{}) nodeInfo := getTestNodeInfo(v1.ResourceList{})
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions) testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources)
pluginOpts[res1.resourceName] = nil
pluginOpts[res2.resourceName] = nil
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts)
as.Nil(err) as.Nil(err)
testPods := []*v1.Pod{ testPods := []*v1.Pod{
@ -767,10 +769,8 @@ func TestInitContainerDeviceAllocation(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "checkpoint") tmpDir, err := ioutil.TempDir("", "checkpoint")
as.Nil(err) as.Nil(err)
defer os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir)
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions)
pluginOpts[res1.resourceName] = nil testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources)
pluginOpts[res2.resourceName] = nil
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts)
as.Nil(err) as.Nil(err)
podWithPluginResourcesInInitContainers := &v1.Pod{ podWithPluginResourcesInInitContainers := &v1.Pod{
@ -904,18 +904,18 @@ func TestDevicePreStartContainer(t *testing.T) {
as.Nil(err) as.Nil(err)
defer os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir)
nodeInfo := getTestNodeInfo(v1.ResourceList{}) nodeInfo := getTestNodeInfo(v1.ResourceList{})
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions)
pluginOpts[res1.resourceName] = &pluginapi.DevicePluginOptions{PreStartRequired: true}
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1}, pluginOpts) testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1})
as.Nil(err) as.Nil(err)
ch := make(chan []string, 1) ch := make(chan []string, 1)
testManager.endpoints[res1.resourceName] = &MockEndpoint{ testManager.endpoints[res1.resourceName] = endpointInfo{
e: &MockEndpoint{
initChan: ch, initChan: ch,
allocateFunc: allocateStubFunc(), allocateFunc: allocateStubFunc(),
},
opts: &pluginapi.DevicePluginOptions{PreStartRequired: true},
} }
pod := makePod(v1.ResourceList{ pod := makePod(v1.ResourceList{
v1.ResourceName(res1.resourceName): res1.resourceQuantity}) v1.ResourceName(res1.resourceName): res1.resourceQuantity})
activePods := []*v1.Pod{} activePods := []*v1.Pod{}

View File

@ -53,7 +53,7 @@ type Manager interface {
// GetCapacity returns the amount of available device plugin resource capacity, resource allocatable // GetCapacity returns the amount of available device plugin resource capacity, resource allocatable
// and inactive device plugin resources previously registered on the node. // and inactive device plugin resources previously registered on the node.
GetCapacity() (v1.ResourceList, v1.ResourceList, []string) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
GetWatcherCallback() watcher.RegisterCallbackFn GetWatcherHandler() watcher.PluginHandler
} }
// DeviceRunContainerOptions contains the combined container runtime settings to consume its allocated devices. // DeviceRunContainerOptions contains the combined container runtime settings to consume its allocated devices.

View File

@ -1194,7 +1194,7 @@ type Kubelet struct {
// pluginwatcher is a utility for Kubelet to register different types of node-level plugins // pluginwatcher is a utility for Kubelet to register different types of node-level plugins
// such as device plugins or CSI plugins. It discovers plugins by monitoring inotify events under the // such as device plugins or CSI plugins. It discovers plugins by monitoring inotify events under the
// directory returned by kubelet.getPluginsDir() // directory returned by kubelet.getPluginsDir()
pluginWatcher pluginwatcher.Watcher pluginWatcher *pluginwatcher.Watcher
// This flag sets a maximum number of images to report in the node status. // This flag sets a maximum number of images to report in the node status.
nodeStatusMaxImages int32 nodeStatusMaxImages int32
@ -1365,9 +1365,9 @@ func (kl *Kubelet) initializeRuntimeDependentModules() {
kl.containerLogManager.Start() kl.containerLogManager.Start()
if kl.enablePluginsWatcher { if kl.enablePluginsWatcher {
// Adding Registration Callback function for CSI Driver // Adding Registration Callback function for CSI Driver
kl.pluginWatcher.AddHandler("CSIPlugin", csi.RegistrationCallback) kl.pluginWatcher.AddHandler("CSIPlugin", pluginwatcher.PluginHandler(csi.PluginHandler))
// Adding Registration Callback function for Device Manager // Adding Registration Callback function for Device Manager
kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandlerCallback()) kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandler())
// Start the plugin watcher // Start the plugin watcher
glog.V(4).Infof("starting watcher") glog.V(4).Infof("starting watcher")
if err := kl.pluginWatcher.Start(); err != nil { if err := kl.pluginWatcher.Start(); err != nil {

View File

@ -1,10 +1,4 @@
package(default_visibility = ["//visibility:public"]) load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
go_library( go_library(
name = "go_default_library", name = "go_default_library",
@ -12,8 +6,10 @@ go_library(
"example_handler.go", "example_handler.go",
"example_plugin.go", "example_plugin.go",
"plugin_watcher.go", "plugin_watcher.go",
"types.go",
], ],
importpath = "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher", importpath = "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher",
visibility = ["//visibility:public"],
deps = [ deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library", "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library",
@ -27,6 +23,16 @@ go_library(
], ],
) )
go_test(
name = "go_default_test",
srcs = ["plugin_watcher_test.go"],
embed = [":go_default_library"],
deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//vendor/github.com/stretchr/testify/require:go_default_library",
],
)
filegroup( filegroup(
name = "package-srcs", name = "package-srcs",
srcs = glob(["**"]), srcs = glob(["**"]),
@ -44,14 +50,3 @@ filegroup(
tags = ["automanaged"], tags = ["automanaged"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
go_test(
name = "go_default_test",
srcs = ["plugin_watcher_test.go"],
embed = [":go_default_library"],
deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library",
"//vendor/github.com/stretchr/testify/require:go_default_library",
],
)

View File

@ -23,6 +23,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/golang/glog"
"golang.org/x/net/context" "golang.org/x/net/context"
v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1"
@ -30,41 +31,61 @@ import (
) )
type exampleHandler struct { type exampleHandler struct {
registeredPlugins map[string]struct{} SupportedVersions []string
mutex sync.Mutex ExpectedNames map[string]int
chanForHandlerAckErrors chan error // for testing
eventChans map[string]chan examplePluginEvent // map[pluginName]eventChan
m sync.Mutex
count int
} }
type examplePluginEvent int
const (
exampleEventValidate examplePluginEvent = 0
exampleEventRegister examplePluginEvent = 1
exampleEventDeRegister examplePluginEvent = 2
exampleEventError examplePluginEvent = 3
)
// NewExampleHandler provide a example handler // NewExampleHandler provide a example handler
func NewExampleHandler() *exampleHandler { func NewExampleHandler(supportedVersions []string) *exampleHandler {
return &exampleHandler{ return &exampleHandler{
chanForHandlerAckErrors: make(chan error), SupportedVersions: supportedVersions,
registeredPlugins: make(map[string]struct{}), ExpectedNames: make(map[string]int),
eventChans: make(map[string]chan examplePluginEvent),
} }
} }
func (h *exampleHandler) Cleanup() error { func (p *exampleHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
h.mutex.Lock() p.SendEvent(pluginName, exampleEventValidate)
defer h.mutex.Unlock()
h.registeredPlugins = make(map[string]struct{})
return nil
}
func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []string, sockPath string) (chan bool, error) { n, ok := p.DecreasePluginCount(pluginName)
if !ok && n > 0 {
return fmt.Errorf("pluginName('%s') wasn't expected (count is %d)", pluginName, n)
}
// check for supported versions if !reflect.DeepEqual(versions, p.SupportedVersions) {
if !reflect.DeepEqual([]string{"v1beta1", "v1beta2"}, versions) { return fmt.Errorf("versions('%v') != supported versions('%v')", versions, p.SupportedVersions)
return nil, fmt.Errorf("not the supported versions: %s", versions)
} }
// this handler expects non-empty endpoint as an example // this handler expects non-empty endpoint as an example
if len(endpoint) == 0 { if len(endpoint) == 0 {
return nil, errors.New("expecting non empty endpoint") return errors.New("expecting non empty endpoint")
} }
_, conn, err := dial(sockPath) return nil
}
func (p *exampleHandler) RegisterPlugin(pluginName, endpoint string) error {
p.SendEvent(pluginName, exampleEventRegister)
// Verifies the grpcServer is ready to serve services.
_, conn, err := dial(endpoint, time.Second)
if err != nil { if err != nil {
return nil, err return fmt.Errorf("Failed dialing endpoint (%s): %v", endpoint, err)
} }
defer conn.Close() defer conn.Close()
@ -73,33 +94,54 @@ func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []
v1beta2Client := v1beta2.NewExampleClient(conn) v1beta2Client := v1beta2.NewExampleClient(conn)
// Tests v1beta1 GetExampleInfo // Tests v1beta1 GetExampleInfo
if _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}); err != nil { _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{})
return nil, err if err != nil {
return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err)
} }
// Tests v1beta2 GetExampleInfo // Tests v1beta1 GetExampleInfo
if _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}); err != nil { _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{})
return nil, err if err != nil {
return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err)
} }
// handle registered plugin return nil
h.mutex.Lock() }
if _, exist := h.registeredPlugins[pluginName]; exist {
h.mutex.Unlock() func (p *exampleHandler) DeRegisterPlugin(pluginName string) {
return nil, fmt.Errorf("plugin %s already registered", pluginName) p.SendEvent(pluginName, exampleEventDeRegister)
} }
h.registeredPlugins[pluginName] = struct{}{}
h.mutex.Unlock() func (p *exampleHandler) EventChan(pluginName string) chan examplePluginEvent {
return p.eventChans[pluginName]
chanForAckOfNotification := make(chan bool) }
go func() {
select { func (p *exampleHandler) SendEvent(pluginName string, event examplePluginEvent) {
case <-chanForAckOfNotification: glog.V(2).Infof("Sending %v for plugin %s over chan %v", event, pluginName, p.eventChans[pluginName])
// TODO: handle the negative scenario p.eventChans[pluginName] <- event
close(chanForAckOfNotification) }
case <-time.After(time.Second):
h.chanForHandlerAckErrors <- errors.New("Timed out while waiting for notification ack") func (p *exampleHandler) AddPluginName(pluginName string) {
} p.m.Lock()
}() defer p.m.Unlock()
return chanForAckOfNotification, nil
v, ok := p.ExpectedNames[pluginName]
if !ok {
p.eventChans[pluginName] = make(chan examplePluginEvent)
v = 1
}
p.ExpectedNames[pluginName] = v
}
func (p *exampleHandler) DecreasePluginCount(pluginName string) (old int, ok bool) {
p.m.Lock()
defer p.m.Unlock()
v, ok := p.ExpectedNames[pluginName]
if !ok {
v = -1
}
return v, ok
} }

View File

@ -18,7 +18,9 @@ package pluginwatcher
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
@ -39,6 +41,7 @@ type examplePlugin struct {
endpoint string // for testing endpoint string // for testing
pluginName string pluginName string
pluginType string pluginType string
versions []string
} }
type pluginServiceV1Beta1 struct { type pluginServiceV1Beta1 struct {
@ -73,12 +76,13 @@ func NewExamplePlugin() *examplePlugin {
} }
// NewTestExamplePlugin returns an initialized examplePlugin instance for testing // NewTestExamplePlugin returns an initialized examplePlugin instance for testing
func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string) *examplePlugin { func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string, advertisedVersions ...string) *examplePlugin {
return &examplePlugin{ return &examplePlugin{
pluginName: pluginName, pluginName: pluginName,
pluginType: pluginType, pluginType: pluginType,
registrationStatus: make(chan registerapi.RegistrationStatus),
endpoint: endpoint, endpoint: endpoint,
versions: advertisedVersions,
registrationStatus: make(chan registerapi.RegistrationStatus),
} }
} }
@ -88,36 +92,48 @@ func (e *examplePlugin) GetInfo(ctx context.Context, req *registerapi.InfoReques
Type: e.pluginType, Type: e.pluginType,
Name: e.pluginName, Name: e.pluginName,
Endpoint: e.endpoint, Endpoint: e.endpoint,
SupportedVersions: []string{"v1beta1", "v1beta2"}, SupportedVersions: e.versions,
}, nil }, nil
} }
func (e *examplePlugin) NotifyRegistrationStatus(ctx context.Context, status *registerapi.RegistrationStatus) (*registerapi.RegistrationStatusResponse, error) { func (e *examplePlugin) NotifyRegistrationStatus(ctx context.Context, status *registerapi.RegistrationStatus) (*registerapi.RegistrationStatusResponse, error) {
glog.Errorf("Registration is: %v\n", status)
if e.registrationStatus != nil { if e.registrationStatus != nil {
e.registrationStatus <- *status e.registrationStatus <- *status
} }
if !status.PluginRegistered {
glog.Errorf("Registration failed: %s\n", status.Error)
}
return &registerapi.RegistrationStatusResponse{}, nil return &registerapi.RegistrationStatusResponse{}, nil
} }
// Serve starts example plugin grpc server // Serve starts a pluginwatcher server and one or more of the plugin services
func (e *examplePlugin) Serve(socketPath string) error { func (e *examplePlugin) Serve(services ...string) error {
glog.Infof("starting example server at: %s\n", socketPath) glog.Infof("starting example server at: %s\n", e.endpoint)
lis, err := net.Listen("unix", socketPath) lis, err := net.Listen("unix", e.endpoint)
if err != nil { if err != nil {
return err return err
} }
glog.Infof("example server started at: %s\n", socketPath)
glog.Infof("example server started at: %s\n", e.endpoint)
e.grpcServer = grpc.NewServer() e.grpcServer = grpc.NewServer()
// Registers kubelet plugin watcher api. // Registers kubelet plugin watcher api.
registerapi.RegisterRegistrationServer(e.grpcServer, e) registerapi.RegisterRegistrationServer(e.grpcServer, e)
// Registers services for both v1beta1 and v1beta2 versions.
for _, service := range services {
switch service {
case "v1beta1":
v1beta1 := &pluginServiceV1Beta1{server: e} v1beta1 := &pluginServiceV1Beta1{server: e}
v1beta1.RegisterService() v1beta1.RegisterService()
break
case "v1beta2":
v1beta2 := &pluginServiceV1Beta2{server: e} v1beta2 := &pluginServiceV1Beta2{server: e}
v1beta2.RegisterService() v1beta2.RegisterService()
break
default:
return fmt.Errorf("Unsupported service: '%s'", service)
}
}
// Starts service // Starts service
e.wg.Add(1) e.wg.Add(1)
@ -128,22 +144,30 @@ func (e *examplePlugin) Serve(socketPath string) error {
glog.Errorf("example server stopped serving: %v", err) glog.Errorf("example server stopped serving: %v", err)
} }
}() }()
return nil return nil
} }
func (e *examplePlugin) Stop() error { func (e *examplePlugin) Stop() error {
glog.Infof("Stopping example server\n") glog.Infof("Stopping example server at: %s\n", e.endpoint)
e.grpcServer.Stop() e.grpcServer.Stop()
c := make(chan struct{}) c := make(chan struct{})
go func() { go func() {
defer close(c) defer close(c)
e.wg.Wait() e.wg.Wait()
}() }()
select { select {
case <-c: case <-c:
return nil break
case <-time.After(time.Second): case <-time.After(time.Second):
glog.Errorf("Timed out on waiting for stop completion")
return errors.New("Timed out on waiting for stop completion") return errors.New("Timed out on waiting for stop completion")
} }
if err := os.Remove(e.endpoint); err != nil && !os.IsNotExist(err) {
return err
}
return nil
} }

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"time" "time"
@ -28,43 +29,144 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1"
utilfs "k8s.io/kubernetes/pkg/util/filesystem" utilfs "k8s.io/kubernetes/pkg/util/filesystem"
) )
// RegisterCallbackFn is the type of the callback function that handlers will provide
type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error)
// Watcher is the plugin watcher // Watcher is the plugin watcher
type Watcher struct { type Watcher struct {
path string path string
handlers map[string]RegisterCallbackFn
stopCh chan interface{} stopCh chan interface{}
fs utilfs.Filesystem fs utilfs.Filesystem
fsWatcher *fsnotify.Watcher fsWatcher *fsnotify.Watcher
wg sync.WaitGroup wg sync.WaitGroup
mutex sync.Mutex mutex sync.Mutex
handlers map[string]PluginHandler
plugins map[string]pathInfo
pluginsPool map[string]map[string]*sync.Mutex // map[pluginType][pluginName]
}
type pathInfo struct {
pluginType string
pluginName string
} }
// NewWatcher provides a new watcher // NewWatcher provides a new watcher
func NewWatcher(sockDir string) Watcher { func NewWatcher(sockDir string) *Watcher {
return Watcher{ return &Watcher{
path: sockDir, path: sockDir,
handlers: make(map[string]RegisterCallbackFn),
fs: &utilfs.DefaultFs{}, fs: &utilfs.DefaultFs{},
handlers: make(map[string]PluginHandler),
plugins: make(map[string]pathInfo),
pluginsPool: make(map[string]map[string]*sync.Mutex),
} }
} }
// AddHandler registers a callback to be invoked for a particular type of plugin func (w *Watcher) AddHandler(pluginType string, handler PluginHandler) {
func (w *Watcher) AddHandler(pluginType string, handlerCbkFn RegisterCallbackFn) {
w.mutex.Lock() w.mutex.Lock()
defer w.mutex.Unlock() defer w.mutex.Unlock()
w.handlers[pluginType] = handlerCbkFn
w.handlers[pluginType] = handler
} }
// Creates the plugin directory, if it doesn't already exist. func (w *Watcher) getHandler(pluginType string) (PluginHandler, bool) {
func (w *Watcher) createPluginDir() error { w.mutex.Lock()
defer w.mutex.Unlock()
h, ok := w.handlers[pluginType]
return h, ok
}
// Start watches for the creation of plugin sockets at the path
func (w *Watcher) Start() error {
glog.V(2).Infof("Plugin Watcher Start at %s", w.path)
w.stopCh = make(chan interface{})
// Creating the directory to be watched if it doesn't exist yet,
// and walks through the directory to discover the existing plugins.
if err := w.init(); err != nil {
return err
}
fsWatcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err)
}
w.fsWatcher = fsWatcher
w.wg.Add(1)
go func(fsWatcher *fsnotify.Watcher) {
defer w.wg.Done()
for {
select {
case event := <-fsWatcher.Events:
//TODO: Handle errors by taking corrective measures
w.wg.Add(1)
go func() {
defer w.wg.Done()
if event.Op&fsnotify.Create == fsnotify.Create {
err := w.handleCreateEvent(event)
if err != nil {
glog.Errorf("error %v when handling create event: %s", err, event)
}
} else if event.Op&fsnotify.Remove == fsnotify.Remove {
err := w.handleDeleteEvent(event)
if err != nil {
glog.Errorf("error %v when handling delete event: %s", err, event)
}
}
return
}()
continue
case err := <-fsWatcher.Errors:
if err != nil {
glog.Errorf("fsWatcher received error: %v", err)
}
continue
case <-w.stopCh:
return
}
}
}(fsWatcher)
// Traverse plugin dir after starting the plugin processing goroutine
if err := w.traversePluginDir(w.path); err != nil {
w.Stop()
return fmt.Errorf("failed to traverse plugin socket path, err: %v", err)
}
return nil
}
// Stop stops probing the creation of plugin sockets at the path
func (w *Watcher) Stop() error {
close(w.stopCh)
c := make(chan struct{})
go func() {
defer close(c)
w.wg.Wait()
}()
select {
case <-c:
case <-time.After(11 * time.Second):
return fmt.Errorf("timeout on stopping watcher")
}
w.fsWatcher.Close()
return nil
}
func (w *Watcher) init() error {
glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path) glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path)
if err := w.fs.MkdirAll(w.path, 0755); err != nil { if err := w.fs.MkdirAll(w.path, 0755); err != nil {
return fmt.Errorf("error (re-)creating root %s: %v", w.path, err) return fmt.Errorf("error (re-)creating root %s: %v", w.path, err)
} }
@ -91,22 +193,38 @@ func (w *Watcher) traversePluginDir(dir string) error {
Op: fsnotify.Create, Op: fsnotify.Create,
} }
}() }()
default:
glog.V(5).Infof("Ignoring file %s with mode %v", path, mode)
} }
return nil return nil
}) })
} }
func (w *Watcher) init() error { // Handle filesystem notify event.
if err := w.createPluginDir(); err != nil { func (w *Watcher) handleCreateEvent(event fsnotify.Event) error {
return err glog.V(6).Infof("Handling create event: %v", event)
fi, err := os.Stat(event.Name)
if err != nil {
return fmt.Errorf("stat file %s failed: %v", event.Name, err)
} }
if strings.HasPrefix(fi.Name(), ".") {
glog.Errorf("Ignoring file: %s", fi.Name())
return nil return nil
}
if !fi.IsDir() {
return w.handlePluginRegistration(event.Name)
}
return w.traversePluginDir(event.Name)
} }
func (w *Watcher) registerPlugin(socketPath string) error { func (w *Watcher) handlePluginRegistration(socketPath string) error {
//TODO: Implement rate limiting to mitigate any DOS kind of attacks. //TODO: Implement rate limiting to mitigate any DOS kind of attacks.
client, conn, err := dial(socketPath) client, conn, err := dial(socketPath, 10*time.Second)
if err != nil { if err != nil {
return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err) return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err)
} }
@ -114,154 +232,161 @@ func (w *Watcher) registerPlugin(socketPath string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
infoResp, err := client.GetInfo(ctx, &registerapi.InfoRequest{}) infoResp, err := client.GetInfo(ctx, &registerapi.InfoRequest{})
if err != nil { if err != nil {
return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err) return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err)
} }
return w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath) handler, ok := w.handlers[infoResp.Type]
}
func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, client registerapi.RegistrationClient, infoResp *registerapi.PluginInfo, socketPath string) error {
var handlerCbkFn RegisterCallbackFn
var ok bool
handlerCbkFn, ok = w.handlers[infoResp.Type]
if !ok { if !ok {
errStr := fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath) return w.notifyPlugin(client, false, fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath))
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{
PluginRegistered: false,
Error: errStr,
}); err != nil {
return errors.Wrap(err, errStr)
}
return errors.New(errStr)
} }
var versions []string // ReRegistration: We want to handle multiple plugins registering at the same time with the same name sequentially.
for _, version := range infoResp.SupportedVersions { // See the state machine for more information.
versions = append(versions, version) // This is done by using a Lock for each plugin with the same name and type
pool := w.getPluginPool(infoResp.Type, infoResp.Name)
pool.Lock()
defer pool.Unlock()
if infoResp.Endpoint == "" {
infoResp.Endpoint = socketPath
} }
// calls handler callback to verify registration request // calls handler callback to verify registration request
chanForAckOfNotification, err := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) if err := handler.ValidatePlugin(infoResp.Name, infoResp.Endpoint, infoResp.SupportedVersions); err != nil {
if err != nil { return w.notifyPlugin(client, false, fmt.Sprintf("plugin validation failed with err: %v", err))
errStr := fmt.Sprintf("plugin registration failed with err: %v", err)
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{
PluginRegistered: false,
Error: errStr,
}); err != nil {
return errors.Wrap(err, errStr)
}
return errors.New(errStr)
} }
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{ // We add the plugin to the pluginwatcher's map before calling a plugin consumer's Register handle
PluginRegistered: true, // so that if we receive a delete event during Register Plugin, we can process it as a DeRegister call.
}); err != nil { w.registerPlugin(socketPath, infoResp.Type, infoResp.Name)
chanForAckOfNotification <- false
if err := handler.RegisterPlugin(infoResp.Name, infoResp.Endpoint); err != nil {
return w.notifyPlugin(client, false, fmt.Sprintf("plugin registration failed with err: %v", err))
}
// Notify is called after register to guarantee that even if notify throws an error Register will always be called after validate
if err := w.notifyPlugin(client, true, ""); err != nil {
return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err) return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err)
} }
chanForAckOfNotification <- true
return nil return nil
} }
// Handle filesystem notify event. func (w *Watcher) handleDeleteEvent(event fsnotify.Event) error {
func (w *Watcher) handleFsNotifyEvent(event fsnotify.Event) error { glog.V(6).Infof("Handling delete event: %v", event)
if event.Op&fsnotify.Create != fsnotify.Create {
plugin, ok := w.getPlugin(event.Name)
if !ok {
return fmt.Errorf("could not find plugin for deleted file %s", event.Name)
}
// You should not get a Deregister call while registering a plugin
pool := w.getPluginPool(plugin.pluginType, plugin.pluginName)
pool.Lock()
defer pool.Unlock()
// ReRegisteration: When waiting for the lock a plugin with the same name (not socketPath) could have registered
// In that case, we don't want to issue a DeRegister call for that plugin
// When ReRegistering, the new plugin will have removed the current mapping (map[socketPath] = plugin) and replaced
// it with it's own socketPath.
if _, ok = w.getPlugin(event.Name); !ok {
glog.V(2).Infof("A newer plugin watcher has been registered for plugin %v, dropping DeRegister call", plugin)
return nil return nil
} }
fi, err := os.Stat(event.Name) h, ok := w.getHandler(plugin.pluginType)
if err != nil { if !ok {
return fmt.Errorf("stat file %s failed: %v", event.Name, err) return fmt.Errorf("could not find handler %s for plugin %s at path %s", plugin.pluginType, plugin.pluginName, event.Name)
} }
if !fi.IsDir() { glog.V(2).Infof("DeRegistering plugin %v at path %s", plugin, event.Name)
return w.registerPlugin(event.Name) w.deRegisterPlugin(event.Name, plugin.pluginType, plugin.pluginName)
} h.DeRegisterPlugin(plugin.pluginName)
if err := w.traversePluginDir(event.Name); err != nil {
return fmt.Errorf("failed to traverse plugin path %s, err: %v", event.Name, err)
}
return nil return nil
} }
// Start watches for the creation of plugin sockets at the path func (w *Watcher) registerPlugin(socketPath, pluginType, pluginName string) {
func (w *Watcher) Start() error { w.mutex.Lock()
glog.V(2).Infof("Plugin Watcher Start at %s", w.path) defer w.mutex.Unlock()
w.stopCh = make(chan interface{})
// Creating the directory to be watched if it doesn't exist yet, // Reregistration case, if this plugin is already in the map, remove it
// and walks through the directory to discover the existing plugins. // This will prevent handleDeleteEvent to issue a DeRegister call
if err := w.init(); err != nil { for path, info := range w.plugins {
return err if info.pluginType != pluginType || info.pluginName != pluginName {
}
fsWatcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err)
}
w.fsWatcher = fsWatcher
if err := w.traversePluginDir(w.path); err != nil {
fsWatcher.Close()
return fmt.Errorf("failed to traverse plugin socket path, err: %v", err)
}
w.wg.Add(1)
go func(fsWatcher *fsnotify.Watcher) {
defer w.wg.Done()
for {
select {
case event := <-fsWatcher.Events:
//TODO: Handle errors by taking corrective measures
go func() {
err := w.handleFsNotifyEvent(event)
if err != nil {
glog.Errorf("error %v when handle event: %s", err, event)
}
}()
continue continue
case err := <-fsWatcher.Errors:
if err != nil {
glog.Errorf("fsWatcher received error: %v", err)
} }
continue
case <-w.stopCh: delete(w.plugins, path)
fsWatcher.Close() break
return
} }
w.plugins[socketPath] = pathInfo{
pluginType: pluginType,
pluginName: pluginName,
} }
}(fsWatcher)
return nil
} }
// Stop stops probing the creation of plugin sockets at the path func (w *Watcher) deRegisterPlugin(socketPath, pluginType, pluginName string) {
func (w *Watcher) Stop() error { w.mutex.Lock()
close(w.stopCh) defer w.mutex.Unlock()
c := make(chan struct{})
go func() { delete(w.plugins, socketPath)
defer close(c) delete(w.pluginsPool[pluginType], pluginName)
w.wg.Wait()
}()
select {
case <-c:
case <-time.After(10 * time.Second):
return fmt.Errorf("timeout on stopping watcher")
}
return nil
} }
// Cleanup cleans the path by removing sockets func (w *Watcher) getPlugin(socketPath string) (pathInfo, bool) {
func (w *Watcher) Cleanup() error { w.mutex.Lock()
return os.RemoveAll(w.path) defer w.mutex.Unlock()
plugin, ok := w.plugins[socketPath]
return plugin, ok
}
func (w *Watcher) getPluginPool(pluginType, pluginName string) *sync.Mutex {
w.mutex.Lock()
defer w.mutex.Unlock()
if _, ok := w.pluginsPool[pluginType]; !ok {
w.pluginsPool[pluginType] = make(map[string]*sync.Mutex)
}
if _, ok := w.pluginsPool[pluginType][pluginName]; !ok {
w.pluginsPool[pluginType][pluginName] = &sync.Mutex{}
}
return w.pluginsPool[pluginType][pluginName]
}
func (w *Watcher) notifyPlugin(client registerapi.RegistrationClient, registered bool, errStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status := &registerapi.RegistrationStatus{
PluginRegistered: registered,
Error: errStr,
}
if _, err := client.NotifyRegistrationStatus(ctx, status); err != nil {
return errors.Wrap(err, errStr)
}
if errStr != "" {
return errors.New(errStr)
}
return nil
} }
// Dial establishes the gRPC communication with the picked up plugin socket. https://godoc.org/google.golang.org/grpc#Dial // Dial establishes the gRPC communication with the picked up plugin socket. https://godoc.org/google.golang.org/grpc#Dial
func dial(unixSocketPath string) (registerapi.RegistrationClient, *grpc.ClientConn, error) { func dial(unixSocketPath string, timeout time.Duration) (registerapi.RegistrationClient, *grpc.ClientConn, error) {
c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(),
grpc.WithTimeout(10*time.Second), grpc.WithTimeout(timeout),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout) return net.DialTimeout("unix", addr, timeout)
}), }),

View File

@ -17,192 +17,222 @@ limitations under the License.
package pluginwatcher package pluginwatcher
import ( import (
"errors" "flag"
"fmt"
"io/ioutil" "io/ioutil"
"path/filepath" "os"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1"
) )
// helper function var (
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { socketDir string
supportedVersions = []string{"v1beta1", "v1beta2"}
)
func init() {
var logLevel string
flag.Set("alsologtostderr", fmt.Sprintf("%t", true))
flag.StringVar(&logLevel, "logLevel", "6", "test")
flag.Lookup("v").Value.Set(logLevel)
d, err := ioutil.TempDir("", "plugin_test")
if err != nil {
panic(fmt.Sprintf("Could not create a temp directory: %s", d))
}
socketDir = d
}
func cleanup(t *testing.T) {
require.NoError(t, os.RemoveAll(socketDir))
os.MkdirAll(socketDir, 0755)
}
func TestPluginRegistration(t *testing.T) {
defer cleanup(t)
hdlr := NewExampleHandler(supportedVersions)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
for i := 0; i < 10; i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
pluginName := fmt.Sprintf("example-plugin-%d", i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
require.NoError(t, p.Stop())
require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(p.pluginName)))
}
}
func TestPluginReRegistration(t *testing.T) {
defer cleanup(t)
pluginName := fmt.Sprintf("example-plugin")
hdlr := NewExampleHandler(supportedVersions)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
plugins := make([]*examplePlugin, 10)
for i := 0; i < 10; i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
plugins[i] = p
}
plugins[len(plugins)-1].Stop()
require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(pluginName)))
close(hdlr.EventChan(pluginName))
for i := 0; i < len(plugins)-1; i++ {
plugins[i].Stop()
}
}
func TestPluginRegistrationAtKubeletStart(t *testing.T) {
defer cleanup(t)
hdlr := NewExampleHandler(supportedVersions)
plugins := make([]*examplePlugin, 10)
for i := 0; i < len(plugins); i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
pluginName := fmt.Sprintf("example-plugin-%d", i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
defer func(p *examplePlugin) { require.NoError(t, p.Stop()) }(p)
plugins[i] = p
}
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
var wg sync.WaitGroup
for i := 0; i < len(plugins); i++ {
wg.Add(1)
go func(p *examplePlugin) {
defer wg.Done()
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
}(plugins[i])
}
c := make(chan struct{}) c := make(chan struct{})
go func() { go func() {
defer close(c) defer close(c)
wg.Wait() wg.Wait()
}() }()
select { select {
case <-c: case <-c:
return false // completed normally return
case <-time.After(timeout):
return true // timed out
}
}
func TestExamplePlugin(t *testing.T) {
rootDir, err := ioutil.TempDir("", "plugin_test")
require.NoError(t, err)
w := NewWatcher(rootDir)
h := NewExampleHandler()
w.AddHandler(registerapi.DevicePlugin, h.Handler)
require.NoError(t, w.Start())
socketPath := filepath.Join(rootDir, "plugin.sock")
PluginName := "example-plugin"
// handler expecting plugin has a non-empty endpoint
p := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "")
require.NoError(t, p.Serve(socketPath))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
require.NoError(t, p.Stop())
p = NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint")
require.NoError(t, p.Serve(socketPath))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
// Trying to start a plugin service at the same socket path should fail
// with "bind: address already in use"
require.NotNil(t, p.Serve(socketPath))
// grpcServer.Stop() will remove the socket and starting plugin service
// at the same path again should succeeds and trigger another callback.
require.NoError(t, p.Stop())
require.Nil(t, p.Serve(socketPath))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
// Starting another plugin with the same name got verification error.
p2 := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint")
socketPath2 := filepath.Join(rootDir, "plugin2.sock")
require.NoError(t, p2.Serve(socketPath2))
require.False(t, waitForPluginRegistrationStatus(t, p2.registrationStatus))
// Restarts plugin watcher should traverse the socket directory and issues a
// callback for every existing socket.
require.NoError(t, w.Stop())
require.NoError(t, h.Cleanup())
require.NoError(t, w.Start())
var wg sync.WaitGroup
wg.Add(2)
var pStatus string
var p2Status string
go func() {
pStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, p.registrationStatus))
wg.Done()
}()
go func() {
p2Status = strconv.FormatBool(waitForPluginRegistrationStatus(t, p2.registrationStatus))
wg.Done()
}()
if waitTimeout(&wg, 2*time.Second) {
t.Fatalf("Timed out waiting for wait group")
}
expectedSet := sets.NewString()
expectedSet.Insert("true", "false")
actualSet := sets.NewString()
actualSet.Insert(pStatus, p2Status)
require.Equal(t, expectedSet, actualSet)
select {
case err := <-h.chanForHandlerAckErrors:
t.Fatalf("%v", err)
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatalf("Timeout while waiting for the plugin registration status")
} }
require.NoError(t, w.Stop())
require.NoError(t, w.Cleanup())
} }
func TestPluginWithSubDir(t *testing.T) { func TestPluginRegistrationFailureWithUnsupportedVersion(t *testing.T) {
rootDir, err := ioutil.TempDir("", "plugin_test") defer cleanup(t)
require.NoError(t, err)
w := NewWatcher(rootDir) pluginName := fmt.Sprintf("example-plugin")
hcsi := NewExampleHandler() socketPath := socketDir + "/plugin.sock"
hdp := NewExampleHandler()
w.AddHandler(registerapi.CSIPlugin, hcsi.Handler) hdlr := NewExampleHandler(supportedVersions)
w.AddHandler(registerapi.DevicePlugin, hdp.Handler) hdlr.AddPluginName(pluginName)
err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.DevicePlugin), 0755) w := newWatcherWithHandler(t, hdlr)
require.NoError(t, err) defer func() { require.NoError(t, w.Stop()) }()
err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.CSIPlugin), 0755)
require.NoError(t, err)
dpSocketPath := filepath.Join(rootDir, registerapi.DevicePlugin, "plugin.sock") // Advertise v1beta3 but don't serve anything else than the plugin service
csiSocketPath := filepath.Join(rootDir, registerapi.CSIPlugin, "plugin.sock") p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3")
require.NoError(t, p.Serve())
defer func() { require.NoError(t, p.Stop()) }()
require.NoError(t, w.Start()) require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
// two plugins using the same name but with different type
dp := NewTestExamplePlugin("exampleplugin", registerapi.DevicePlugin, "example-endpoint")
require.NoError(t, dp.Serve(dpSocketPath))
require.True(t, waitForPluginRegistrationStatus(t, dp.registrationStatus))
csi := NewTestExamplePlugin("exampleplugin", registerapi.CSIPlugin, "example-endpoint")
require.NoError(t, csi.Serve(csiSocketPath))
require.True(t, waitForPluginRegistrationStatus(t, csi.registrationStatus))
// Restarts plugin watcher should traverse the socket directory and issues a
// callback for every existing socket.
require.NoError(t, w.Stop())
require.NoError(t, hcsi.Cleanup())
require.NoError(t, hdp.Cleanup())
require.NoError(t, w.Start())
var wg sync.WaitGroup
wg.Add(2)
var dpStatus string
var csiStatus string
go func() {
dpStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, dp.registrationStatus))
wg.Done()
}()
go func() {
csiStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, csi.registrationStatus))
wg.Done()
}()
if waitTimeout(&wg, 4*time.Second) {
require.NoError(t, errors.New("Timed out waiting for wait group"))
}
expectedSet := sets.NewString()
expectedSet.Insert("true", "true")
actualSet := sets.NewString()
actualSet.Insert(dpStatus, csiStatus)
require.Equal(t, expectedSet, actualSet)
select {
case err := <-hcsi.chanForHandlerAckErrors:
t.Fatalf("%v", err)
case err := <-hdp.chanForHandlerAckErrors:
t.Fatalf("%v", err)
case <-time.After(4 * time.Second):
}
require.NoError(t, w.Stop())
require.NoError(t, w.Cleanup())
} }
func waitForPluginRegistrationStatus(t *testing.T, statusCh chan registerapi.RegistrationStatus) bool { func TestPlugiRegistrationFailureWithUnsupportedVersionAtKubeletStart(t *testing.T) {
defer cleanup(t)
pluginName := fmt.Sprintf("example-plugin")
socketPath := socketDir + "/plugin.sock"
// Advertise v1beta3 but don't serve anything else than the plugin service
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3")
require.NoError(t, p.Serve())
defer func() { require.NoError(t, p.Stop()) }()
hdlr := NewExampleHandler(supportedVersions)
hdlr.AddPluginName(pluginName)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
}
func waitForPluginRegistrationStatus(t *testing.T, statusChan chan registerapi.RegistrationStatus) bool {
select { select {
case status := <-statusCh: case status := <-statusChan:
return status.PluginRegistered return status.PluginRegistered
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
t.Fatalf("Timed out while waiting for registration status") t.Fatalf("Timed out while waiting for registration status")
} }
return false return false
} }
func waitForEvent(t *testing.T, expected examplePluginEvent, eventChan chan examplePluginEvent) bool {
select {
case event := <-eventChan:
return event == expected
case <-time.After(2 * time.Second):
t.Fatalf("Timed out while waiting for registration status %v", expected)
}
return false
}
func newWatcherWithHandler(t *testing.T, hdlr PluginHandler) *Watcher {
w := NewWatcher(socketDir)
w.AddHandler(registerapi.DevicePlugin, hdlr)
require.NoError(t, w.Start())
return w
}

View File

@ -0,0 +1,59 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package pluginwatcher
// PluginHandler is an interface a client of the pluginwatcher API needs to implement in
// order to consume plugins
// The PluginHandler follows the simple following state machine:
//
// +--------------------------------------+
// | ReRegistration |
// | Socket created with same plugin name |
// | |
// | |
// Socket Created v + Socket Deleted
// +------------------> Validate +---------------------------> Register +------------------> DeRegister
// + + +
// | | |
// | Error | Error |
// | | |
// v v v
// Out Out Out
//
// The pluginwatcher module follows strictly and sequentially this state machine for each *plugin name*.
// e.g: If you are Registering a plugin foo, you cannot get a DeRegister call for plugin foo
// until the Register("foo") call returns. Nor will you get a Validate("foo", "Different endpoint", ...)
// call until the Register("foo") call returns.
//
// ReRegistration: Socket created with same plugin name, usually for a plugin update
// e.g: plugin with name foo registers at foo.com/foo-1.9.7 later a plugin with name foo
// registers at foo.com/foo-1.9.9
//
// DeRegistration: When ReRegistration happens only the deletion of the new socket will trigger a DeRegister call
type PluginHandler interface {
// Validate returns an error if the information provided by
// the potential plugin is erroneous (unsupported version, ...)
ValidatePlugin(pluginName string, endpoint string, versions []string) error
// RegisterPlugin is called so that the plugin can be register by any
// plugin consumer
// Error encountered here can still be Notified to the plugin.
RegisterPlugin(pluginName, endpoint string) error
// DeRegister is called once the pluginwatcher observes that the socket has
// been deleted.
DeRegisterPlugin(pluginName string)
}

View File

@ -28,6 +28,7 @@ import (
"context" "context"
"github.com/golang/glog" "github.com/golang/glog"
api "k8s.io/api/core/v1" api "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors" apierrs "k8s.io/apimachinery/pkg/api/errors"
meta "k8s.io/apimachinery/pkg/apis/meta/v1" meta "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -89,6 +90,10 @@ type csiDriversStore struct {
sync.RWMutex sync.RWMutex
} }
// RegistrationHandler is the handler which is fed to the pluginwatcher API.
type RegistrationHandler struct {
}
// TODO (verult) consider using a struct instead of global variables // TODO (verult) consider using a struct instead of global variables
// csiDrivers map keep track of all registered CSI drivers on the node and their // csiDrivers map keep track of all registered CSI drivers on the node and their
// corresponding sockets // corresponding sockets
@ -96,21 +101,28 @@ var csiDrivers csiDriversStore
var nodeUpdater nodeupdater.Interface var nodeUpdater nodeupdater.Interface
// RegistrationCallback is called by kubelet's plugin watcher upon detection // PluginHandler is the plugin registration handler interface passed to the
// pluginwatcher module in kubelet
var PluginHandler = &RegistrationHandler{}
// ValidatePlugin is called by kubelet's plugin watcher upon detection
// of a new registration socket opened by CSI Driver registrar side car. // of a new registration socket opened by CSI Driver registrar side car.
func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) { func (h *RegistrationHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
glog.Infof(log("Trying to register a new plugin with name: %s endpoint: %s versions: %s",
pluginName, endpoint, strings.Join(versions, ",")))
glog.Infof(log("Callback from kubelet with plugin name: %s endpoint: %s versions: %s socket path: %s", return nil
pluginName, endpoint, strings.Join(versions, ","), socketPath)) }
if endpoint == "" { // RegisterPlugin is called when a plugin can be registered
endpoint = socketPath func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string) error {
} glog.Infof(log("Register new plugin with name: %s at endpoint: %s", pluginName, endpoint))
// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
// all other CSI components will be able to get the actual socket of CSI drivers by its name. // all other CSI components will be able to get the actual socket of CSI drivers by its name.
csiDrivers.Lock() csiDrivers.Lock()
defer csiDrivers.Unlock() defer csiDrivers.Unlock()
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint} csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint}
// Get node info from the driver. // Get node info from the driver.
@ -118,22 +130,27 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string,
// TODO (verult) retry with exponential backoff, possibly added in csi client library. // TODO (verult) retry with exponential backoff, possibly added in csi client library.
ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) ctx, cancel := context.WithTimeout(context.Background(), csiTimeout)
defer cancel() defer cancel()
driverNodeID, maxVolumePerNode, _, err := csi.NodeGetInfo(ctx) driverNodeID, maxVolumePerNode, _, err := csi.NodeGetInfo(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err) return fmt.Errorf("error during CSI NodeGetInfo() call: %v", err)
} }
// Calling nodeLabelManager to update annotations and labels for newly registered CSI driver // Calling nodeLabelManager to update annotations and labels for newly registered CSI driver
err = nodeUpdater.AddLabelsAndLimits(pluginName, driverNodeID, maxVolumePerNode) err = nodeUpdater.AddLabelsAndLimits(pluginName, driverNodeID, maxVolumePerNode)
if err != nil { if err != nil {
// Unregister the driver and return error // Unregister the driver and return error
csiDrivers.Lock()
defer csiDrivers.Unlock()
delete(csiDrivers.driversMap, pluginName) delete(csiDrivers.driversMap, pluginName)
return nil, err return fmt.Errorf("error while adding CSI labels: %v", err)
} }
return nil, nil return nil
}
// DeRegisterPlugin is called when a plugin removed it's socket, signaling
// it is no longer available
// TODO: Handle DeRegistration
func (h *RegistrationHandler) DeRegisterPlugin(pluginName string) {
} }
func (p *csiPlugin) Init(host volume.VolumeHost) error { func (p *csiPlugin) Init(host volume.VolumeHost) error {