Merge pull request #64621 from RenaudWasTaken/pluginwatcher

Automatic merge from submit-queue (batch tested with PRs 68087, 68256, 64621, 68299, 68296). If you want to cherry-pick this change to another branch, please follow the instructions here: https://github.com/kubernetes/community/blob/master/contributors/devel/cherry-picks.md.

Change plugin watcher registration mechanism

**Which issue(s) this PR fixes**: #64637

**Notes For Reviewers**:
The current API the plugin watcher exposes to kubelet is the following:
```golang
type RegisterCallbackFn func(pluginName string, endpoint string,
                             versions []string, socketPath string) (error, chan bool)	
```

The callback channel is here to signal the plugin watcher consumer when the plugin watcher API has notified the plugin of it's successful registration.
In other words the current lifecycle of a plugin is the following:
```
(pluginwatcher) GetInfo -> (pluginwatcher) NotifyRegistrationStatus -> (deviceplugin) ListWatch
```
Rather than
```
(pluginwatcher) GetInfo (race) -> (pluginwatcher) NotifyRegistrationStatus
                        (race) -> (deviceplugin) ListWatch
```

This PR changes the callback/channel mechanism to a more explicit, interfaced based contract (and more maintainable than a function to which we add more channels for more lifecycle events).

This PR also introduces three new states: {Init, Register, DeRegister}
```golang
// PluginHandler is an interface a client of the pluginwatcher API needs to implement in
// order to consume plugins
// The PluginHandler follows the simple following state machine:
//
//                         +--------------------------------------+
//                         |            ReRegistration            |
//                         | Socket created with same plugin name |
//                         |                                      |
//                         |                                      |
//    Socket Created       v                                      +        Socket Deleted
// +------------------> Validate +----------> Init +---------> Register +------------------> DeRegister
//                         +                   +                                                +
//                         |                   |                                                |
//                         | Error             | Error                                          |
//                         |                   |                                                |
//                         v                   v                                                v
//                        Out                 Out                                              Out
//
// The pluginwatcher module follows strictly and sequentially this state machine for each *plugin name*.
// e.g: If you are Registering a plugin foo, you cannot get a DeRegister call for plugin foo
//      until the Register("foo") call returns. Nor will you get a Validate("foo", "Different endpoint", ...)
//      call until the Register("foo") call returns.
//
// ReRegistration: Socket created with same plugin name, usually for a plugin update
// e.g: plugin with name foo registers at foo.com/foo-1.9.7 later a plugin with name foo
//      registers at foo.com/foo-1.9.9
//
// DeRegistration: When ReRegistration happens only the deletion of the new socket will trigger a DeRegister call

type PluginHandler interface {
        // Validate returns an error if the information provided by
        // the potential plugin is erroneous (unsupported version, ...)
        ValidatePlugin(pluginName string, endpoint string, versions []string) error
        // Init starts the plugin (e.g: contact the gRPC client, gets plugin
        // specific information, ...) but if another plugin with the same name
        // exists does not switch to the newer one.
        // Any error encountered here can still be Notified to the plugin.
        InitPlugin(pluginName string, endpoint string) error
        // Register is called once the pluginwatcher has notified the plugin
        // of its successful registration.
        // Errors at this point can no longer be bubbled up to the plugin
        RegisterPlugin(pluginName, endpoint string)
        // DeRegister is called once the pluginwatcher observes that the socket has
        // been deleted.
        DeRegisterPlugin(pluginName string)
}
```

```release-note
NONE
```
/sig node
/area hw-accelerators

/cc @jiayingz @vikaschoudhary16 @vishh @vladimirvivien @sbezverk @figo (ccing the main reviewers of the original PR, feel free to cc more people)
This commit is contained in:
Kubernetes Submit Queue 2018-09-06 14:49:39 -07:00 committed by GitHub
commit 4da3bdc4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 809 additions and 501 deletions

View File

