diff --git a/pkg/kubelet/cm/devicemanager/endpoint.go b/pkg/kubelet/cm/devicemanager/endpoint.go index 0a70884b8de..8ae98d3ffca 100644 --- a/pkg/kubelet/cm/devicemanager/endpoint.go +++ b/pkg/kubelet/cm/devicemanager/endpoint.go @@ -19,60 +19,40 @@ package devicemanager import ( "context" "fmt" - "net" "sync" "time" - "google.golang.org/grpc" - "k8s.io/klog/v2" - pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + plugin "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/plugin/v1beta1" ) // endpoint maps to a single registered device plugin. It is responsible // for managing gRPC communications with the device plugin and caching // device states reported by the device plugin. type endpoint interface { - run() - stop() getPreferredAllocation(available, mustInclude []string, size int) (*pluginapi.PreferredAllocationResponse, error) allocate(devs []string) (*pluginapi.AllocateResponse, error) preStartContainer(devs []string) (*pluginapi.PreStartContainerResponse, error) - callback(resourceName string, devices []pluginapi.Device) + setStopTime(t time.Time) isStopped() bool stopGracePeriodExpired() bool } type endpointImpl struct { - client pluginapi.DevicePluginClient - clientConn *grpc.ClientConn - - socketPath string + mutex sync.Mutex resourceName string + api pluginapi.DevicePluginClient stopTime time.Time - - mutex sync.Mutex - cb monitorCallback + client plugin.Client // for testing only } // newEndpointImpl creates a new endpoint for the given resourceName. // This is to be used during normal device plugin registration. -func newEndpointImpl(socketPath, resourceName string, callback monitorCallback) (*endpointImpl, error) { - client, c, err := dial(socketPath) - if err != nil { - klog.ErrorS(err, "Can't create new endpoint with socket path", "path", socketPath) - return nil, err - } - +func newEndpointImpl(p plugin.DevicePlugin) *endpointImpl { return &endpointImpl{ - client: client, - clientConn: c, - - socketPath: socketPath, - resourceName: resourceName, - - cb: callback, - }, nil + api: p.Api(), + resourceName: p.Resource(), + } } // newStoppedEndpointImpl creates a new endpoint for the given resourceName with stopTime set. @@ -84,42 +64,6 @@ func newStoppedEndpointImpl(resourceName string) *endpointImpl { } } -func (e *endpointImpl) callback(resourceName string, devices []pluginapi.Device) { - e.cb(resourceName, devices) -} - -// run initializes ListAndWatch gRPC call for the device plugin and -// blocks on receiving ListAndWatch gRPC stream updates. Each ListAndWatch -// stream update contains a new list of device states. -// It then issues a callback to pass this information to the device manager which -// will adjust the resource available information accordingly. -func (e *endpointImpl) run() { - stream, err := e.client.ListAndWatch(context.Background(), &pluginapi.Empty{}) - if err != nil { - klog.ErrorS(err, "listAndWatch ended unexpectedly for device plugin", "resourceName", e.resourceName) - - return - } - - for { - response, err := stream.Recv() - if err != nil { - klog.ErrorS(err, "listAndWatch ended unexpectedly for device plugin", "resourceName", e.resourceName) - return - } - - devs := response.Devices - klog.V(2).InfoS("State pushed for device plugin", "resourceName", e.resourceName, "resourceCapacity", len(devs)) - - var newDevs []pluginapi.Device - for _, d := range devs { - newDevs = append(newDevs, *d) - } - - e.callback(e.resourceName, newDevs) - } -} - func (e *endpointImpl) isStopped() bool { e.mutex.Lock() defer e.mutex.Unlock() @@ -132,7 +76,6 @@ func (e *endpointImpl) stopGracePeriodExpired() bool { return !e.stopTime.IsZero() && time.Since(e.stopTime) > endpointStopGracePeriod } -// used for testing only func (e *endpointImpl) setStopTime(t time.Time) { e.mutex.Lock() defer e.mutex.Unlock() @@ -144,7 +87,7 @@ func (e *endpointImpl) getPreferredAllocation(available, mustInclude []string, s if e.isStopped() { return nil, fmt.Errorf(errEndpointStopped, e) } - return e.client.GetPreferredAllocation(context.Background(), &pluginapi.PreferredAllocationRequest{ + return e.api.GetPreferredAllocation(context.Background(), &pluginapi.PreferredAllocationRequest{ ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ { AvailableDeviceIDs: available, @@ -160,7 +103,7 @@ func (e *endpointImpl) allocate(devs []string) (*pluginapi.AllocateResponse, err if e.isStopped() { return nil, fmt.Errorf(errEndpointStopped, e) } - return e.client.Allocate(context.Background(), &pluginapi.AllocateRequest{ + return e.api.Allocate(context.Background(), &pluginapi.AllocateRequest{ ContainerRequests: []*pluginapi.ContainerAllocateRequest{ {DevicesIDs: devs}, }, @@ -174,34 +117,7 @@ func (e *endpointImpl) preStartContainer(devs []string) (*pluginapi.PreStartCont } ctx, cancel := context.WithTimeout(context.Background(), pluginapi.KubeletPreStartContainerRPCTimeoutInSecs*time.Second) defer cancel() - return e.client.PreStartContainer(ctx, &pluginapi.PreStartContainerRequest{ + return e.api.PreStartContainer(ctx, &pluginapi.PreStartContainerRequest{ DevicesIDs: devs, }) } - -func (e *endpointImpl) stop() { - e.mutex.Lock() - defer e.mutex.Unlock() - if e.clientConn != nil { - e.clientConn.Close() - } - e.stopTime = time.Now() -} - -// dial establishes the gRPC communication with the registered device plugin. https://godoc.org/google.golang.org/grpc#Dial -func dial(unixSocketPath string) (pluginapi.DevicePluginClient, *grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - c, err := grpc.DialContext(ctx, unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, "unix", addr) - }), - ) - - if err != nil { - return nil, nil, fmt.Errorf(errFailedToDialDevicePlugin+" %v", err) - } - - return pluginapi.NewDevicePluginClient(c), c, nil -} diff --git a/pkg/kubelet/cm/devicemanager/endpoint_test.go b/pkg/kubelet/cm/devicemanager/endpoint_test.go index f20550985a1..419002f06c0 100644 --- a/pkg/kubelet/cm/devicemanager/endpoint_test.go +++ b/pkg/kubelet/cm/devicemanager/endpoint_test.go @@ -19,14 +19,53 @@ package devicemanager import ( "fmt" "path" + "sync" "testing" "time" "github.com/stretchr/testify/require" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + plugin "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/plugin/v1beta1" ) +// monitorCallback is the function called when a device's health state changes, +// or new devices are reported, or old devices are deleted. +// Updated contains the most recent state of the Device. +type monitorCallback func(resourceName string, devices []pluginapi.Device) + +func newMockPluginManager() *mockPluginManager { + return &mockPluginManager{ + func(string) error { return nil }, + func(string, plugin.DevicePlugin) error { return nil }, + func(string) {}, + func(string, *pluginapi.ListAndWatchResponse) {}, + } +} + +type mockPluginManager struct { + cleanupPluginDirectory func(string) error + pluginConnected func(string, plugin.DevicePlugin) error + pluginDisconnected func(string) + pluginListAndWatchReceiver func(string, *pluginapi.ListAndWatchResponse) +} + +func (m *mockPluginManager) CleanupPluginDirectory(r string) error { + return m.cleanupPluginDirectory(r) +} + +func (m *mockPluginManager) PluginConnected(r string, p plugin.DevicePlugin) error { + return m.pluginConnected(r, p) +} + +func (m *mockPluginManager) PluginDisconnected(r string) { + m.pluginDisconnected(r) +} + +func (m *mockPluginManager) PluginListAndWatchReceiver(r string, lr *pluginapi.ListAndWatchResponse) { + m.pluginListAndWatchReceiver(r, lr) +} + func esocketName() string { return fmt.Sprintf("mock%d.sock", time.Now().UnixNano()) } @@ -95,7 +134,7 @@ func TestRun(t *testing.T) { p, e := esetup(t, devs, socket, "mock", callback) defer ecleanup(t, p, e) - go e.run() + go e.client.Run() // Wait for the first callback to be issued. <-callbackChan @@ -146,7 +185,7 @@ func TestAllocate(t *testing.T) { return resp, nil }) - go e.run() + go e.client.Run() // Wait for the callback to be issued. select { case <-callbackChan: @@ -180,7 +219,7 @@ func TestGetPreferredAllocation(t *testing.T) { return resp, nil }) - go e.run() + go e.client.Run() // Wait for the callback to be issued. select { case <-callbackChan: @@ -194,19 +233,47 @@ func TestGetPreferredAllocation(t *testing.T) { require.Equal(t, resp, respOut) } -func esetup(t *testing.T, devs []*pluginapi.Device, socket, resourceName string, callback monitorCallback) (*Stub, *endpointImpl) { - p := NewDevicePluginStub(devs, socket, resourceName, false, false) +func esetup(t *testing.T, devs []*pluginapi.Device, socket, resourceName string, callback monitorCallback) (*plugin.Stub, *endpointImpl) { + m := newMockPluginManager() + m.pluginListAndWatchReceiver = func(r string, resp *pluginapi.ListAndWatchResponse) { + var newDevs []pluginapi.Device + for _, d := range resp.Devices { + newDevs = append(newDevs, *d) + } + callback(resourceName, newDevs) + } + + var dp plugin.DevicePlugin + var wg sync.WaitGroup + wg.Add(1) + m.pluginConnected = func(r string, c plugin.DevicePlugin) error { + dp = c + wg.Done() + return nil + } + + p := plugin.NewDevicePluginStub(devs, socket, resourceName, false, false) err := p.Start() require.NoError(t, err) - e, err := newEndpointImpl(socket, resourceName, callback) + c := plugin.NewPluginClient(resourceName, socket, m) + err = c.Connect() require.NoError(t, err) + wg.Wait() + + e := newEndpointImpl(dp) + e.client = c + + m.pluginDisconnected = func(r string) { + e.setStopTime(time.Now()) + } + return p, e } -func ecleanup(t *testing.T, p *Stub, e *endpointImpl) { +func ecleanup(t *testing.T, p *plugin.Stub, e *endpointImpl) { p.Stop() - e.stop() + e.client.Disconnect() } diff --git a/pkg/kubelet/cm/devicemanager/manager.go b/pkg/kubelet/cm/devicemanager/manager.go index b367aa30781..7b0283d93e4 100644 --- a/pkg/kubelet/cm/devicemanager/manager.go +++ b/pkg/kubelet/cm/devicemanager/manager.go @@ -19,7 +19,6 @@ package devicemanager import ( "context" "fmt" - "net" "os" "path/filepath" "runtime" @@ -28,8 +27,6 @@ import ( "time" cadvisorapi "github.com/google/cadvisor/info/v1" - "github.com/opencontainers/selinux/go-selinux" - "google.golang.org/grpc" "k8s.io/klog/v2" v1 "k8s.io/api/core/v1" @@ -38,11 +35,11 @@ import ( "k8s.io/apimachinery/pkg/util/sets" utilfeature "k8s.io/apiserver/pkg/util/feature" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" - v1helper "k8s.io/kubernetes/pkg/apis/core/v1/helper" "k8s.io/kubernetes/pkg/features" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager/errors" "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/checkpoint" + plugin "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/plugin/v1beta1" "k8s.io/kubernetes/pkg/kubelet/cm/topologymanager" "k8s.io/kubernetes/pkg/kubelet/config" "k8s.io/kubernetes/pkg/kubelet/lifecycle" @@ -56,21 +53,14 @@ const nodeWithoutTopology = -1 // ActivePodsFunc is a function that returns a list of pods to reconcile. type ActivePodsFunc func() []*v1.Pod -// monitorCallback is the function called when a device's health state changes, -// or new devices are reported, or old devices are deleted. -// Updated contains the most recent state of the Device. -type monitorCallback func(resourceName string, devices []pluginapi.Device) - // ManagerImpl is the structure in charge of managing Device Plugins. type ManagerImpl struct { - socketname string - socketdir string + checkpointdir string endpoints map[string]endpointInfo // Key is ResourceName mutex sync.Mutex - server *grpc.Server - wg sync.WaitGroup + server plugin.Server // activePods is a method for listing active pods on the node // so the amount of pluginResources requested by existing pods @@ -81,10 +71,6 @@ type ManagerImpl struct { // We use it to determine when we can purge inactive pods from checkpointed state. sourcesReady config.SourcesReady - // callback is used for updating devices' states in one time call. - // e.g. a new device is advertised, two old devices are deleted and a running device fails. - callback monitorCallback - // allDevices holds all the devices currently registered to the device manager allDevices ResourceDeviceInstances @@ -140,21 +126,14 @@ func NewManagerImpl(topology []cadvisorapi.Node, topologyAffinityStore topologym func newManagerImpl(socketPath string, topology []cadvisorapi.Node, topologyAffinityStore topologymanager.Store) (*ManagerImpl, error) { klog.V(2).InfoS("Creating Device Plugin manager", "path", socketPath) - if socketPath == "" || !filepath.IsAbs(socketPath) { - return nil, fmt.Errorf(errBadSocket+" %s", socketPath) - } - var numaNodes []int for _, node := range topology { numaNodes = append(numaNodes, node.Id) } - dir, file := filepath.Split(socketPath) manager := &ManagerImpl{ endpoints: make(map[string]endpointInfo), - socketname: file, - socketdir: dir, allDevices: NewResourceDeviceInstances(), healthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String), @@ -164,13 +143,20 @@ func newManagerImpl(socketPath string, topology []cadvisorapi.Node, topologyAffi topologyAffinityStore: topologyAffinityStore, devicesToReuse: make(PodReusableDevices), } - manager.callback = manager.genericDeviceUpdateCallback + + server, err := plugin.NewServer(socketPath, manager, manager) + if err != nil { + return nil, fmt.Errorf("failed to create plugin server: %v", err) + } + + manager.server = server + manager.checkpointdir, _ = filepath.Split(server.SocketPath()) // The following structures are populated with real implementations in manager.Start() // Before that, initializes them to perform no-op operations. manager.activePods = func() []*v1.Pod { return []*v1.Pod{} } manager.sourcesReady = &sourcesReadyStub{} - checkpointManager, err := checkpointmanager.NewCheckpointManager(dir) + checkpointManager, err := checkpointmanager.NewCheckpointManager(manager.checkpointdir) if err != nil { return nil, fmt.Errorf("failed to initialize checkpoint manager: %v", err) } @@ -179,26 +165,7 @@ func newManagerImpl(socketPath string, topology []cadvisorapi.Node, topologyAffi return manager, nil } -func (m *ManagerImpl) genericDeviceUpdateCallback(resourceName string, devices []pluginapi.Device) { - m.mutex.Lock() - m.healthyDevices[resourceName] = sets.NewString() - m.unhealthyDevices[resourceName] = sets.NewString() - m.allDevices[resourceName] = make(map[string]pluginapi.Device) - for _, dev := range devices { - m.allDevices[resourceName][dev.ID] = dev - if dev.Health == pluginapi.Healthy { - m.healthyDevices[resourceName].Insert(dev.ID) - } else { - m.unhealthyDevices[resourceName].Insert(dev.ID) - } - } - m.mutex.Unlock() - if err := m.writeCheckpoint(); err != nil { - klog.ErrorS(err, "Writing checkpoint encountered") - } -} - -func (m *ManagerImpl) removeContents(dir string) error { +func (m *ManagerImpl) CleanupPluginDirectory(dir string) error { d, err := os.Open(dir) if err != nil { return err @@ -235,9 +202,68 @@ func (m *ManagerImpl) removeContents(dir string) error { return errorsutil.NewAggregate(errs) } +func (m *ManagerImpl) PluginConnected(resourceName string, p plugin.DevicePlugin) error { + options, err := p.Api().GetDevicePluginOptions(context.Background(), &pluginapi.Empty{}) + if err != nil { + return fmt.Errorf("failed to get device plugin options: %v", err) + } + + e := newEndpointImpl(p) + + m.mutex.Lock() + defer m.mutex.Unlock() + m.endpoints[resourceName] = endpointInfo{e, options} + + return nil +} + +func (m *ManagerImpl) PluginDisconnected(resourceName string) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, exists := m.endpoints[resourceName]; exists { + m.markResourceUnhealthy(resourceName) + klog.V(2).InfoS("Endpoint became unhealthy", "resourceName", resourceName, "endpoint", m.endpoints[resourceName]) + } + + m.endpoints[resourceName].e.setStopTime(time.Now()) +} + +func (m *ManagerImpl) PluginListAndWatchReceiver(resourceName string, resp *pluginapi.ListAndWatchResponse) { + var devices []pluginapi.Device + for _, d := range resp.Devices { + devices = append(devices, *d) + } + m.genericDeviceUpdateCallback(resourceName, devices) +} + +func (m *ManagerImpl) genericDeviceUpdateCallback(resourceName string, devices []pluginapi.Device) { + m.mutex.Lock() + m.healthyDevices[resourceName] = sets.NewString() + m.unhealthyDevices[resourceName] = sets.NewString() + m.allDevices[resourceName] = make(map[string]pluginapi.Device) + for _, dev := range devices { + m.allDevices[resourceName][dev.ID] = dev + if dev.Health == pluginapi.Healthy { + m.healthyDevices[resourceName].Insert(dev.ID) + } else { + m.unhealthyDevices[resourceName].Insert(dev.ID) + } + } + m.mutex.Unlock() + if err := m.writeCheckpoint(); err != nil { + klog.ErrorS(err, "Writing checkpoint encountered") + } +} + +// GetWatcherHandler returns the plugin handler +func (m *ManagerImpl) GetWatcherHandler() cache.PluginHandler { + return m.server +} + // checkpointFile returns device plugin checkpoint file path. func (m *ManagerImpl) checkpointFile() string { - return filepath.Join(m.socketdir, kubeletDeviceManagerCheckpoint) + return filepath.Join(m.checkpointdir, kubeletDeviceManagerCheckpoint) } // Start starts the Device Plugin Manager and start initialization of @@ -255,118 +281,14 @@ func (m *ManagerImpl) Start(activePods ActivePodsFunc, sourcesReady config.Sourc klog.InfoS("Continue after failing to read checkpoint file. Device allocation info may NOT be up-to-date", "err", err) } - socketPath := filepath.Join(m.socketdir, m.socketname) - if err = os.MkdirAll(m.socketdir, 0750); err != nil { - return err - } - if selinux.GetEnabled() { - if err := selinux.SetFileLabel(m.socketdir, config.KubeletPluginsDirSELinuxLabel); err != nil { - klog.InfoS("Unprivileged containerized plugins might not work. Could not set selinux context on socket dir", "path", m.socketdir, "err", err) - } - } - - // Removes all stale sockets in m.socketdir. Device plugins can monitor - // this and use it as a signal to re-register with the new Kubelet. - if err := m.removeContents(m.socketdir); err != nil { - klog.ErrorS(err, "Fail to clean up stale content under socket dir", "path", m.socketdir) - } - - s, err := net.Listen("unix", socketPath) - if err != nil { - klog.ErrorS(err, "Failed to listen to socket while starting device plugin registry") - return err - } - - m.wg.Add(1) - m.server = grpc.NewServer([]grpc.ServerOption{}...) - - pluginapi.RegisterRegistrationServer(m.server, m) - go func() { - defer m.wg.Done() - m.server.Serve(s) - }() - - klog.V(2).InfoS("Serving device plugin registration server on socket", "path", socketPath) - - return nil + return m.server.Start() } -// GetWatcherHandler returns the plugin handler -func (m *ManagerImpl) GetWatcherHandler() cache.PluginHandler { - if f, err := os.Create(m.socketdir + "DEPRECATION"); err != nil { - klog.ErrorS(err, "Failed to create deprecation file at socket dir", "path", m.socketdir) - } else { - f.Close() - klog.V(4).InfoS("Created deprecation file", "path", f.Name()) - } - - return cache.PluginHandler(m) -} - -// ValidatePlugin validates a plugin if the version is correct and the name has the format of an extended resource -func (m *ManagerImpl) ValidatePlugin(pluginName string, endpoint string, versions []string) error { - klog.V(2).InfoS("Got Plugin at endpoint with versions", "plugin", pluginName, "endpoint", endpoint, "versions", versions) - - if !m.isVersionCompatibleWithPlugin(versions) { - return fmt.Errorf("manager version, %s, is not among plugin supported versions %v", pluginapi.Version, versions) - } - - if !v1helper.IsExtendedResourceName(v1.ResourceName(pluginName)) { - return fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, pluginName)) - } - - return nil -} - -// RegisterPlugin starts the endpoint and registers it -// TODO: Start the endpoint and wait for the First ListAndWatch call -// before registering the plugin -func (m *ManagerImpl) RegisterPlugin(pluginName string, endpoint string, versions []string) error { - klog.V(2).InfoS("Registering plugin at endpoint", "plugin", pluginName, "endpoint", endpoint) - - e, err := newEndpointImpl(endpoint, pluginName, m.callback) - if err != nil { - return fmt.Errorf("failed to dial device plugin with socketPath %s: %v", endpoint, err) - } - - options, err := e.client.GetDevicePluginOptions(context.Background(), &pluginapi.Empty{}) - if err != nil { - return fmt.Errorf("failed to get device plugin options: %v", err) - } - - m.registerEndpoint(pluginName, options, e) - go m.runEndpoint(pluginName, e) - - return nil -} - -// DeRegisterPlugin deregisters the plugin -// TODO work on the behavior for deregistering plugins -// e.g: Should we delete the resource -func (m *ManagerImpl) DeRegisterPlugin(pluginName string) { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Note: This will mark the resource unhealthy as per the behavior - // in runEndpoint - if eI, ok := m.endpoints[pluginName]; ok { - eI.e.stop() - } -} - -func (m *ManagerImpl) isVersionCompatibleWithPlugin(versions []string) bool { - // TODO(vikasc): Currently this is fine as we only have a single supported version. When we do need to support - // multiple versions in the future, we may need to extend this function to return a supported version. - // E.g., say kubelet supports v1beta1 and v1beta2, and we get v1alpha1 and v1beta1 from a device plugin, - // this function should return v1beta1 - for _, version := range versions { - for _, supportedVersion := range pluginapi.SupportedVersions { - if version == supportedVersion { - return true - } - } - } - return false +// Stop is the function that can stop the plugin server. +// Can be called concurrently, more than once, and is safe to call +// without a prior Start. +func (m *ManagerImpl) Stop() error { + return m.server.Stop() } // Allocate is the call that you can use to allocate a set of devices @@ -417,91 +339,6 @@ func (m *ManagerImpl) UpdatePluginResources(node *schedulerframework.NodeInfo, a return nil } -// Register registers a device plugin. -func (m *ManagerImpl) Register(ctx context.Context, r *pluginapi.RegisterRequest) (*pluginapi.Empty, error) { - klog.InfoS("Got registration request from device plugin with resource", "resourceName", r.ResourceName) - metrics.DevicePluginRegistrationCount.WithLabelValues(r.ResourceName).Inc() - var versionCompatible bool - for _, v := range pluginapi.SupportedVersions { - if r.Version == v { - versionCompatible = true - break - } - } - if !versionCompatible { - err := fmt.Errorf(errUnsupportedVersion, r.Version, pluginapi.SupportedVersions) - klog.InfoS("Bad registration request from device plugin with resource", "resourceName", r.ResourceName, "err", err) - return &pluginapi.Empty{}, err - } - - if !v1helper.IsExtendedResourceName(v1.ResourceName(r.ResourceName)) { - err := fmt.Errorf(errInvalidResourceName, r.ResourceName) - klog.InfoS("Bad registration request from device plugin", "err", err) - return &pluginapi.Empty{}, err - } - - // TODO: for now, always accepts newest device plugin. Later may consider to - // add some policies here, e.g., verify whether an old device plugin with the - // same resource name is still alive to determine whether we want to accept - // the new registration. - go m.addEndpoint(r) - - return &pluginapi.Empty{}, nil -} - -// Stop is the function that can stop the gRPC server. -// Can be called concurrently, more than once, and is safe to call -// without a prior Start. -func (m *ManagerImpl) Stop() error { - m.mutex.Lock() - defer m.mutex.Unlock() - for _, eI := range m.endpoints { - eI.e.stop() - } - - if m.server == nil { - return nil - } - m.server.Stop() - m.wg.Wait() - m.server = nil - return nil -} - -func (m *ManagerImpl) registerEndpoint(resourceName string, options *pluginapi.DevicePluginOptions, e endpoint) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.endpoints[resourceName] = endpointInfo{e: e, opts: options} - klog.V(2).InfoS("Registered endpoint", "endpoint", e) -} - -func (m *ManagerImpl) runEndpoint(resourceName string, e endpoint) { - e.run() - e.stop() - - m.mutex.Lock() - defer m.mutex.Unlock() - - if old, ok := m.endpoints[resourceName]; ok && old.e == e { - m.markResourceUnhealthy(resourceName) - } - - klog.V(2).InfoS("Endpoint became unhealthy", "resourceName", resourceName, "endpoint", e) -} - -func (m *ManagerImpl) addEndpoint(r *pluginapi.RegisterRequest) { - new, err := newEndpointImpl(filepath.Join(m.socketdir, r.Endpoint), r.ResourceName, m.callback) - if err != nil { - klog.ErrorS(err, "Failed to dial device plugin with request", "request", r) - return - } - m.registerEndpoint(r.ResourceName, r.Options, new) - go func() { - m.runEndpoint(r.ResourceName, new) - }() -} - func (m *ManagerImpl) markResourceUnhealthy(resourceName string) { klog.V(2).InfoS("Mark all resources Unhealthy for resource", "resourceName", resourceName) healthyDevices := sets.NewString() diff --git a/pkg/kubelet/cm/devicemanager/manager_test.go b/pkg/kubelet/cm/devicemanager/manager_test.go index 0345a67f8c8..22c1c6ec5af 100644 --- a/pkg/kubelet/cm/devicemanager/manager_test.go +++ b/pkg/kubelet/cm/devicemanager/manager_test.go @@ -40,6 +40,7 @@ import ( watcherapi "k8s.io/kubelet/pkg/apis/pluginregistration/v1" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/checkpoint" + plugin "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/plugin/v1beta1" "k8s.io/kubernetes/pkg/kubelet/cm/topologymanager" "k8s.io/kubernetes/pkg/kubelet/cm/topologymanager/bitmask" "k8s.io/kubernetes/pkg/kubelet/config" @@ -52,6 +53,30 @@ const ( testResourceName = "fake-domain/resource" ) +func newWrappedManagerImpl(socketPath string, manager *ManagerImpl) *wrappedManagerImpl { + w := &wrappedManagerImpl{ + ManagerImpl: manager, + callback: manager.genericDeviceUpdateCallback, + } + w.socketdir, _ = filepath.Split(socketPath) + w.server, _ = plugin.NewServer(socketPath, w, w) + return w +} + +type wrappedManagerImpl struct { + *ManagerImpl + socketdir string + callback func(string, []pluginapi.Device) +} + +func (m *wrappedManagerImpl) PluginListAndWatchReceiver(r string, resp *pluginapi.ListAndWatchResponse) { + var devices []pluginapi.Device + for _, d := range resp.Devices { + devices = append(devices, *d) + } + m.callback(r, devices) +} + func tmpSocketDir() (socketDir, socketName, pluginSocketName string, err error) { socketDir, err = ioutil.TempDir("", "device_plugin") if err != nil { @@ -121,7 +146,7 @@ func TestDevicePluginReRegistration(t *testing.T) { require.Equal(t, resourceCapacity.Value(), resourceAllocatable.Value(), "capacity should equal to allocatable") require.Equal(t, int64(2), resourceAllocatable.Value(), "Devices are not updated.") - p2 := NewDevicePluginStub(devs, pluginSocketName+".new", testResourceName, preStartContainerFlag, getPreferredAllocationFlag) + p2 := plugin.NewDevicePluginStub(devs, pluginSocketName+".new", testResourceName, preStartContainerFlag, getPreferredAllocationFlag) err = p2.Start() require.NoError(t, err) p2.Register(socketName, testResourceName, "") @@ -138,7 +163,7 @@ func TestDevicePluginReRegistration(t *testing.T) { require.Equal(t, int64(2), resourceAllocatable.Value(), "Devices shouldn't change.") // Test the scenario that a plugin re-registers with different devices. - p3 := NewDevicePluginStub(devsForRegistration, pluginSocketName+".third", testResourceName, preStartContainerFlag, getPreferredAllocationFlag) + p3 := plugin.NewDevicePluginStub(devsForRegistration, pluginSocketName+".third", testResourceName, preStartContainerFlag, getPreferredAllocationFlag) err = p3.Start() require.NoError(t, err) p3.Register(socketName, testResourceName, "") @@ -191,7 +216,7 @@ func TestDevicePluginReRegistrationProbeMode(t *testing.T) { require.Equal(t, resourceCapacity.Value(), resourceAllocatable.Value(), "capacity should equal to allocatable") require.Equal(t, int64(2), resourceAllocatable.Value(), "Devices are not updated.") - p2 := NewDevicePluginStub(devs, pluginSocketName+".new", testResourceName, false, false) + p2 := plugin.NewDevicePluginStub(devs, pluginSocketName+".new", testResourceName, false, false) err = p2.Start() require.NoError(t, err) // Wait for the second callback to be issued. @@ -208,7 +233,7 @@ func TestDevicePluginReRegistrationProbeMode(t *testing.T) { require.Equal(t, int64(2), resourceAllocatable.Value(), "Devices are not updated.") // Test the scenario that a plugin re-registers with different devices. - p3 := NewDevicePluginStub(devsForRegistration, pluginSocketName+".third", testResourceName, false, false) + p3 := plugin.NewDevicePluginStub(devsForRegistration, pluginSocketName+".third", testResourceName, false, false) err = p3.Start() require.NoError(t, err) // Wait for the third callback to be issued. @@ -234,12 +259,13 @@ func setupDeviceManager(t *testing.T, devs []*pluginapi.Device, callback monitor require.NoError(t, err) updateChan := make(chan interface{}) + w := newWrappedManagerImpl(socketName, m) if callback != nil { - m.callback = callback + w.callback = callback } - originalCallback := m.callback - m.callback = func(resourceName string, devices []pluginapi.Device) { + originalCallback := w.callback + w.callback = func(resourceName string, devices []pluginapi.Device) { originalCallback(resourceName, devices) updateChan <- new(interface{}) } @@ -247,14 +273,14 @@ func setupDeviceManager(t *testing.T, devs []*pluginapi.Device, callback monitor return []*v1.Pod{} } - err = m.Start(activePods, &sourcesReadyStub{}) + err = w.Start(activePods, &sourcesReadyStub{}) require.NoError(t, err) - return m, updateChan + return w, updateChan } -func setupDevicePlugin(t *testing.T, devs []*pluginapi.Device, pluginSocketName string) *Stub { - p := NewDevicePluginStub(devs, pluginSocketName, testResourceName, false, false) +func setupDevicePlugin(t *testing.T, devs []*pluginapi.Device, pluginSocketName string) *plugin.Stub { + p := plugin.NewDevicePluginStub(devs, pluginSocketName, testResourceName, false, false) err := p.Start() require.NoError(t, err) return p @@ -276,20 +302,20 @@ func runPluginManager(pluginManager pluginmanager.PluginManager) { go pluginManager.Run(sourcesReady, wait.NeverStop) } -func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub) { +func setup(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *plugin.Stub) { m, updateChan := setupDeviceManager(t, devs, callback, socketName) p := setupDevicePlugin(t, devs, pluginSocketName) return m, updateChan, p } -func setupInProbeMode(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *Stub, pluginmanager.PluginManager) { +func setupInProbeMode(t *testing.T, devs []*pluginapi.Device, callback monitorCallback, socketName string, pluginSocketName string) (Manager, <-chan interface{}, *plugin.Stub, pluginmanager.PluginManager) { m, updateChan := setupDeviceManager(t, devs, callback, socketName) p := setupDevicePlugin(t, devs, pluginSocketName) pm := setupPluginManager(t, pluginSocketName, m) return m, updateChan, p, pm } -func cleanup(t *testing.T, m Manager, p *Stub) { +func cleanup(t *testing.T, m Manager, p *plugin.Stub) { p.Stop() m.Stop() } @@ -365,6 +391,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) { // Tests adding another resource. resourceName2 := "resource2" e2 := &endpointImpl{} + e2.client = plugin.NewPluginClient(resourceName2, socketName, testManager) testManager.endpoints[resourceName2] = endpointInfo{e: e2, opts: nil} callback(resourceName2, devs) capacity, allocatable, removedResources = testManager.GetCapacity() @@ -394,7 +421,7 @@ func TestUpdateCapacityAllocatable(t *testing.T) { // Stops resourceName2 endpoint. Verifies its stopTime is set, allocate and // preStartContainer calls return errors. - e2.stop() + e2.client.Disconnect() as.False(e2.stopTime.IsZero()) _, err = e2.allocate([]string{"Device1"}) reflect.DeepEqual(err, fmt.Errorf(errEndpointStopped, e2)) @@ -661,11 +688,6 @@ type MockEndpoint struct { initChan chan []string } -func (m *MockEndpoint) stop() {} -func (m *MockEndpoint) run() {} - -func (m *MockEndpoint) callback(resourceName string, devices []pluginapi.Device) {} - func (m *MockEndpoint) preStartContainer(devs []string) (*pluginapi.PreStartContainerResponse, error) { m.initChan <- devs return &pluginapi.PreStartContainerResponse{}, nil @@ -685,6 +707,8 @@ func (m *MockEndpoint) allocate(devs []string) (*pluginapi.AllocateResponse, err return nil, nil } +func (m *MockEndpoint) setStopTime(t time.Time) {} + func (m *MockEndpoint) isStopped() bool { return false } func (m *MockEndpoint) stopGracePeriodExpired() bool { return false } @@ -706,15 +730,13 @@ func makePod(limits v1.ResourceList) *v1.Pod { } } -func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource) (*ManagerImpl, error) { +func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestResource) (*wrappedManagerImpl, error) { monitorCallback := func(resourceName string, devices []pluginapi.Device) {} ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) if err != nil { return nil, err } - testManager := &ManagerImpl{ - socketdir: tmpDir, - callback: monitorCallback, + m := &ManagerImpl{ healthyDevices: make(map[string]sets.String), unhealthyDevices: make(map[string]sets.String), allocatedDevices: make(map[string]sets.String), @@ -727,6 +749,11 @@ func getTestManager(tmpDir string, activePods ActivePodsFunc, testRes []TestReso checkpointManager: ckm, allDevices: NewResourceDeviceInstances(), } + testManager := &wrappedManagerImpl{ + ManagerImpl: m, + socketdir: tmpDir, + callback: monitorCallback, + } for _, res := range testRes { testManager.healthyDevices[res.resourceName] = sets.NewString(res.devs.Devices().UnsortedList()...) @@ -1141,13 +1168,16 @@ func TestUpdatePluginResources(t *testing.T) { ckm, err := checkpointmanager.NewCheckpointManager(tmpDir) as.Nil(err) - testManager := &ManagerImpl{ - callback: monitorCallback, + m := &ManagerImpl{ allocatedDevices: make(map[string]sets.String), healthyDevices: make(map[string]sets.String), podDevices: newPodDevices(), checkpointManager: ckm, } + testManager := wrappedManagerImpl{ + ManagerImpl: m, + callback: monitorCallback, + } testManager.podDevices.devs[string(pod.UID)] = make(containerDevices) // require one of resource1 and one of resource2 diff --git a/pkg/kubelet/cm/devicemanager/plugin/v1beta1/api.go b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/api.go new file mode 100644 index 00000000000..b869183786a --- /dev/null +++ b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/api.go @@ -0,0 +1,46 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + api "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +type RegistrationHandler interface { + CleanupPluginDirectory(string) error +} + +type ClientHandler interface { + PluginConnected(string, DevicePlugin) error + PluginDisconnected(string) + PluginListAndWatchReceiver(string, *api.ListAndWatchResponse) +} + +// TODO: evaluate whether we need these error definitions. +const ( + // errFailedToDialDevicePlugin is the error raised when the device plugin could not be + // reached on the registered socket + errFailedToDialDevicePlugin = "failed to dial device plugin:" + // errUnsupportedVersion is the error raised when the device plugin uses an API version not + // supported by the Kubelet registry + errUnsupportedVersion = "requested API version %q is not supported by kubelet. Supported version is %q" + // errInvalidResourceName is the error raised when a device plugin is registering + // itself with an invalid ResourceName + errInvalidResourceName = "the ResourceName %q is invalid" + // errBadSocket is the error raised when the registry socket path is not absolute + errBadSocket = "bad socketPath, must be an absolute path:" +) diff --git a/pkg/kubelet/cm/devicemanager/plugin/v1beta1/client.go b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/client.go new file mode 100644 index 00000000000..13b1249009b --- /dev/null +++ b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/client.go @@ -0,0 +1,131 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "google.golang.org/grpc" + + "k8s.io/klog/v2" + api "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" +) + +type DevicePlugin interface { + Api() api.DevicePluginClient + Resource() string + SocketPath() string +} + +type Client interface { + Connect() error + Run() + Disconnect() error +} + +type client struct { + mutex sync.Mutex + resource string + socket string + grpc *grpc.ClientConn + handler ClientHandler + client api.DevicePluginClient +} + +func NewPluginClient(r string, socketPath string, h ClientHandler) Client { + return &client{ + resource: r, + socket: socketPath, + handler: h, + } +} + +func (c *client) Connect() error { + client, conn, err := dial(c.socket) + if err != nil { + klog.ErrorS(err, "Unable to connect to device plugin client with socket path", "path", c.socket) + return err + } + c.grpc = conn + c.client = client + return c.handler.PluginConnected(c.resource, c) +} + +func (c *client) Run() { + stream, err := c.client.ListAndWatch(context.Background(), &api.Empty{}) + if err != nil { + klog.ErrorS(err, "ListAndWatch ended unexpectedly for device plugin", "resource", c.resource) + return + } + + for { + response, err := stream.Recv() + if err != nil { + klog.ErrorS(err, "ListAndWatch ended unexpectedly for device plugin", "resource", c.resource) + return + } + klog.V(2).InfoS("State pushed for device plugin", "resource", c.resource, "resourceCapacity", len(response.Devices)) + c.handler.PluginListAndWatchReceiver(c.resource, response) + } +} + +func (c *client) Disconnect() error { + c.mutex.Lock() + if c.grpc != nil { + if err := c.grpc.Close(); err != nil { + klog.V(2).ErrorS(err, "Failed to close grcp connection", "resource", c.Resource()) + } + c.grpc = nil + } + c.mutex.Unlock() + c.handler.PluginDisconnected(c.resource) + return nil +} + +func (c *client) Resource() string { + return c.resource +} + +func (c *client) Api() api.DevicePluginClient { + return c.client +} + +func (c *client) SocketPath() string { + return c.socket +} + +// dial establishes the gRPC communication with the registered device plugin. https://godoc.org/google.golang.org/grpc#Dial +func dial(unixSocketPath string) (api.DevicePluginClient, *grpc.ClientConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c, err := grpc.DialContext(ctx, unixSocketPath, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "unix", addr) + }), + ) + + if err != nil { + return nil, nil, fmt.Errorf(errFailedToDialDevicePlugin+" %v", err) + } + + return api.NewDevicePluginClient(c), c, nil +} diff --git a/pkg/kubelet/cm/devicemanager/plugin/v1beta1/handler.go b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/handler.go new file mode 100644 index 00000000000..cea6317257e --- /dev/null +++ b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/handler.go @@ -0,0 +1,120 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + "fmt" + "os" + + core "k8s.io/api/core/v1" + "k8s.io/klog/v2" + api "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + v1helper "k8s.io/kubernetes/pkg/apis/core/v1/helper" + "k8s.io/kubernetes/pkg/kubelet/pluginmanager/cache" +) + +func (s *server) GetPluginHandler() cache.PluginHandler { + if f, err := os.Create(s.socketDir + "DEPRECATION"); err != nil { + klog.ErrorS(err, "Failed to create deprecation file at socket dir", "path", s.socketDir) + } else { + f.Close() + klog.V(4).InfoS("Created deprecation file", "path", f.Name()) + } + return s +} + +func (s *server) RegisterPlugin(pluginName string, endpoint string, versions []string) error { + klog.V(2).InfoS("Registering plugin at endpoint", "plugin", pluginName, "endpoint", endpoint) + return s.connectClient(pluginName, endpoint) +} + +func (s *server) DeRegisterPlugin(pluginName string) { + klog.V(2).InfoS("Deregistering plugin", "plugin", pluginName) + s.mutex.Lock() + defer s.mutex.Unlock() + if _, exists := s.clients[pluginName]; exists { + s.disconnectClient(pluginName) + } +} + +func (s *server) ValidatePlugin(pluginName string, endpoint string, versions []string) error { + klog.V(2).InfoS("Got plugin at endpoint with versions", "plugin", pluginName, "endpoint", endpoint, "versions", versions) + + if !s.isVersionCompatibleWithPlugin(versions...) { + return fmt.Errorf("manager version, %s, is not among plugin supported versions %v", api.Version, versions) + } + + if !v1helper.IsExtendedResourceName(core.ResourceName(pluginName)) { + return fmt.Errorf("invalid name of device plugin socket: %s", fmt.Sprintf(errInvalidResourceName, pluginName)) + } + + return nil +} + +func (s *server) connectClient(name string, socketPath string) error { + c := NewPluginClient(name, socketPath, s.chandler) + + s.registerClient(name, c) + if err := c.Connect(); err != nil { + s.deregisterClient(name) + klog.ErrorS(err, "Failed to connect to new client", "resource", name) + return err + } + + go func() { + s.runClient(name, c) + }() + + return nil +} + +func (s *server) disconnectClient(name string) error { + c := s.clients[name] + s.deregisterClient(name) + return c.Disconnect() +} + +func (s *server) registerClient(name string, c Client) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.clients[name] = c + klog.V(2).InfoS("Registered client", "name", name) +} + +func (s *server) deregisterClient(name string) { + s.mutex.Lock() + defer s.mutex.Unlock() + + delete(s.clients, name) + klog.V(2).InfoS("Deregistered client", "name", name) +} + +func (s *server) runClient(name string, c Client) { + c.Run() + + s.mutex.Lock() + if _, exists := s.clients[name]; !exists { + s.mutex.Unlock() + return + } + s.mutex.Unlock() + + if err := s.disconnectClient(name); err != nil { + klog.V(2).InfoS("Unable to disconnect client", "resource", name, "client", c, "err", err) + } +} diff --git a/pkg/kubelet/cm/devicemanager/plugin/v1beta1/server.go b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/server.go new file mode 100644 index 00000000000..b848858d09f --- /dev/null +++ b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/server.go @@ -0,0 +1,188 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1beta1 + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "sync" + + "github.com/opencontainers/selinux/go-selinux" + "google.golang.org/grpc" + + core "k8s.io/api/core/v1" + "k8s.io/klog/v2" + api "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + v1helper "k8s.io/kubernetes/pkg/apis/core/v1/helper" + "k8s.io/kubernetes/pkg/kubelet/config" + "k8s.io/kubernetes/pkg/kubelet/metrics" + "k8s.io/kubernetes/pkg/kubelet/pluginmanager/cache" +) + +type Server interface { + cache.PluginHandler + Start() error + Stop() error + SocketPath() string +} + +type server struct { + socketName string + socketDir string + mutex sync.Mutex + wg sync.WaitGroup + grpc *grpc.Server + rhandler RegistrationHandler + chandler ClientHandler + clients map[string]Client +} + +func NewServer(socketPath string, rh RegistrationHandler, ch ClientHandler) (Server, error) { + if socketPath == "" || !filepath.IsAbs(socketPath) { + return nil, fmt.Errorf(errBadSocket+" %s", socketPath) + } + + dir, name := filepath.Split(socketPath) + + klog.V(2).InfoS("Creating device plugin registration server", "version", api.Version, "socket", socketPath) + s := &server{ + socketName: name, + socketDir: dir, + rhandler: rh, + chandler: ch, + clients: make(map[string]Client), + } + + return s, nil +} + +func (s *server) Start() error { + klog.V(2).InfoS("Starting device plugin registration server") + + if err := os.MkdirAll(s.socketDir, 0750); err != nil { + klog.ErrorS(err, "Failed to create the device plugin socket directory", "directory", s.socketDir) + return err + } + + if selinux.GetEnabled() { + if err := selinux.SetFileLabel(s.socketDir, config.KubeletPluginsDirSELinuxLabel); err != nil { + klog.InfoS("Unprivileged containerized plugins might not work. Could not set selinux context on socket dir", "path", s.socketDir, "err", err) + } + } + + // For now we leave cleanup of the *entire* directory up to the Handler + // (even though we should in theory be able to just wipe the whole directory) + // because the Handler stores its checkpoint file (amongst others) in here. + if err := s.rhandler.CleanupPluginDirectory(s.socketDir); err != nil { + klog.ErrorS(err, "Failed to cleanup the device plugin directory", "directory", s.socketDir) + return err + } + + ln, err := net.Listen("unix", s.SocketPath()) + if err != nil { + klog.ErrorS(err, "Failed to listen to socket while starting device plugin registry") + return err + } + + s.wg.Add(1) + s.grpc = grpc.NewServer([]grpc.ServerOption{}...) + + api.RegisterRegistrationServer(s.grpc, s) + go func() { + defer s.wg.Done() + s.grpc.Serve(ln) + }() + + return nil +} + +func (s *server) Stop() error { + for _, r := range s.clientResources() { + if err := s.disconnectClient(r); err != nil { + klog.InfoS("Error disconnecting device plugin client", "resourceName", r, "err", err) + } + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.grpc == nil { + return nil + } + + s.grpc.Stop() + s.wg.Wait() + s.grpc = nil + + return nil +} + +func (s *server) SocketPath() string { + return filepath.Join(s.socketDir, s.socketName) +} + +func (s *server) Register(ctx context.Context, r *api.RegisterRequest) (*api.Empty, error) { + klog.InfoS("Got registration request from device plugin with resource", "resourceName", r.ResourceName) + metrics.DevicePluginRegistrationCount.WithLabelValues(r.ResourceName).Inc() + + if !s.isVersionCompatibleWithPlugin(r.Version) { + err := fmt.Errorf(errUnsupportedVersion, r.Version, api.SupportedVersions) + klog.InfoS("Bad registration request from device plugin with resource", "resourceName", r.ResourceName, "err", err) + return &api.Empty{}, err + } + + if !v1helper.IsExtendedResourceName(core.ResourceName(r.ResourceName)) { + err := fmt.Errorf(errInvalidResourceName, r.ResourceName) + klog.InfoS("Bad registration request from device plugin", "err", err) + return &api.Empty{}, err + } + + if err := s.connectClient(r.ResourceName, filepath.Join(s.socketDir, r.Endpoint)); err != nil { + klog.InfoS("Error connecting to device plugin client", "err", err) + return &api.Empty{}, err + } + + return &api.Empty{}, nil +} + +func (s *server) isVersionCompatibleWithPlugin(versions ...string) bool { + // TODO(vikasc): Currently this is fine as we only have a single supported version. When we do need to support + // multiple versions in the future, we may need to extend this function to return a supported version. + // E.g., say kubelet supports v1beta1 and v1beta2, and we get v1alpha1 and v1beta1 from a device plugin, + // this function should return v1beta1 + for _, version := range versions { + for _, supportedVersion := range api.SupportedVersions { + if version == supportedVersion { + return true + } + } + } + return false +} + +func (s *server) clientResources() []string { + s.mutex.Lock() + defer s.mutex.Unlock() + var resources []string + for r := range s.clients { + resources = append(resources, r) + } + return resources +} diff --git a/pkg/kubelet/cm/devicemanager/device_plugin_stub.go b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/stub.go similarity index 99% rename from pkg/kubelet/cm/devicemanager/device_plugin_stub.go rename to pkg/kubelet/cm/devicemanager/plugin/v1beta1/stub.go index ef1ca4569e6..dee1a9414aa 100644 --- a/pkg/kubelet/cm/devicemanager/device_plugin_stub.go +++ b/pkg/kubelet/cm/devicemanager/plugin/v1beta1/stub.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package devicemanager +package v1beta1 import ( "context" diff --git a/pkg/kubelet/cm/devicemanager/types.go b/pkg/kubelet/cm/devicemanager/types.go index 7b17ac4619d..d508e8c9969 100644 --- a/pkg/kubelet/cm/devicemanager/types.go +++ b/pkg/kubelet/cm/devicemanager/types.go @@ -93,21 +93,9 @@ type DeviceRunContainerOptions struct { Annotations []kubecontainer.Annotation } -// TODO: evaluate whether we need these error definitions. +// TODO: evaluate whether we need this error definition. const ( - // errFailedToDialDevicePlugin is the error raised when the device plugin could not be - // reached on the registered socket - errFailedToDialDevicePlugin = "failed to dial device plugin:" - // errUnsupportedVersion is the error raised when the device plugin uses an API version not - // supported by the Kubelet registry - errUnsupportedVersion = "requested API version %q is not supported by kubelet. Supported version is %q" - // errInvalidResourceName is the error raised when a device plugin is registering - // itself with an invalid ResourceName - errInvalidResourceName = "the ResourceName %q is invalid" - // errEndpointStopped indicates that the endpoint has been stopped errEndpointStopped = "endpoint %v has been stopped" - // errBadSocket is the error raised when the registry socket path is not absolute - errBadSocket = "bad socketPath, must be an absolute path:" ) // endpointStopGracePeriod indicates the grace period after an endpoint is stopped diff --git a/test/images/sample-device-plugin/VERSION b/test/images/sample-device-plugin/VERSION index 7e32cd56983..c068b2447cc 100644 --- a/test/images/sample-device-plugin/VERSION +++ b/test/images/sample-device-plugin/VERSION @@ -1 +1 @@ -1.3 +1.4 diff --git a/test/images/sample-device-plugin/sampledeviceplugin.go b/test/images/sample-device-plugin/sampledeviceplugin.go index c3750488561..824b19788cf 100644 --- a/test/images/sample-device-plugin/sampledeviceplugin.go +++ b/test/images/sample-device-plugin/sampledeviceplugin.go @@ -24,7 +24,7 @@ import ( "k8s.io/klog/v2" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" - dm "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager" + plugin "k8s.io/kubernetes/pkg/kubelet/cm/devicemanager/plugin/v1beta1" ) const ( @@ -86,7 +86,7 @@ func main() { } socketPath := pluginSocksDir + "/dp." + fmt.Sprintf("%d", time.Now().Unix()) - dp1 := dm.NewDevicePluginStub(devs, socketPath, resourceName, false, false) + dp1 := plugin.NewDevicePluginStub(devs, socketPath, resourceName, false, false) if err := dp1.Start(); err != nil { panic(err)