@ -95,7 +95,11 @@ type ContainerManager interface {
// GetPodCgroupRoot returns the cgroup which contains all pods.
GetPodCgroupRoot() string
GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn
// GetPluginRegistrationHandler returns a plugin registration handler
// The pluginwatcher's Handlers allow to have a single module for handling
// registration.
GetPluginRegistrationHandler() pluginwatcher.PluginHandler
}
type NodeConfig struct {

View File

@ -605,8 +605,8 @@ func (cm *containerManagerImpl) Start(node *v1.Node,
return nil
}
func (cm *containerManagerImpl) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn {
return cm.deviceManager.GetWatcherCallback()
func (cm *containerManagerImpl) GetPluginRegistrationHandler() pluginwatcher.PluginHandler {
return cm.deviceManager.GetWatcherHandler()
}
// TODO: move the GetResources logic to PodContainerManager.

View File

@ -77,10 +77,8 @@ func (cm *containerManagerStub) GetCapacity() v1.ResourceList {
return c
}
func (cm *containerManagerStub) GetPluginRegistrationHandlerCallback() pluginwatcher.RegisterCallbackFn {
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) {
return nil, nil
}
func (cm *containerManagerStub) GetPluginRegistrationHandler() pluginwatcher.PluginHandler {
return nil
}
func (cm *containerManagerStub) GetDevicePluginResourceCapacity() (v1.ResourceList, v1.ResourceList, []string) {

View File

@ -56,7 +56,7 @@ type ManagerImpl struct {
socketname string
socketdir string
endpoints map[string]endpoint // Key is ResourceName
endpoints map[string]endpointInfo // Key is ResourceName
mutex sync.Mutex
server *grpc.Server
@ -86,10 +86,14 @@ type ManagerImpl struct {
// podDevices contains pod to allocated device mapping.
podDevices podDevices
pluginOpts map[string]*pluginapi.DevicePluginOptions
checkpointManager checkpointmanager.CheckpointManager
}
type endpointInfo struct {
e endpoint
opts *pluginapi.DevicePluginOptions
}
type sourcesReadyStub struct{}
func (s *sourcesReadyStub) AddSource(source string) {}
@ -109,13 +113,13 @@ func newManagerImpl(socketPath string) (*ManagerImpl, error) {
dir, file := filepath.Split(socketPath)
manager := &ManagerImpl{
endpoints: make(map[string]endpoint),
endpoints: make(map[string]endpointInfo),
socketname: file,
socketdir: dir,
healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String),
pluginOpts: make(map[string]*pluginapi.DevicePluginOptions),
podDevices: make(podDevices),
}
manager.callback = manager.genericDeviceUpdateCallback
@ -228,8 +232,8 @@ func (m *ManagerImpl) Start(activePods ActivePodsFunc, sourcesReady config.Sourc
return nil
}
// GetWatcherCallback returns callback function to be registered with plugin watcher
func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn {
// GetWatcherHandler returns the plugin handler
func (m *ManagerImpl) GetWatcherHandler() watcher.PluginHandler {
if f, err := os.Create(m.socketdir + "DEPRECATION"); err != nil {
glog.Errorf("Failed to create deprecation file at %s", m.socketdir)
} else {
@ -237,16 +241,57 @@ func (m *ManagerImpl) GetWatcherCallback() watcher.RegisterCallbackFn {
glog.V(4).Infof("created deprecation file %s", f.Name())
}
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) {
if !m.isVersionCompatibleWithPlugin(versions) {
return nil, fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions)
}
return watcher.PluginHandler(m)
}
if !v1helper.IsExtendedResourceName(v1.ResourceName(name)) {
return nil, fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, name))
}
// ValidatePlugin validates a plugin if the version is correct and the name has the format of an extended resource
func (m *ManagerImpl) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
glog.V(2).Infof("Got Plugin %s at endpoint %s with versions %v", pluginName, endpoint, versions)
return m.addEndpointProbeMode(name, sockPath)
if !m.isVersionCompatibleWithPlugin(versions) {
return fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions)
}
if !v1helper.IsExtendedResourceName(v1.ResourceName(pluginName)) {
return fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, pluginName))
}
return nil
}
// RegisterPlugin starts the endpoint and registers it
// TODO: Start the endpoint and wait for the First ListAndWatch call
// before registering the plugin
func (m *ManagerImpl) RegisterPlugin(pluginName string, endpoint string) error {
glog.V(2).Infof("Registering Plugin %s at endpoint %s", pluginName, endpoint)
e, err := newEndpointImpl(endpoint, pluginName, m.callback)
if err != nil {
return fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", endpoint, err)
}
options, err := e.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{})
if err != nil {
return fmt.Errorf("Failed to get device plugin options: %v", err)
}
m.registerEndpoint(pluginName, options, e)
go m.runEndpoint(pluginName, e)
return nil
}
// DeRegisterPlugin deregisters the plugin
// TODO work on the behavior for deregistering plugins
// e.g: Should we delete the resource
func (m *ManagerImpl) DeRegisterPlugin(pluginName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
// Note: This will mark the resource unhealthy as per the behavior
// in runEndpoint
if eI, ok := m.endpoints[pluginName]; ok {
eI.e.stop()
}
}
@ -333,8 +378,8 @@ func (m *ManagerImpl) Register(ctx context.Context, r *pluginapi.RegisterRequest
func (m *ManagerImpl) Stop() error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, e := range m.endpoints {
e.stop()
for _, eI := range m.endpoints {
eI.e.stop()
}
if m.server == nil {
@ -346,51 +391,26 @@ func (m *ManagerImpl) Stop() error {
return nil
}
func (m *ManagerImpl) addEndpointProbeMode(resourceName string, socketPath string) (chan bool, error) {
chanForAckOfNotification := make(chan bool)
new, err := newEndpointImpl(socketPath, resourceName, m.callback)
if err != nil {
glog.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err)
return nil, fmt.Errorf("Failed to dial device plugin with socketPath %s: %v", socketPath, err)
}
options, err := new.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{})
if err != nil {
glog.Errorf("Failed to get device plugin options: %v", err)
return nil, fmt.Errorf("Failed to get device plugin options: %v", err)
}
m.registerEndpoint(resourceName, options, new)
go func() {
select {
case <-chanForAckOfNotification:
close(chanForAckOfNotification)
m.runEndpoint(resourceName, new)
case <-time.After(time.Second):
glog.Errorf("Timed out while waiting for notification ack from plugin")
}
}()
return chanForAckOfNotification, nil
}
func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e *endpointImpl) {
func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e endpoint) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.pluginOpts[resourceName] = options
m.endpoints[resourceName] = e
m.endpoints[resourceName] = endpointInfo{e: e, opts: options}
glog.V(2).Infof("Registered endpoint %v", e)
}
func (m *ManagerImpl) runEndpoint(resourceName string, e *endpointImpl) {
func (m *ManagerImpl) runEndpoint(resourceName string, e endpoint) {
e.run()
e.stop()
m.mutex.Lock()
defer m.mutex.Unlock()
if old, ok := m.endpoints[resourceName]; ok && old == e {
if old, ok := m.endpoints[resourceName]; ok && old.e == e {
m.markResourceUnhealthy(resourceName)
}
glog.V(2).Infof("Unregistered endpoint %v", e)
glog.V(2).Infof("Endpoint (%s, %v) became unhealthy", resourceName, e)
}
func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) {
@ -437,8 +457,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
deletedResources := sets.NewString()
m.mutex.Lock()
for resourceName, devices := range m.healthyDevices {
e, ok := m.endpoints[resourceName]
if (ok && e.stopGracePeriodExpired()) || !ok {
eI, ok := m.endpoints[resourceName]
if (ok && eI.e.stopGracePeriodExpired()) || !ok {
// The resources contained in endpoints and (un)healthyDevices
// should always be consistent. Otherwise, we run with the risk
// of failing to garbage collect non-existing resources or devices.
@ -455,8 +475,8 @@ func (m *ManagerImpl) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
}
}
for resourceName, devices := range m.unhealthyDevices {
e, ok := m.endpoints[resourceName]
if (ok && e.stopGracePeriodExpired()) || !ok {
eI, ok := m.endpoints[resourceName]
if (ok && eI.e.stopGracePeriodExpired()) || !ok {
if !ok {
glog.Errorf("unexpected: unhealthyDevices and endpoints are out of sync")
}
@ -519,7 +539,7 @@ func (m *ManagerImpl) readCheckpoint() error {
// will stay zero till the corresponding device plugin re-registers.
m.healthyDevices[resource] = sets.NewString()
m.unhealthyDevices[resource] = sets.NewString()
m.endpoints[resource] = newStoppedEndpointImpl(resource)
m.endpoints[resource] = endpointInfo{e: newStoppedEndpointImpl(resource), opts: nil}
}
return nil
}
@ -652,7 +672,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont
// plugin Allocate grpc calls if it becomes common that a container may require
// resources from multiple device plugins.
m.mutex.Lock()
e, ok := m.endpoints[resource]
eI, ok := m.endpoints[resource]
m.mutex.Unlock()
if !ok {
m.mutex.Lock()
@ -665,7 +685,7 @@ func (m *ManagerImpl) allocateContainerResources(pod *v1.Pod, container *v1.Cont
// TODO: refactor this part of code to just append a ContainerAllocationRequest
// in a passed in AllocateRequest pointer, and issues a single Allocate call per pod.
glog.V(3).Infof("Making allocation request for devices %v for device plugin %s", devs, resource)
resp, err := e.allocate(devs)
resp, err := eI.e.allocate(devs)
metrics.DevicePluginAllocationLatency.WithLabelValues(resource).Observe(metrics.SinceInMicroseconds(startRPCTime))
if err != nil {
// In case of allocation failure, we want to restore m.allocatedDevices
@ -715,11 +735,13 @@ func (m *ManagerImpl) GetDeviceRunContainerOptions(pod *v1.Pod, container *v1.Co
// with PreStartRequired option set.
func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource string) error {
m.mutex.Lock()
opts, ok := m.pluginOpts[resource]
eI, ok := m.endpoints[resource]
if !ok {
m.mutex.Unlock()
return fmt.Errorf("Plugin options not found in cache for resource: %s", resource)
} else if opts == nil || !opts.PreStartRequired {
return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource)
}
if eI.opts == nil || !eI.opts.PreStartRequired {
m.mutex.Unlock()
glog.V(4).Infof("Plugin options indicate to skip PreStartContainer for resource: %s", resource)
return nil
@ -731,16 +753,10 @@ func (m *ManagerImpl) callPreStartContainerIfNeeded(podUID, contName, resource s
return fmt.Errorf("no devices found allocated in local cache for pod %s, container %s, resource %s", podUID, contName, resource)
}
e, ok := m.endpoints[resource]
if !ok {
m.mutex.Unlock()
return fmt.Errorf("endpoint not found in cache for a registered resource: %s", resource)
}
m.mutex.Unlock()
devs := devices.UnsortedList()
glog.V(4).Infof("Issuing an PreStartContainer call for container, %s, of pod %s", contName, podUID)
_, err := e.preStartContainer(devs)
_, err := eI.e.preStartContainer(devs)
if err != nil {
return fmt.Errorf("device plugin PreStartContainer rpc failed with err: %v", err)
}

View File

@ -57,9 +57,7 @@ func (h *ManagerStub) GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
return nil, nil, []string{}
}
// GetWatcherCallback returns plugin watcher callback
func (h *ManagerStub) GetWatcherCallback() pluginwatcher.RegisterCallbackFn {
return func(name string, endpoint string, versions []string, sockPath string) (chan bool, error) {
return nil, nil
}
// GetWatcherHandler returns plugin watcher interface
func (h *ManagerStub) GetWatcherHandler() pluginwatcher.PluginHandler {
return nil
}

View File

@ -249,9 +249,10 @@ func setupDevicePlugin(t *testing.T, devs []*pluginapi.Device, pluginSocketName
func setupPluginWatcher(pluginSocketName string, m Manager) *pluginwatcher.Watcher {
w := pluginwatcher.NewWatcher(filepath.Dir(pluginSocketName))
w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherCallback())
w.AddHandler(watcherapi.DevicePlugin, m.GetWatcherHandler())
w.Start()
return &w
return w
}
func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub) {
@ -295,7 +296,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) {
// Expects capacity for resource1 to be 2.
resourceName1 := "domain1.com/resource1"
e1 := &endpointImpl{}
testManager.endpoints[resourceName1] = e1
testManager.endpoints[resourceName1] = endpointInfo{e: e1, opts: nil}
callback(resourceName1, devs)
capacity, allocatable, removedResources := testManager.GetCapacity()
resource1Capacity, ok := capacity[v1.ResourceName(resourceName1)]
@ -345,7 +346,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) {
// Tests adding another resource.
resourceName2 := "resource2"
e2 := &endpointImpl{}
testManager.endpoints[resourceName2] = e2
testManager.endpoints[resourceName2] = endpointInfo{e: e2, opts: nil}
callback(resourceName2, devs)
capacity, allocatable, removedResources = testManager.GetCapacity()
as.Equal(2, len(capacity))
@ -456,7 +457,7 @@ func TestCheckpoint(t *testing.T) {
ckm, err := checkpointmanager.NewCheckpointManager(tmpDir)
as.Nil(err)
testManager := &ManagerImpl{
endpoints: make(map[string]endpoint),
endpoints: make(map[string]endpointInfo),
healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String),
@ -577,7 +578,7 @@ func makePod(limits v1.ResourceList) *v1.Pod {
}
}
func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource, opts map[string]*pluginapi.DevicePluginOptions) (*ManagerImpl, error) {
func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource) (*ManagerImpl, error) {
monitorCallback := func(resourceName string, devices []pluginapi.Device) {}
ckm, err := checkpointmanager.NewCheckpointManager(tmpDir)
if err != nil {
@ -589,41 +590,45 @@ func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestReso
healthyDevices: make(map[string]sets.String),
unhealthyDevices: make(map[string]sets.String),
allocatedDevices: make(map[string]sets.String),
endpoints: make(map[string]endpoint),
pluginOpts: opts,
endpoints: make(map[string]endpointInfo),
podDevices: make(podDevices),
activePods: activePods,
sourcesReady: &sourcesReadyStub{},
checkpointManager: ckm,
}
for _, res := range testRes {
testManager.healthyDevices[res.resourceName] = sets.NewString()
for _, dev := range res.devs {
testManager.healthyDevices[res.resourceName].Insert(dev)
}
if res.resourceName == "domain1.com/resource1" {
testManager.endpoints[res.resourceName] = &MockEndpoint{
allocateFunc: allocateStubFunc(),
testManager.endpoints[res.resourceName] = endpointInfo{
e: &MockEndpoint{allocateFunc: allocateStubFunc()},
opts: nil,
}
}
if res.resourceName == "domain2.com/resource2" {
testManager.endpoints[res.resourceName] = &MockEndpoint{
allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) {
resp := new(pluginapi.ContainerAllocateResponse)
resp.Envs = make(map[string]string)
for _, dev := range devs {
switch dev {
case "dev3":
resp.Envs["key2"] = "val2"
testManager.endpoints[res.resourceName] = endpointInfo{
e: &MockEndpoint{
allocateFunc: func(devs []string) (*pluginapi.AllocateResponse, error) {
resp := new(pluginapi.ContainerAllocateResponse)
resp.Envs = make(map[string]string)
for _, dev := range devs {
switch dev {
case "dev3":
resp.Envs["key2"] = "val2"
case "dev4":
resp.Envs["key2"] = "val3"
case "dev4":
resp.Envs["key2"] = "val3"
}
}
}
resps := new(pluginapi.AllocateResponse)
resps.ContainerResponses = append(resps.ContainerResponses, resp)
return resps, nil
resps := new(pluginapi.AllocateResponse)
resps.ContainerResponses = append(resps.ContainerResponses, resp)
return resps, nil
},
},
opts: nil,
}
}
}
@ -669,10 +674,7 @@ func TestPodContainerDeviceAllocation(t *testing.T) {
as.Nil(err)
defer os.RemoveAll(tmpDir)
nodeInfo := getTestNodeInfo(v1.ResourceList{})
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions)
pluginOpts[res1.resourceName] = nil
pluginOpts[res2.resourceName] = nil
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts)
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources)
as.Nil(err)
testPods := []*v1.Pod{
@ -767,10 +769,8 @@ func TestInitContainerDeviceAllocation(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "checkpoint")
as.Nil(err)
defer os.RemoveAll(tmpDir)
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions)
pluginOpts[res1.resourceName] = nil
pluginOpts[res2.resourceName] = nil
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources, pluginOpts)
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, testResources)
as.Nil(err)
podWithPluginResourcesInInitContainers := &v1.Pod{
@ -904,18 +904,18 @@ func TestDevicePreStartContainer(t *testing.T) {
as.Nil(err)
defer os.RemoveAll(tmpDir)
nodeInfo := getTestNodeInfo(v1.ResourceList{})
pluginOpts := make(map[string]*pluginapi.DevicePluginOptions)
pluginOpts[res1.resourceName] = &pluginapi.DevicePluginOptions{PreStartRequired: true}
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1}, pluginOpts)
testManager, err := getTestManager(tmpDir, podsStub.getActivePods, []TestResource{res1})
as.Nil(err)
ch := make(chan []string, 1)
testManager.endpoints[res1.resourceName] = &MockEndpoint{
initChan: ch,
allocateFunc: allocateStubFunc(),
testManager.endpoints[res1.resourceName] = endpointInfo{
e: &MockEndpoint{
initChan: ch,
allocateFunc: allocateStubFunc(),
},
opts: &pluginapi.DevicePluginOptions{PreStartRequired: true},
}
pod := makePod(v1.ResourceList{
v1.ResourceName(res1.resourceName): res1.resourceQuantity})
activePods := []*v1.Pod{}

View File

@ -53,7 +53,7 @@ type Manager interface {
// GetCapacity returns the amount of available device plugin resource capacity, resource allocatable
// and inactive device plugin resources previously registered on the node.
GetCapacity() (v1.ResourceList, v1.ResourceList, []string)
GetWatcherCallback() watcher.RegisterCallbackFn
GetWatcherHandler() watcher.PluginHandler
}
// DeviceRunContainerOptions contains the combined container runtime settings to consume its allocated devices.

View File

@ -1194,7 +1194,7 @@ type Kubelet struct {
// pluginwatcher is a utility for Kubelet to register different types of node-level plugins
// such as device plugins or CSI plugins. It discovers plugins by monitoring inotify events under the
// directory returned by kubelet.getPluginsDir()
pluginWatcher pluginwatcher.Watcher
pluginWatcher *pluginwatcher.Watcher
// This flag sets a maximum number of images to report in the node status.
nodeStatusMaxImages int32
@ -1365,9 +1365,9 @@ func (kl *Kubelet) initializeRuntimeDependentModules() {
kl.containerLogManager.Start()
if kl.enablePluginsWatcher {
// Adding Registration Callback function for CSI Driver
kl.pluginWatcher.AddHandler("CSIPlugin", csi.RegistrationCallback)
kl.pluginWatcher.AddHandler("CSIPlugin", pluginwatcher.PluginHandler(csi.PluginHandler))
// Adding Registration Callback function for Device Manager
kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandlerCallback())
kl.pluginWatcher.AddHandler(pluginwatcherapi.DevicePlugin, kl.containerManager.GetPluginRegistrationHandler())
// Start the plugin watcher
glog.V(4).Infof("starting watcher")
if err := kl.pluginWatcher.Start(); err != nil {

View File

@ -1,10 +1,4 @@
package(default_visibility = ["//visibility:public"])
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
"go_test",
)
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
@ -12,8 +6,10 @@ go_library(
"example_handler.go",
"example_plugin.go",
"plugin_watcher.go",
"types.go",
],
importpath = "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher",
visibility = ["//visibility:public"],
deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library",
@ -27,6 +23,16 @@ go_library(
],
)
go_test(
name = "go_default_test",
srcs = ["plugin_watcher_test.go"],
embed = [":go_default_library"],
deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//vendor/github.com/stretchr/testify/require:go_default_library",
],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
@ -44,14 +50,3 @@ filegroup(
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["plugin_watcher_test.go"],
embed = [":go_default_library"],
deps = [
"//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library",
"//vendor/github.com/stretchr/testify/require:go_default_library",
],
)

View File

@ -23,6 +23,7 @@ import (
"sync"
"time"
"github.com/golang/glog"
"golang.org/x/net/context"
v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1"
@ -30,41 +31,61 @@ import (
)
type exampleHandler struct {
registeredPlugins map[string]struct{}
mutex sync.Mutex
chanForHandlerAckErrors chan error // for testing
SupportedVersions []string
ExpectedNames map[string]int
eventChans map[string]chan examplePluginEvent // map[pluginName]eventChan
m sync.Mutex
count int
}
type examplePluginEvent int
const (
exampleEventValidate examplePluginEvent = 0
exampleEventRegister examplePluginEvent = 1
exampleEventDeRegister examplePluginEvent = 2
exampleEventError examplePluginEvent = 3
)
// NewExampleHandler provide a example handler
func NewExampleHandler() *exampleHandler {
func NewExampleHandler(supportedVersions []string) *exampleHandler {
return &exampleHandler{
chanForHandlerAckErrors: make(chan error),
registeredPlugins: make(map[string]struct{}),
SupportedVersions: supportedVersions,
ExpectedNames: make(map[string]int),
eventChans: make(map[string]chan examplePluginEvent),
}
}
func (h *exampleHandler) Cleanup() error {
h.mutex.Lock()
defer h.mutex.Unlock()
h.registeredPlugins = make(map[string]struct{})
return nil
}
func (p *exampleHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
p.SendEvent(pluginName, exampleEventValidate)
func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []string, sockPath string) (chan bool, error) {
n, ok := p.DecreasePluginCount(pluginName)
if !ok && n > 0 {
return fmt.Errorf("pluginName('%s') wasn't expected (count is %d)", pluginName, n)
}
// check for supported versions
if !reflect.DeepEqual([]string{"v1beta1", "v1beta2"}, versions) {
return nil, fmt.Errorf("not the supported versions: %s", versions)
if !reflect.DeepEqual(versions, p.SupportedVersions) {
return fmt.Errorf("versions('%v') != supported versions('%v')", versions, p.SupportedVersions)
}
// this handler expects non-empty endpoint as an example
if len(endpoint) == 0 {
return nil, errors.New("expecting non empty endpoint")
return errors.New("expecting non empty endpoint")
}
_, conn, err := dial(sockPath)
return nil
}
func (p *exampleHandler) RegisterPlugin(pluginName, endpoint string) error {
p.SendEvent(pluginName, exampleEventRegister)
// Verifies the grpcServer is ready to serve services.
_, conn, err := dial(endpoint, time.Second)
if err != nil {
return nil, err
return fmt.Errorf("Failed dialing endpoint (%s): %v", endpoint, err)
}
defer conn.Close()
@ -73,33 +94,54 @@ func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []
v1beta2Client := v1beta2.NewExampleClient(conn)
// Tests v1beta1 GetExampleInfo
if _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}); err != nil {
return nil, err
_, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{})
if err != nil {
return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err)
}
// Tests v1beta2 GetExampleInfo
if _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}); err != nil {
return nil, err
// Tests v1beta1 GetExampleInfo
_, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{})
if err != nil {
return fmt.Errorf("Failed GetExampleInfo for v1beta2Client(%s): %v", endpoint, err)
}
// handle registered plugin
h.mutex.Lock()
if _, exist := h.registeredPlugins[pluginName]; exist {
h.mutex.Unlock()
return nil, fmt.Errorf("plugin %s already registered", pluginName)
}
h.registeredPlugins[pluginName] = struct{}{}
h.mutex.Unlock()
chanForAckOfNotification := make(chan bool)
go func() {
select {
case <-chanForAckOfNotification:
// TODO: handle the negative scenario
close(chanForAckOfNotification)
case <-time.After(time.Second):
h.chanForHandlerAckErrors <- errors.New("Timed out while waiting for notification ack")
}
}()
return chanForAckOfNotification, nil
return nil
}
func (p *exampleHandler) DeRegisterPlugin(pluginName string) {
p.SendEvent(pluginName, exampleEventDeRegister)
}
func (p *exampleHandler) EventChan(pluginName string) chan examplePluginEvent {
return p.eventChans[pluginName]
}
func (p *exampleHandler) SendEvent(pluginName string, event examplePluginEvent) {
glog.V(2).Infof("Sending %v for plugin %s over chan %v", event, pluginName, p.eventChans[pluginName])
p.eventChans[pluginName] <- event
}
func (p *exampleHandler) AddPluginName(pluginName string) {
p.m.Lock()
defer p.m.Unlock()
v, ok := p.ExpectedNames[pluginName]
if !ok {
p.eventChans[pluginName] = make(chan examplePluginEvent)
v = 1
}
p.ExpectedNames[pluginName] = v
}
func (p *exampleHandler) DecreasePluginCount(pluginName string) (old int, ok bool) {
p.m.Lock()
defer p.m.Unlock()
v, ok := p.ExpectedNames[pluginName]
if !ok {
v = -1
}
return v, ok
}

View File

@ -18,7 +18,9 @@ package pluginwatcher
import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
@ -39,6 +41,7 @@ type examplePlugin struct {
endpoint string // for testing
pluginName string
pluginType string
versions []string
}
type pluginServiceV1Beta1 struct {
@ -73,12 +76,13 @@ func NewExamplePlugin() *examplePlugin {
}
// NewTestExamplePlugin returns an initialized examplePlugin instance for testing
func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string) *examplePlugin {
func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string, advertisedVersions ...string) *examplePlugin {
return &examplePlugin{
pluginName: pluginName,
pluginType: pluginType,
registrationStatus: make(chan registerapi.RegistrationStatus),
endpoint: endpoint,
versions: advertisedVersions,
registrationStatus: make(chan registerapi.RegistrationStatus),
}
}
@ -88,36 +92,48 @@ func (e *examplePlugin) GetInfo(ctx context.Context, req *registerapi.InfoReques
Type: e.pluginType,
Name: e.pluginName,
Endpoint: e.endpoint,
SupportedVersions: []string{"v1beta1", "v1beta2"},
SupportedVersions: e.versions,
}, nil
}
func (e *examplePlugin) NotifyRegistrationStatus(ctx context.Context, status *registerapi.RegistrationStatus) (*registerapi.RegistrationStatusResponse, error) {
glog.Errorf("Registration is: %v\n", status)
if e.registrationStatus != nil {
e.registrationStatus <- *status
}
if !status.PluginRegistered {
glog.Errorf("Registration failed: %s\n", status.Error)
}
return &registerapi.RegistrationStatusResponse{}, nil
}
// Serve starts example plugin grpc server
func (e *examplePlugin) Serve(socketPath string) error {
glog.Infof("starting example server at: %s\n", socketPath)
lis, err := net.Listen("unix", socketPath)
// Serve starts a pluginwatcher server and one or more of the plugin services
func (e *examplePlugin) Serve(services ...string) error {
glog.Infof("starting example server at: %s\n", e.endpoint)
lis, err := net.Listen("unix", e.endpoint)
if err != nil {
return err
}
glog.Infof("example server started at: %s\n", socketPath)
glog.Infof("example server started at: %s\n", e.endpoint)
e.grpcServer = grpc.NewServer()
// Registers kubelet plugin watcher api.
registerapi.RegisterRegistrationServer(e.grpcServer, e)
// Registers services for both v1beta1 and v1beta2 versions.
v1beta1 := &pluginServiceV1Beta1{server: e}
v1beta1.RegisterService()
v1beta2 := &pluginServiceV1Beta2{server: e}
v1beta2.RegisterService()
for _, service := range services {
switch service {
case "v1beta1":
v1beta1 := &pluginServiceV1Beta1{server: e}
v1beta1.RegisterService()
break
case "v1beta2":
v1beta2 := &pluginServiceV1Beta2{server: e}
v1beta2.RegisterService()
break
default:
return fmt.Errorf("Unsupported service: '%s'", service)
}
}
// Starts service
e.wg.Add(1)
@ -128,22 +144,30 @@ func (e *examplePlugin) Serve(socketPath string) error {
glog.Errorf("example server stopped serving: %v", err)
}
}()
return nil
}
func (e *examplePlugin) Stop() error {
glog.Infof("Stopping example server\n")
glog.Infof("Stopping example server at: %s\n", e.endpoint)
e.grpcServer.Stop()
c := make(chan struct{})
go func() {
defer close(c)
e.wg.Wait()
}()
select {
case <-c:
return nil
break
case <-time.After(time.Second):
glog.Errorf("Timed out on waiting for stop completion")
return errors.New("Timed out on waiting for stop completion")
}
if err := os.Remove(e.endpoint); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

View File

@ -20,6 +20,7 @@ import (
"fmt"
"net"
"os"
"strings"
"sync"
"time"
@ -28,43 +29,144 @@ import (
"github.com/pkg/errors"
"golang.org/x/net/context"
"google.golang.org/grpc"
registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1"
utilfs "k8s.io/kubernetes/pkg/util/filesystem"
)
// RegisterCallbackFn is the type of the callback function that handlers will provide
type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error)
// Watcher is the plugin watcher
type Watcher struct {
path string
handlers map[string]RegisterCallbackFn
stopCh chan interface{}
fs utilfs.Filesystem
fsWatcher *fsnotify.Watcher
wg sync.WaitGroup
mutex sync.Mutex
mutex sync.Mutex
handlers map[string]PluginHandler
plugins map[string]pathInfo
pluginsPool map[string]map[string]*sync.Mutex // map[pluginType][pluginName]
}
type pathInfo struct {
pluginType string
pluginName string
}
// NewWatcher provides a new watcher
func NewWatcher(sockDir string) Watcher {
return Watcher{
path: sockDir,
handlers: make(map[string]RegisterCallbackFn),
fs: &utilfs.DefaultFs{},
func NewWatcher(sockDir string) *Watcher {
return &Watcher{
path: sockDir,
fs: &utilfs.DefaultFs{},
handlers: make(map[string]PluginHandler),
plugins: make(map[string]pathInfo),
pluginsPool: make(map[string]map[string]*sync.Mutex),
}
}
// AddHandler registers a callback to be invoked for a particular type of plugin
func (w *Watcher) AddHandler(pluginType string, handlerCbkFn RegisterCallbackFn) {
func (w *Watcher) AddHandler(pluginType string, handler PluginHandler) {
w.mutex.Lock()
defer w.mutex.Unlock()
w.handlers[pluginType] = handlerCbkFn
w.handlers[pluginType] = handler
}
// Creates the plugin directory, if it doesn't already exist.
func (w *Watcher) createPluginDir() error {
func (w *Watcher) getHandler(pluginType string) (PluginHandler, bool) {
w.mutex.Lock()
defer w.mutex.Unlock()
h, ok := w.handlers[pluginType]
return h, ok
}
// Start watches for the creation of plugin sockets at the path
func (w *Watcher) Start() error {
glog.V(2).Infof("Plugin Watcher Start at %s", w.path)
w.stopCh = make(chan interface{})
// Creating the directory to be watched if it doesn't exist yet,
// and walks through the directory to discover the existing plugins.
if err := w.init(); err != nil {
return err
}
fsWatcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err)
}
w.fsWatcher = fsWatcher
w.wg.Add(1)
go func(fsWatcher *fsnotify.Watcher) {
defer w.wg.Done()
for {
select {
case event := <-fsWatcher.Events:
//TODO: Handle errors by taking corrective measures
w.wg.Add(1)
go func() {
defer w.wg.Done()
if event.Op&fsnotify.Create == fsnotify.Create {
err := w.handleCreateEvent(event)
if err != nil {
glog.Errorf("error %v when handling create event: %s", err, event)
}
} else if event.Op&fsnotify.Remove == fsnotify.Remove {
err := w.handleDeleteEvent(event)
if err != nil {
glog.Errorf("error %v when handling delete event: %s", err, event)
}
}
return
}()
continue
case err := <-fsWatcher.Errors:
if err != nil {
glog.Errorf("fsWatcher received error: %v", err)
}
continue
case <-w.stopCh:
return
}
}
}(fsWatcher)
// Traverse plugin dir after starting the plugin processing goroutine
if err := w.traversePluginDir(w.path); err != nil {
w.Stop()
return fmt.Errorf("failed to traverse plugin socket path, err: %v", err)
}
return nil
}
// Stop stops probing the creation of plugin sockets at the path
func (w *Watcher) Stop() error {
close(w.stopCh)
c := make(chan struct{})
go func() {
defer close(c)
w.wg.Wait()
}()
select {
case <-c:
case <-time.After(11 * time.Second):
return fmt.Errorf("timeout on stopping watcher")
}
w.fsWatcher.Close()
return nil
}
func (w *Watcher) init() error {
glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path)
if err := w.fs.MkdirAll(w.path, 0755); err != nil {
return fmt.Errorf("error (re-)creating root %s: %v", w.path, err)
}
@ -91,22 +193,38 @@ func (w *Watcher) traversePluginDir(dir string) error {
Op: fsnotify.Create,
}
}()
default:
glog.V(5).Infof("Ignoring file %s with mode %v", path, mode)
}
return nil
})
}
func (w *Watcher) init() error {
if err := w.createPluginDir(); err != nil {
return err
// Handle filesystem notify event.
func (w *Watcher) handleCreateEvent(event fsnotify.Event) error {
glog.V(6).Infof("Handling create event: %v", event)
fi, err := os.Stat(event.Name)
if err != nil {
return fmt.Errorf("stat file %s failed: %v", event.Name, err)
}
return nil
if strings.HasPrefix(fi.Name(), ".") {
glog.Errorf("Ignoring file: %s", fi.Name())
return nil
}
if !fi.IsDir() {
return w.handlePluginRegistration(event.Name)
}
return w.traversePluginDir(event.Name)
}
func (w *Watcher) registerPlugin(socketPath string) error {
func (w *Watcher) handlePluginRegistration(socketPath string) error {
//TODO: Implement rate limiting to mitigate any DOS kind of attacks.
client, conn, err := dial(socketPath)
client, conn, err := dial(socketPath, 10*time.Second)
if err != nil {
return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err)
}
@ -114,154 +232,161 @@ func (w *Watcher) registerPlugin(socketPath string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
infoResp, err := client.GetInfo(ctx, &registerapi.InfoRequest{})
if err != nil {
return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err)
}
return w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath)
}
func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, client registerapi.RegistrationClient, infoResp *registerapi.PluginInfo, socketPath string) error {
var handlerCbkFn RegisterCallbackFn
var ok bool
handlerCbkFn, ok = w.handlers[infoResp.Type]
handler, ok := w.handlers[infoResp.Type]
if !ok {
errStr := fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath)
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{
PluginRegistered: false,
Error: errStr,
}); err != nil {
return errors.Wrap(err, errStr)
}
return errors.New(errStr)
return w.notifyPlugin(client, false, fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath))
}
var versions []string
for _, version := range infoResp.SupportedVersions {
versions = append(versions, version)
// ReRegistration: We want to handle multiple plugins registering at the same time with the same name sequentially.
// See the state machine for more information.
// This is done by using a Lock for each plugin with the same name and type
pool := w.getPluginPool(infoResp.Type, infoResp.Name)
pool.Lock()
defer pool.Unlock()
if infoResp.Endpoint == "" {
infoResp.Endpoint = socketPath
}
// calls handler callback to verify registration request
chanForAckOfNotification, err := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath)
if err != nil {
errStr := fmt.Sprintf("plugin registration failed with err: %v", err)
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{
PluginRegistered: false,
Error: errStr,
}); err != nil {
return errors.Wrap(err, errStr)
}
return errors.New(errStr)
if err := handler.ValidatePlugin(infoResp.Name, infoResp.Endpoint, infoResp.SupportedVersions); err != nil {
return w.notifyPlugin(client, false, fmt.Sprintf("plugin validation failed with err: %v", err))
}
if _, err := client.NotifyRegistrationStatus(ctx, &registerapi.RegistrationStatus{
PluginRegistered: true,
}); err != nil {
chanForAckOfNotification <- false
// We add the plugin to the pluginwatcher's map before calling a plugin consumer's Register handle
// so that if we receive a delete event during Register Plugin, we can process it as a DeRegister call.
w.registerPlugin(socketPath, infoResp.Type, infoResp.Name)
if err := handler.RegisterPlugin(infoResp.Name, infoResp.Endpoint); err != nil {
return w.notifyPlugin(client, false, fmt.Sprintf("plugin registration failed with err: %v", err))
}
// Notify is called after register to guarantee that even if notify throws an error Register will always be called after validate
if err := w.notifyPlugin(client, true, ""); err != nil {
return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err)
}
chanForAckOfNotification <- true
return nil
}
// Handle filesystem notify event.
func (w *Watcher) handleFsNotifyEvent(event fsnotify.Event) error {
if event.Op&fsnotify.Create != fsnotify.Create {
func (w *Watcher) handleDeleteEvent(event fsnotify.Event) error {
glog.V(6).Infof("Handling delete event: %v", event)
plugin, ok := w.getPlugin(event.Name)
if !ok {
return fmt.Errorf("could not find plugin for deleted file %s", event.Name)
}
// You should not get a Deregister call while registering a plugin
pool := w.getPluginPool(plugin.pluginType, plugin.pluginName)
pool.Lock()
defer pool.Unlock()
// ReRegisteration: When waiting for the lock a plugin with the same name (not socketPath) could have registered
// In that case, we don't want to issue a DeRegister call for that plugin
// When ReRegistering, the new plugin will have removed the current mapping (map[socketPath] = plugin) and replaced
// it with it's own socketPath.
if _, ok = w.getPlugin(event.Name); !ok {
glog.V(2).Infof("A newer plugin watcher has been registered for plugin %v, dropping DeRegister call", plugin)
return nil
}
fi, err := os.Stat(event.Name)
if err != nil {
return fmt.Errorf("stat file %s failed: %v", event.Name, err)
h, ok := w.getHandler(plugin.pluginType)
if !ok {
return fmt.Errorf("could not find handler %s for plugin %s at path %s", plugin.pluginType, plugin.pluginName, event.Name)
}
if !fi.IsDir() {
return w.registerPlugin(event.Name)
}
if err := w.traversePluginDir(event.Name); err != nil {
return fmt.Errorf("failed to traverse plugin path %s, err: %v", event.Name, err)
}
glog.V(2).Infof("DeRegistering plugin %v at path %s", plugin, event.Name)
w.deRegisterPlugin(event.Name, plugin.pluginType, plugin.pluginName)
h.DeRegisterPlugin(plugin.pluginName)
return nil
}
// Start watches for the creation of plugin sockets at the path
func (w *Watcher) Start() error {
glog.V(2).Infof("Plugin Watcher Start at %s", w.path)
w.stopCh = make(chan interface{})
func (w *Watcher) registerPlugin(socketPath, pluginType, pluginName string) {
w.mutex.Lock()
defer w.mutex.Unlock()
// Creating the directory to be watched if it doesn't exist yet,
// and walks through the directory to discover the existing plugins.
if err := w.init(); err != nil {
return err
}
fsWatcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err)
}
w.fsWatcher = fsWatcher
if err := w.traversePluginDir(w.path); err != nil {
fsWatcher.Close()
return fmt.Errorf("failed to traverse plugin socket path, err: %v", err)
}
w.wg.Add(1)
go func(fsWatcher *fsnotify.Watcher) {
defer w.wg.Done()
for {
select {
case event := <-fsWatcher.Events:
//TODO: Handle errors by taking corrective measures
go func() {
err := w.handleFsNotifyEvent(event)
if err != nil {
glog.Errorf("error %v when handle event: %s", err, event)
}
}()
continue
case err := <-fsWatcher.Errors:
if err != nil {
glog.Errorf("fsWatcher received error: %v", err)
}
continue
case <-w.stopCh:
fsWatcher.Close()
return
}
// Reregistration case, if this plugin is already in the map, remove it
// This will prevent handleDeleteEvent to issue a DeRegister call
for path, info := range w.plugins {
if info.pluginType != pluginType || info.pluginName != pluginName {
continue
}
}(fsWatcher)
return nil
}
// Stop stops probing the creation of plugin sockets at the path
func (w *Watcher) Stop() error {
close(w.stopCh)
c := make(chan struct{})
go func() {
defer close(c)
w.wg.Wait()
}()
select {
case <-c:
case <-time.After(10 * time.Second):
return fmt.Errorf("timeout on stopping watcher")
delete(w.plugins, path)
break
}
w.plugins[socketPath] = pathInfo{
pluginType: pluginType,
pluginName: pluginName,
}
return nil
}
// Cleanup cleans the path by removing sockets
func (w *Watcher) Cleanup() error {
return os.RemoveAll(w.path)
func (w *Watcher) deRegisterPlugin(socketPath, pluginType, pluginName string) {
w.mutex.Lock()
defer w.mutex.Unlock()
delete(w.plugins, socketPath)
delete(w.pluginsPool[pluginType], pluginName)
}
func (w *Watcher) getPlugin(socketPath string) (pathInfo, bool) {
w.mutex.Lock()
defer w.mutex.Unlock()
plugin, ok := w.plugins[socketPath]
return plugin, ok
}
func (w *Watcher) getPluginPool(pluginType, pluginName string) *sync.Mutex {
w.mutex.Lock()
defer w.mutex.Unlock()
if _, ok := w.pluginsPool[pluginType]; !ok {
w.pluginsPool[pluginType] = make(map[string]*sync.Mutex)
}
if _, ok := w.pluginsPool[pluginType][pluginName]; !ok {
w.pluginsPool[pluginType][pluginName] = &sync.Mutex{}
}
return w.pluginsPool[pluginType][pluginName]
}
func (w *Watcher) notifyPlugin(client registerapi.RegistrationClient, registered bool, errStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status := &registerapi.RegistrationStatus{
PluginRegistered: registered,
Error: errStr,
}
if _, err := client.NotifyRegistrationStatus(ctx, status); err != nil {
return errors.Wrap(err, errStr)
}
if errStr != "" {
return errors.New(errStr)
}
return nil
}
// Dial establishes the gRPC communication with the picked up plugin socket. https://godoc.org/google.golang.org/grpc#Dial
func dial(unixSocketPath string) (registerapi.RegistrationClient, *grpc.ClientConn, error) {
func dial(unixSocketPath string, timeout time.Duration) (registerapi.RegistrationClient, *grpc.ClientConn, error) {
c, err := grpc.Dial(unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(),
grpc.WithTimeout(10*time.Second),
grpc.WithTimeout(timeout),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}),

View File

@ -17,192 +17,222 @@ limitations under the License.
package pluginwatcher
import (
"errors"
"flag"
"fmt"
"io/ioutil"
"path/filepath"
"strconv"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1"
)
// helper function
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
var (
socketDir string
supportedVersions = []string{"v1beta1", "v1beta2"}
)
func init() {
var logLevel string
flag.Set("alsologtostderr", fmt.Sprintf("%t", true))
flag.StringVar(&logLevel, "logLevel", "6", "test")
flag.Lookup("v").Value.Set(logLevel)
d, err := ioutil.TempDir("", "plugin_test")
if err != nil {
panic(fmt.Sprintf("Could not create a temp directory: %s", d))
}
socketDir = d
}
func cleanup(t *testing.T) {
require.NoError(t, os.RemoveAll(socketDir))
os.MkdirAll(socketDir, 0755)
}
func TestPluginRegistration(t *testing.T) {
defer cleanup(t)
hdlr := NewExampleHandler(supportedVersions)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
for i := 0; i < 10; i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
pluginName := fmt.Sprintf("example-plugin-%d", i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
require.NoError(t, p.Stop())
require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(p.pluginName)))
}
}
func TestPluginReRegistration(t *testing.T) {
defer cleanup(t)
pluginName := fmt.Sprintf("example-plugin")
hdlr := NewExampleHandler(supportedVersions)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
plugins := make([]*examplePlugin, 10)
for i := 0; i < 10; i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
plugins[i] = p
}
plugins[len(plugins)-1].Stop()
require.True(t, waitForEvent(t, exampleEventDeRegister, hdlr.EventChan(pluginName)))
close(hdlr.EventChan(pluginName))
for i := 0; i < len(plugins)-1; i++ {
plugins[i].Stop()
}
}
func TestPluginRegistrationAtKubeletStart(t *testing.T) {
defer cleanup(t)
hdlr := NewExampleHandler(supportedVersions)
plugins := make([]*examplePlugin, 10)
for i := 0; i < len(plugins); i++ {
socketPath := fmt.Sprintf("%s/plugin-%d.sock", socketDir, i)
pluginName := fmt.Sprintf("example-plugin-%d", i)
hdlr.AddPluginName(pluginName)
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, supportedVersions...)
require.NoError(t, p.Serve("v1beta1", "v1beta2"))
defer func(p *examplePlugin) { require.NoError(t, p.Stop()) }(p)
plugins[i] = p
}
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
var wg sync.WaitGroup
for i := 0; i < len(plugins); i++ {
wg.Add(1)
go func(p *examplePlugin) {
defer wg.Done()
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.True(t, waitForEvent(t, exampleEventRegister, hdlr.EventChan(p.pluginName)))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
}(plugins[i])
}
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}
func TestExamplePlugin(t *testing.T) {
rootDir, err := ioutil.TempDir("", "plugin_test")
require.NoError(t, err)
w := NewWatcher(rootDir)
h := NewExampleHandler()
w.AddHandler(registerapi.DevicePlugin, h.Handler)
require.NoError(t, w.Start())
socketPath := filepath.Join(rootDir, "plugin.sock")
PluginName := "example-plugin"
// handler expecting plugin has a non-empty endpoint
p := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "")
require.NoError(t, p.Serve(socketPath))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
require.NoError(t, p.Stop())
p = NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint")
require.NoError(t, p.Serve(socketPath))
require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
// Trying to start a plugin service at the same socket path should fail
// with "bind: address already in use"
require.NotNil(t, p.Serve(socketPath))
// grpcServer.Stop() will remove the socket and starting plugin service
// at the same path again should succeeds and trigger another callback.
require.NoError(t, p.Stop())
require.Nil(t, p.Serve(socketPath))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
// Starting another plugin with the same name got verification error.
p2 := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint")
socketPath2 := filepath.Join(rootDir, "plugin2.sock")
require.NoError(t, p2.Serve(socketPath2))
require.False(t, waitForPluginRegistrationStatus(t, p2.registrationStatus))
// Restarts plugin watcher should traverse the socket directory and issues a
// callback for every existing socket.
require.NoError(t, w.Stop())
require.NoError(t, h.Cleanup())
require.NoError(t, w.Start())
var wg sync.WaitGroup
wg.Add(2)
var pStatus string
var p2Status string
go func() {
pStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, p.registrationStatus))
wg.Done()
}()
go func() {
p2Status = strconv.FormatBool(waitForPluginRegistrationStatus(t, p2.registrationStatus))
wg.Done()
}()
if waitTimeout(&wg, 2*time.Second) {
t.Fatalf("Timed out waiting for wait group")
}
expectedSet := sets.NewString()
expectedSet.Insert("true", "false")
actualSet := sets.NewString()
actualSet.Insert(pStatus, p2Status)
require.Equal(t, expectedSet, actualSet)
select {
case err := <-h.chanForHandlerAckErrors:
t.Fatalf("%v", err)
return
case <-time.After(2 * time.Second):
t.Fatalf("Timeout while waiting for the plugin registration status")
}
require.NoError(t, w.Stop())
require.NoError(t, w.Cleanup())
}
func TestPluginWithSubDir(t *testing.T) {
rootDir, err := ioutil.TempDir("", "plugin_test")
require.NoError(t, err)
func TestPluginRegistrationFailureWithUnsupportedVersion(t *testing.T) {
defer cleanup(t)
w := NewWatcher(rootDir)
hcsi := NewExampleHandler()
hdp := NewExampleHandler()
pluginName := fmt.Sprintf("example-plugin")
socketPath := socketDir + "/plugin.sock"
w.AddHandler(registerapi.CSIPlugin, hcsi.Handler)
w.AddHandler(registerapi.DevicePlugin, hdp.Handler)
hdlr := NewExampleHandler(supportedVersions)
hdlr.AddPluginName(pluginName)
err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.DevicePlugin), 0755)
require.NoError(t, err)
err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.CSIPlugin), 0755)
require.NoError(t, err)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
dpSocketPath := filepath.Join(rootDir, registerapi.DevicePlugin, "plugin.sock")
csiSocketPath := filepath.Join(rootDir, registerapi.CSIPlugin, "plugin.sock")
// Advertise v1beta3 but don't serve anything else than the plugin service
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3")
require.NoError(t, p.Serve())
defer func() { require.NoError(t, p.Stop()) }()
require.NoError(t, w.Start())
// two plugins using the same name but with different type
dp := NewTestExamplePlugin("exampleplugin", registerapi.DevicePlugin, "example-endpoint")
require.NoError(t, dp.Serve(dpSocketPath))
require.True(t, waitForPluginRegistrationStatus(t, dp.registrationStatus))
csi := NewTestExamplePlugin("exampleplugin", registerapi.CSIPlugin, "example-endpoint")
require.NoError(t, csi.Serve(csiSocketPath))
require.True(t, waitForPluginRegistrationStatus(t, csi.registrationStatus))
// Restarts plugin watcher should traverse the socket directory and issues a
// callback for every existing socket.
require.NoError(t, w.Stop())
require.NoError(t, hcsi.Cleanup())
require.NoError(t, hdp.Cleanup())
require.NoError(t, w.Start())
var wg sync.WaitGroup
wg.Add(2)
var dpStatus string
var csiStatus string
go func() {
dpStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, dp.registrationStatus))
wg.Done()
}()
go func() {
csiStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, csi.registrationStatus))
wg.Done()
}()
if waitTimeout(&wg, 4*time.Second) {
require.NoError(t, errors.New("Timed out waiting for wait group"))
}
expectedSet := sets.NewString()
expectedSet.Insert("true", "true")
actualSet := sets.NewString()
actualSet.Insert(dpStatus, csiStatus)
require.Equal(t, expectedSet, actualSet)
select {
case err := <-hcsi.chanForHandlerAckErrors:
t.Fatalf("%v", err)
case err := <-hdp.chanForHandlerAckErrors:
t.Fatalf("%v", err)
case <-time.After(4 * time.Second):
}
require.NoError(t, w.Stop())
require.NoError(t, w.Cleanup())
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
}
func waitForPluginRegistrationStatus(t *testing.T, statusCh chan registerapi.RegistrationStatus) bool {
func TestPlugiRegistrationFailureWithUnsupportedVersionAtKubeletStart(t *testing.T) {
defer cleanup(t)
pluginName := fmt.Sprintf("example-plugin")
socketPath := socketDir + "/plugin.sock"
// Advertise v1beta3 but don't serve anything else than the plugin service
p := NewTestExamplePlugin(pluginName, registerapi.DevicePlugin, socketPath, "v1beta3")
require.NoError(t, p.Serve())
defer func() { require.NoError(t, p.Stop()) }()
hdlr := NewExampleHandler(supportedVersions)
hdlr.AddPluginName(pluginName)
w := newWatcherWithHandler(t, hdlr)
defer func() { require.NoError(t, w.Stop()) }()
require.True(t, waitForEvent(t, exampleEventValidate, hdlr.EventChan(p.pluginName)))
require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus))
}
func waitForPluginRegistrationStatus(t *testing.T, statusChan chan registerapi.RegistrationStatus) bool {
select {
case status := <-statusCh:
case status := <-statusChan:
return status.PluginRegistered
case <-time.After(10 * time.Second):
t.Fatalf("Timed out while waiting for registration status")
}
return false
}
func waitForEvent(t *testing.T, expected examplePluginEvent, eventChan chan examplePluginEvent) bool {
select {
case event := <-eventChan:
return event == expected
case <-time.After(2 * time.Second):
t.Fatalf("Timed out while waiting for registration status %v", expected)
}
return false
}
func newWatcherWithHandler(t *testing.T, hdlr PluginHandler) *Watcher {
w := NewWatcher(socketDir)
w.AddHandler(registerapi.DevicePlugin, hdlr)
require.NoError(t, w.Start())
return w
}

View File

@ -0,0 +1,59 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package pluginwatcher
// PluginHandler is an interface a client of the pluginwatcher API needs to implement in
// order to consume plugins
// The PluginHandler follows the simple following state machine:
//
// +--------------------------------------+
// | ReRegistration |
// | Socket created with same plugin name |
// | |
// | |
// Socket Created v + Socket Deleted
// +------------------> Validate +---------------------------> Register +------------------> DeRegister
// + + +
// | | |
// | Error | Error |
// | | |
// v v v
// Out Out Out
//
// The pluginwatcher module follows strictly and sequentially this state machine for each *plugin name*.
// e.g: If you are Registering a plugin foo, you cannot get a DeRegister call for plugin foo
// until the Register("foo") call returns. Nor will you get a Validate("foo", "Different endpoint", ...)
// call until the Register("foo") call returns.
//
// ReRegistration: Socket created with same plugin name, usually for a plugin update
// e.g: plugin with name foo registers at foo.com/foo-1.9.7 later a plugin with name foo
// registers at foo.com/foo-1.9.9
//
// DeRegistration: When ReRegistration happens only the deletion of the new socket will trigger a DeRegister call
type PluginHandler interface {
// Validate returns an error if the information provided by
// the potential plugin is erroneous (unsupported version, ...)
ValidatePlugin(pluginName string, endpoint string, versions []string) error
// RegisterPlugin is called so that the plugin can be register by any
// plugin consumer
// Error encountered here can still be Notified to the plugin.
RegisterPlugin(pluginName, endpoint string) error
// DeRegister is called once the pluginwatcher observes that the socket has
// been deleted.
DeRegisterPlugin(pluginName string)
}

View File

@ -28,6 +28,7 @@ import (
"context"
"github.com/golang/glog"
api "k8s.io/api/core/v1"
apierrs "k8s.io/apimachinery/pkg/api/errors"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -89,6 +90,10 @@ type csiDriversStore struct {
sync.RWMutex
}
// RegistrationHandler is the handler which is fed to the pluginwatcher API.
type RegistrationHandler struct {
}
// TODO (verult) consider using a struct instead of global variables
// csiDrivers map keep track of all registered CSI drivers on the node and their
// corresponding sockets
@ -96,21 +101,28 @@ var csiDrivers csiDriversStore
var nodeUpdater nodeupdater.Interface
// RegistrationCallback is called by kubelet's plugin watcher upon detection
// PluginHandler is the plugin registration handler interface passed to the
// pluginwatcher module in kubelet
var PluginHandler = &RegistrationHandler{}
// ValidatePlugin is called by kubelet's plugin watcher upon detection
// of a new registration socket opened by CSI Driver registrar side car.
func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) {
func (h *RegistrationHandler) ValidatePlugin(pluginName string, endpoint string, versions []string) error {
glog.Infof(log("Trying to register a new plugin with name: %s endpoint: %s versions: %s",
pluginName, endpoint, strings.Join(versions, ",")))
glog.Infof(log("Callback from kubelet with plugin name: %s endpoint: %s versions: %s socket path: %s",
pluginName, endpoint, strings.Join(versions, ","), socketPath))
return nil
}
if endpoint == "" {
endpoint = socketPath
}
// RegisterPlugin is called when a plugin can be registered
func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string) error {
glog.Infof(log("Register new plugin with name: %s at endpoint: %s", pluginName, endpoint))
// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
// all other CSI components will be able to get the actual socket of CSI drivers by its name.
csiDrivers.Lock()
defer csiDrivers.Unlock()
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint}
// Get node info from the driver.
@ -118,22 +130,27 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string,
// TODO (verult) retry with exponential backoff, possibly added in csi client library.
ctx, cancel := context.WithTimeout(context.Background(), csiTimeout)
defer cancel()
driverNodeID, maxVolumePerNode, _, err := csi.NodeGetInfo(ctx)
if err != nil {
return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err)
return fmt.Errorf("error during CSI NodeGetInfo() call: %v", err)
}
// Calling nodeLabelManager to update annotations and labels for newly registered CSI driver
err = nodeUpdater.AddLabelsAndLimits(pluginName, driverNodeID, maxVolumePerNode)
if err != nil {
// Unregister the driver and return error
csiDrivers.Lock()
defer csiDrivers.Unlock()
delete(csiDrivers.driversMap, pluginName)
return nil, err
return fmt.Errorf("error while adding CSI labels: %v", err)
}
return nil, nil
return nil
}
// DeRegisterPlugin is called when a plugin removed it's socket, signaling
// it is no longer available
// TODO: Handle DeRegistration
func (h *RegistrationHandler) DeRegisterPlugin(pluginName string) {
}
func (p *csiPlugin) Init(host volume.VolumeHost) error {