diff --git a/pkg/kubelet/cm/dra/plugin/client.go b/pkg/kubelet/cm/dra/plugin/client.go index e3a1e96756c..bbf0b1e9235 100644 --- a/pkg/kubelet/cm/dra/plugin/client.go +++ b/pkg/kubelet/cm/dra/plugin/client.go @@ -24,18 +24,107 @@ import ( "google.golang.org/grpc" grpccodes "google.golang.org/grpc/codes" grpcstatus "google.golang.org/grpc/status" - "k8s.io/klog/v2" + "k8s.io/klog/v2" drapbv1alpha2 "k8s.io/kubelet/pkg/apis/dra/v1alpha2" drapb "k8s.io/kubelet/pkg/apis/dra/v1alpha3" ) const PluginClientTimeout = 45 * time.Second -// draPluginClient encapsulates all dra plugin methods. -type draPluginClient struct { - pluginName string - plugin *Plugin +type ( + nodeResourceManager interface { + Prepare(context.Context, *grpc.ClientConn, *plugin, *drapb.NodePrepareResourcesRequest) (*drapb.NodePrepareResourcesResponse, error) + Unprepare(context.Context, *grpc.ClientConn, *plugin, *drapb.NodeUnprepareResourcesRequest) (*drapb.NodeUnprepareResourcesResponse, error) + } + + v1alpha2NodeResourceManager struct{} + v1alpha3NodeResourceManager struct{} +) + +var nodeResourceManagers = map[string]nodeResourceManager{ + v1alpha2Version: v1alpha2NodeResourceManager{}, + v1alpha3Version: v1alpha3NodeResourceManager{}, +} + +func (v1alpha2rm v1alpha2NodeResourceManager) Prepare(ctx context.Context, conn *grpc.ClientConn, _ *plugin, req *drapb.NodePrepareResourcesRequest) (*drapb.NodePrepareResourcesResponse, error) { + nodeClient := drapbv1alpha2.NewNodeClient(conn) + response := &drapb.NodePrepareResourcesResponse{ + Claims: make(map[string]*drapb.NodePrepareResourceResponse), + } + + for _, claim := range req.Claims { + res, err := nodeClient.NodePrepareResource(ctx, + &drapbv1alpha2.NodePrepareResourceRequest{ + Namespace: claim.Namespace, + ClaimUid: claim.Uid, + ClaimName: claim.Name, + ResourceHandle: claim.ResourceHandle, + }) + result := &drapb.NodePrepareResourceResponse{} + if err != nil { + result.Error = err.Error() + } else { + result.CDIDevices = res.CdiDevices + } + response.Claims[claim.Uid] = result + } + + return response, nil +} + +func (v1alpha2rm v1alpha2NodeResourceManager) Unprepare(ctx context.Context, conn *grpc.ClientConn, _ *plugin, req *drapb.NodeUnprepareResourcesRequest) (*drapb.NodeUnprepareResourcesResponse, error) { + nodeClient := drapbv1alpha2.NewNodeClient(conn) + response := &drapb.NodeUnprepareResourcesResponse{ + Claims: make(map[string]*drapb.NodeUnprepareResourceResponse), + } + + for _, claim := range req.Claims { + _, err := nodeClient.NodeUnprepareResource(ctx, + &drapbv1alpha2.NodeUnprepareResourceRequest{ + Namespace: claim.Namespace, + ClaimUid: claim.Uid, + ClaimName: claim.Name, + ResourceHandle: claim.ResourceHandle, + }) + result := &drapb.NodeUnprepareResourceResponse{} + if err != nil { + result.Error = err.Error() + } + response.Claims[claim.Uid] = result + } + + return response, nil +} + +func (v1alpha3rm v1alpha3NodeResourceManager) Prepare(ctx context.Context, conn *grpc.ClientConn, p *plugin, req *drapb.NodePrepareResourcesRequest) (*drapb.NodePrepareResourcesResponse, error) { + nodeClient := drapb.NewNodeClient(conn) + response, err := nodeClient.NodePrepareResources(ctx, req) + if err != nil { + status, _ := grpcstatus.FromError(err) + if status.Code() == grpccodes.Unimplemented { + p.setVersion(v1alpha2Version) + return nodeResourceManagers[v1alpha2Version].Prepare(ctx, conn, p, req) + } + return nil, err + } + + return response, nil +} + +func (v1alpha3rm v1alpha3NodeResourceManager) Unprepare(ctx context.Context, conn *grpc.ClientConn, p *plugin, req *drapb.NodeUnprepareResourcesRequest) (*drapb.NodeUnprepareResourcesResponse, error) { + nodeClient := drapb.NewNodeClient(conn) + response, err := nodeClient.NodeUnprepareResources(ctx, req) + if err != nil { + status, _ := grpcstatus.FromError(err) + if status.Code() == grpccodes.Unimplemented { + p.setVersion(v1alpha2Version) + return nodeResourceManagers[v1alpha2Version].Unprepare(ctx, conn, p, req) + } + return nil, err + } + + return response, nil } func NewDRAPluginClient(pluginName string) (drapb.NodeClient, error) { @@ -43,111 +132,68 @@ func NewDRAPluginClient(pluginName string) (drapb.NodeClient, error) { return nil, fmt.Errorf("plugin name is empty") } - existingPlugin := draPlugins.Get(pluginName) + existingPlugin := draPlugins.get(pluginName) if existingPlugin == nil { return nil, fmt.Errorf("plugin name %s not found in the list of registered DRA plugins", pluginName) } - return &draPluginClient{ - pluginName: pluginName, - plugin: existingPlugin, - }, nil + return existingPlugin, nil } -func (r *draPluginClient) NodePrepareResources( +func (p *plugin) NodePrepareResources( ctx context.Context, req *drapb.NodePrepareResourcesRequest, opts ...grpc.CallOption, -) (resp *drapb.NodePrepareResourcesResponse, err error) { +) (*drapb.NodePrepareResourcesResponse, error) { logger := klog.FromContext(ctx) logger.V(4).Info(log("calling NodePrepareResources rpc"), "request", req) - defer logger.V(4).Info(log("done calling NodePrepareResources rpc"), "response", resp, "err", err) - conn, err := r.plugin.getOrCreateGRPCConn() + conn, err := p.getOrCreateGRPCConn() if err != nil { return nil, err } - nodeClient := drapb.NewNodeClient(conn) - nodeClientOld := drapbv1alpha2.NewNodeClient(conn) ctx, cancel := context.WithTimeout(ctx, PluginClientTimeout) defer cancel() - resp, err = nodeClient.NodePrepareResources(ctx, req) - if err != nil { - status, _ := grpcstatus.FromError(err) - if status.Code() == grpccodes.Unimplemented { - // Fall back to the older gRPC API. - resp = &drapb.NodePrepareResourcesResponse{ - Claims: make(map[string]*drapb.NodePrepareResourceResponse), - } - err = nil - for _, claim := range req.Claims { - respOld, errOld := nodeClientOld.NodePrepareResource(ctx, - &drapbv1alpha2.NodePrepareResourceRequest{ - Namespace: claim.Namespace, - ClaimUid: claim.Uid, - ClaimName: claim.Name, - ResourceHandle: claim.ResourceHandle, - }) - result := &drapb.NodePrepareResourceResponse{} - if errOld != nil { - result.Error = errOld.Error() - } else { - result.CDIDevices = respOld.CdiDevices - } - resp.Claims[claim.Uid] = result - } - } + version := p.getVersion() + resourceManager, exists := nodeResourceManagers[version] + if !exists { + err := fmt.Errorf("unsupported plugin version: %s", version) + logger.V(4).Info(log("done calling NodePrepareResources rpc"), "response", nil, "err", err) + return nil, err } - return + response, err := resourceManager.Prepare(ctx, conn, p, req) + logger.V(4).Info(log("done calling NodePrepareResources rpc"), "response", response, "err", err) + return response, err } -func (r *draPluginClient) NodeUnprepareResources( +func (p *plugin) NodeUnprepareResources( ctx context.Context, req *drapb.NodeUnprepareResourcesRequest, opts ...grpc.CallOption, -) (resp *drapb.NodeUnprepareResourcesResponse, err error) { +) (*drapb.NodeUnprepareResourcesResponse, error) { logger := klog.FromContext(ctx) logger.V(4).Info(log("calling NodeUnprepareResource rpc"), "request", req) - defer logger.V(4).Info(log("done calling NodeUnprepareResources rpc"), "response", resp, "err", err) - conn, err := r.plugin.getOrCreateGRPCConn() + conn, err := p.getOrCreateGRPCConn() if err != nil { return nil, err } - nodeClient := drapb.NewNodeClient(conn) - nodeClientOld := drapbv1alpha2.NewNodeClient(conn) ctx, cancel := context.WithTimeout(ctx, PluginClientTimeout) defer cancel() - resp, err = nodeClient.NodeUnprepareResources(ctx, req) - if err != nil { - status, _ := grpcstatus.FromError(err) - if status.Code() == grpccodes.Unimplemented { - // Fall back to the older gRPC API. - resp = &drapb.NodeUnprepareResourcesResponse{ - Claims: make(map[string]*drapb.NodeUnprepareResourceResponse), - } - err = nil - for _, claim := range req.Claims { - _, errOld := nodeClientOld.NodeUnprepareResource(ctx, - &drapbv1alpha2.NodeUnprepareResourceRequest{ - Namespace: claim.Namespace, - ClaimUid: claim.Uid, - ClaimName: claim.Name, - ResourceHandle: claim.ResourceHandle, - }) - result := &drapb.NodeUnprepareResourceResponse{} - if errOld != nil { - result.Error = errOld.Error() - } - resp.Claims[claim.Uid] = result - } - } + version := p.getVersion() + resourceManager, exists := nodeResourceManagers[version] + if !exists { + err := fmt.Errorf("unsupported plugin version: %s", version) + logger.V(4).Info(log("done calling NodeUnprepareResources rpc"), "response", nil, "err", err) + return nil, err } - return + response, err := resourceManager.Unprepare(ctx, conn, p, req) + logger.V(4).Info(log("done calling NodeUnprepareResources rpc"), "response", response, "err", err) + return response, err } diff --git a/pkg/kubelet/cm/dra/plugin/client_test.go b/pkg/kubelet/cm/dra/plugin/client_test.go index 18c37a1d1ee..3e1889e2c97 100644 --- a/pkg/kubelet/cm/dra/plugin/client_test.go +++ b/pkg/kubelet/cm/dra/plugin/client_test.go @@ -18,31 +18,46 @@ package plugin import ( "context" + "fmt" "net" "os" "path/filepath" "sync" "testing" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" - drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1alpha3" + drapbv1alpha2 "k8s.io/kubelet/pkg/apis/dra/v1alpha2" + drapbv1alpha3 "k8s.io/kubelet/pkg/apis/dra/v1alpha3" ) -type fakeGRPCServer struct { - drapbv1.UnimplementedNodeServer +type fakeV1alpha3GRPCServer struct { + drapbv1alpha3.UnimplementedNodeServer } -func (f *fakeGRPCServer) NodePrepareResource(ctx context.Context, in *drapbv1.NodePrepareResourcesRequest) (*drapbv1.NodePrepareResourcesResponse, error) { - return &drapbv1.NodePrepareResourcesResponse{Claims: map[string]*drapbv1.NodePrepareResourceResponse{"dummy": {CDIDevices: []string{"dummy"}}}}, nil +func (f *fakeV1alpha3GRPCServer) NodePrepareResource(ctx context.Context, in *drapbv1alpha3.NodePrepareResourcesRequest) (*drapbv1alpha3.NodePrepareResourcesResponse, error) { + return &drapbv1alpha3.NodePrepareResourcesResponse{Claims: map[string]*drapbv1alpha3.NodePrepareResourceResponse{"dummy": {CDIDevices: []string{"dummy"}}}}, nil } -func (f *fakeGRPCServer) NodeUnprepareResource(ctx context.Context, in *drapbv1.NodeUnprepareResourcesRequest) (*drapbv1.NodeUnprepareResourcesResponse, error) { - return &drapbv1.NodeUnprepareResourcesResponse{}, nil +func (f *fakeV1alpha3GRPCServer) NodeUnprepareResource(ctx context.Context, in *drapbv1alpha3.NodeUnprepareResourcesRequest) (*drapbv1alpha3.NodeUnprepareResourcesResponse, error) { + return &drapbv1alpha3.NodeUnprepareResourcesResponse{}, nil +} + +type fakeV1alpha2GRPCServer struct { + drapbv1alpha2.UnimplementedNodeServer +} + +func (f *fakeV1alpha2GRPCServer) NodePrepareResource(ctx context.Context, in *drapbv1alpha2.NodePrepareResourceRequest) (*drapbv1alpha2.NodePrepareResourceResponse, error) { + return &drapbv1alpha2.NodePrepareResourceResponse{CdiDevices: []string{"dummy"}}, nil +} + +func (f *fakeV1alpha2GRPCServer) NodeUnprepareResource(ctx context.Context, in *drapbv1alpha2.NodeUnprepareResourceRequest) (*drapbv1alpha2.NodeUnprepareResourceResponse, error) { + return &drapbv1alpha2.NodeUnprepareResourceResponse{}, nil } type tearDown func() -func setupFakeGRPCServer() (string, tearDown, error) { +func setupFakeGRPCServer(version string) (string, tearDown, error) { p, err := os.MkdirTemp("", "dra_plugin") if err != nil { return "", nil, err @@ -62,8 +77,16 @@ func setupFakeGRPCServer() (string, tearDown, error) { } s := grpc.NewServer() - fakeGRPCServer := &fakeGRPCServer{} - drapbv1.RegisterNodeServer(s, fakeGRPCServer) + switch version { + case v1alpha2Version: + fakeGRPCServer := &fakeV1alpha2GRPCServer{} + drapbv1alpha2.RegisterNodeServer(s, fakeGRPCServer) + case v1alpha3Version: + fakeGRPCServer := &fakeV1alpha3GRPCServer{} + drapbv1alpha3.RegisterNodeServer(s, fakeGRPCServer) + default: + return "", nil, fmt.Errorf("unsupported version: %s", version) + } go func() { go s.Serve(listener) @@ -75,7 +98,7 @@ func setupFakeGRPCServer() (string, tearDown, error) { } func TestGRPCConnIsReused(t *testing.T) { - addr, teardown, err := setupFakeGRPCServer() + addr, teardown, err := setupFakeGRPCServer(v1alpha3Version) if err != nil { t.Fatal(err) } @@ -85,11 +108,12 @@ func TestGRPCConnIsReused(t *testing.T) { wg := sync.WaitGroup{} m := sync.Mutex{} - plugin := &Plugin{ + p := &plugin{ endpoint: addr, + version: v1alpha3Version, } - conn, err := plugin.getOrCreateGRPCConn() + conn, err := p.getOrCreateGRPCConn() defer func() { err := conn.Close() if err != nil { @@ -101,7 +125,8 @@ func TestGRPCConnIsReused(t *testing.T) { } // ensure the plugin we are using is registered - draPlugins.Set("dummy-plugin", plugin) + draPlugins.add("dummy-plugin", p) + defer draPlugins.delete("dummy-plugin") // we call `NodePrepareResource` 2 times and check whether a new connection is created or the same is reused for i := 0; i < 2; i++ { @@ -114,8 +139,8 @@ func TestGRPCConnIsReused(t *testing.T) { return } - req := &drapbv1.NodePrepareResourcesRequest{ - Claims: []*drapbv1.Claim{ + req := &drapbv1alpha3.NodePrepareResourcesRequest{ + Claims: []*drapbv1alpha3.Claim{ { Namespace: "dummy-namespace", Uid: "dummy-uid", @@ -126,9 +151,9 @@ func TestGRPCConnIsReused(t *testing.T) { } client.NodePrepareResources(context.TODO(), req) - client.(*draPluginClient).plugin.Lock() - conn := client.(*draPluginClient).plugin.conn - client.(*draPluginClient).plugin.Unlock() + client.(*plugin).Lock() + conn := client.(*plugin).conn + client.(*plugin).Unlock() m.Lock() defer m.Unlock() @@ -144,6 +169,122 @@ func TestGRPCConnIsReused(t *testing.T) { if counter, ok := reusedConns[conn]; ok && counter != 2 { t.Errorf("expected counter to be 2 but got %d", counter) } - - draPlugins.Delete("dummy-plugin") +} + +func TestNewDRAPluginClient(t *testing.T) { + for _, test := range []struct { + description string + setup func(string) tearDown + pluginName string + shouldError bool + }{ + { + description: "plugin name is empty", + setup: func(_ string) tearDown { + return func() {} + }, + pluginName: "", + shouldError: true, + }, + { + description: "plugin name not found in the list", + setup: func(_ string) tearDown { + return func() {} + }, + pluginName: "plugin-name-not-found-in-the-list", + shouldError: true, + }, + { + description: "plugin exists", + setup: func(name string) tearDown { + draPlugins.add(name, &plugin{}) + return func() { + draPlugins.delete(name) + } + }, + pluginName: "dummy-plugin", + }, + } { + t.Run(test.description, func(t *testing.T) { + teardown := test.setup(test.pluginName) + defer teardown() + + client, err := NewDRAPluginClient(test.pluginName) + if test.shouldError { + assert.Nil(t, client) + assert.Error(t, err) + } else { + assert.NotNil(t, client) + assert.Nil(t, err) + } + }) + } +} + +func TestNodeUnprepareResource(t *testing.T) { + for _, test := range []struct { + description string + serverSetup func(string) (string, tearDown, error) + serverVersion string + request *drapbv1alpha3.NodeUnprepareResourcesRequest + }{ + { + description: "server supports v1alpha3", + serverSetup: setupFakeGRPCServer, + serverVersion: v1alpha3Version, + request: &drapbv1alpha3.NodeUnprepareResourcesRequest{}, + }, + { + description: "server supports v1alpha2, plugin client should fallback", + serverSetup: setupFakeGRPCServer, + serverVersion: v1alpha2Version, + request: &drapbv1alpha3.NodeUnprepareResourcesRequest{ + Claims: []*drapbv1alpha3.Claim{ + { + Namespace: "dummy-namespace", + Uid: "dummy-uid", + Name: "dummy-claim", + ResourceHandle: "dummy-resource", + }, + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + addr, teardown, err := setupFakeGRPCServer(test.serverVersion) + if err != nil { + t.Fatal(err) + } + defer teardown() + + p := &plugin{ + endpoint: addr, + version: v1alpha3Version, + } + + conn, err := p.getOrCreateGRPCConn() + defer func() { + err := conn.Close() + if err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + + draPlugins.add("dummy-plugin", p) + defer draPlugins.delete("dummy-plugin") + + client, err := NewDRAPluginClient("dummy-plugin") + if err != nil { + t.Fatal(err) + } + + _, err = client.NodeUnprepareResources(context.TODO(), test.request) + if err != nil { + t.Fatal(err) + } + }) + } } diff --git a/pkg/kubelet/cm/dra/plugin/plugin.go b/pkg/kubelet/cm/dra/plugin/plugin.go index e9da2dd4ad7..94a9c7354de 100644 --- a/pkg/kubelet/cm/dra/plugin/plugin.go +++ b/pkg/kubelet/cm/dra/plugin/plugin.go @@ -17,22 +17,81 @@ limitations under the License. package plugin import ( + "context" "errors" "fmt" + "net" "strings" + "sync" + "time" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" utilversion "k8s.io/apimachinery/pkg/util/version" "k8s.io/klog/v2" ) const ( // DRAPluginName is the name of the in-tree DRA Plugin. - DRAPluginName = "kubernetes.io/dra" + DRAPluginName = "kubernetes.io/dra" + v1alpha3Version = "v1alpha3" + v1alpha2Version = "v1alpha2" ) -// draPlugins map keeps track of all registered DRA plugins on the node -// and their corresponding sockets. -var draPlugins = &PluginsStore{} +// Plugin is a description of a DRA Plugin, defined by an endpoint +// and the highest DRA version supported. +type plugin struct { + sync.Mutex + conn *grpc.ClientConn + endpoint string + version string + highestSupportedVersion *utilversion.Version +} + +func (p *plugin) getOrCreateGRPCConn() (*grpc.ClientConn, error) { + p.Lock() + defer p.Unlock() + + if p.conn != nil { + return p.conn, nil + } + + network := "unix" + klog.V(4).InfoS(log("creating new gRPC connection"), "protocol", network, "endpoint", p.endpoint) + conn, err := grpc.Dial( + p.endpoint, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, target) + }), + ) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if ok := conn.WaitForStateChange(ctx, connectivity.Connecting); !ok { + return nil, errors.New("timed out waiting for gRPC connection to be ready") + } + + p.conn = conn + return p.conn, nil +} + +func (p *plugin) getVersion() string { + p.Lock() + defer p.Unlock() + return p.version +} + +func (p *plugin) setVersion(version string) { + p.Lock() + p.version = version + p.Unlock() +} // RegistrationHandler is the handler which is fed to the pluginwatcher API. type RegistrationHandler struct{} @@ -53,9 +112,11 @@ func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, // Storing endpoint of newly registered DRA Plugin into the map, where plugin name will be the key // all other DRA components will be able to get the actual socket of DRA plugins by its name. - draPlugins.Set(pluginName, &Plugin{ + // By default we assume the supported plugin version is v1alpha3 + draPlugins.add(pluginName, &plugin{ conn: nil, endpoint: endpoint, + version: v1alpha3Version, highestSupportedVersion: highestSupportedVersion, }) @@ -91,32 +152,32 @@ func (h *RegistrationHandler) validateVersions( ) } - existingPlugin := draPlugins.Get(pluginName) - if existingPlugin != nil { - if !existingPlugin.highestSupportedVersion.LessThan(newPluginHighestVersion) { - return nil, errors.New( - log( - "%s for DRA plugin %q failed. Another plugin with the same name is already registered with a higher supported version: %q", - callerName, - pluginName, - existingPlugin.highestSupportedVersion, - ), - ) - } + existingPlugin := draPlugins.get(pluginName) + if existingPlugin == nil { + return newPluginHighestVersion, nil } - - return newPluginHighestVersion, nil + if existingPlugin.highestSupportedVersion.LessThan(newPluginHighestVersion) { + return newPluginHighestVersion, nil + } + return nil, errors.New( + log( + "%s for DRA plugin %q failed. Another plugin with the same name is already registered with a higher supported version: %q", + callerName, + pluginName, + existingPlugin.highestSupportedVersion, + ), + ) } -func unregisterPlugin(pluginName string) { - draPlugins.Delete(pluginName) +func deregisterPlugin(pluginName string) { + draPlugins.delete(pluginName) } // DeRegisterPlugin is called when a plugin has removed its socket, // signaling it is no longer available. func (h *RegistrationHandler) DeRegisterPlugin(pluginName string) { klog.InfoS("DeRegister DRA plugin", "name", pluginName) - unregisterPlugin(pluginName) + deregisterPlugin(pluginName) } // ValidatePlugin is called by kubelet's plugin watcher upon detection diff --git a/pkg/kubelet/cm/dra/plugin/plugin_test.go b/pkg/kubelet/cm/dra/plugin/plugin_test.go new file mode 100644 index 00000000000..70499b260c8 --- /dev/null +++ b/pkg/kubelet/cm/dra/plugin/plugin_test.go @@ -0,0 +1,81 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugin + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRegistrationHandler_ValidatePlugin(t *testing.T) { + for _, test := range []struct { + description string + handler func() *RegistrationHandler + pluginName string + endpoint string + versions []string + shouldError bool + }{ + { + description: "no versions provided", + handler: NewRegistrationHandler, + shouldError: true, + }, + { + description: "unsupported version", + handler: NewRegistrationHandler, + versions: []string{"v2.0.0"}, + shouldError: true, + }, + { + description: "plugin already registered with a higher supported version", + handler: func() *RegistrationHandler { + handler := NewRegistrationHandler() + if err := handler.RegisterPlugin("this-plugin-already-exists-and-has-a-long-name-so-it-doesnt-collide", "", []string{"v1.1.0"}); err != nil { + t.Fatal(err) + } + return handler + }, + pluginName: "this-plugin-already-exists-and-has-a-long-name-so-it-doesnt-collide", + versions: []string{"v1.0.0"}, + shouldError: true, + }, + { + description: "should validate the plugin", + handler: NewRegistrationHandler, + pluginName: "this-is-a-dummy-plugin-with-a-long-name-so-it-doesnt-collide", + versions: []string{"v1.3.0"}, + }, + } { + t.Run(test.description, func(t *testing.T) { + handler := test.handler() + err := handler.ValidatePlugin(test.pluginName, test.endpoint, test.versions) + if test.shouldError { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } + }) + } + + t.Cleanup(func() { + handler := NewRegistrationHandler() + handler.DeRegisterPlugin("this-plugin-already-exists-and-has-a-long-name-so-it-doesnt-collide") + handler.DeRegisterPlugin("this-is-a-dummy-plugin-with-a-long-name-so-it-doesnt-collide") + }) +} diff --git a/pkg/kubelet/cm/dra/plugin/plugins_store.go b/pkg/kubelet/cm/dra/plugin/plugins_store.go index 32f750af80d..aa1449e5913 100644 --- a/pkg/kubelet/cm/dra/plugin/plugins_store.go +++ b/pkg/kubelet/cm/dra/plugin/plugins_store.go @@ -17,69 +17,24 @@ limitations under the License. package plugin import ( - "context" - "errors" - "net" "sync" - "time" - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/insecure" - utilversion "k8s.io/apimachinery/pkg/util/version" "k8s.io/klog/v2" ) -// Plugin is a description of a DRA Plugin, defined by an endpoint -// and the highest DRA version supported. -type Plugin struct { - sync.RWMutex - conn *grpc.ClientConn - endpoint string - highestSupportedVersion *utilversion.Version -} - -func (p *Plugin) getOrCreateGRPCConn() (*grpc.ClientConn, error) { - p.Lock() - defer p.Unlock() - - if p.conn != nil { - return p.conn, nil - } - - network := "unix" - klog.V(4).InfoS(log("creating new gRPC connection"), "protocol", network, "endpoint", p.endpoint) - conn, err := grpc.Dial( - p.endpoint, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, network, target) - }), - ) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - if ok := conn.WaitForStateChange(ctx, connectivity.Connecting); !ok { - return nil, errors.New("timed out waiting for gRPC connection to be ready") - } - - p.conn = conn - return p.conn, nil -} - // PluginsStore holds a list of DRA Plugins. -type PluginsStore struct { +type pluginsStore struct { sync.RWMutex - store map[string]*Plugin + store map[string]*plugin } +// draPlugins map keeps track of all registered DRA plugins on the node +// and their corresponding sockets. +var draPlugins = &pluginsStore{} + // Get lets you retrieve a DRA Plugin by name. // This method is protected by a mutex. -func (s *PluginsStore) Get(pluginName string) *Plugin { +func (s *pluginsStore) get(pluginName string) *plugin { s.RLock() defer s.RUnlock() @@ -88,31 +43,26 @@ func (s *PluginsStore) Get(pluginName string) *Plugin { // Set lets you save a DRA Plugin to the list and give it a specific name. // This method is protected by a mutex. -func (s *PluginsStore) Set(pluginName string, plugin *Plugin) { +func (s *pluginsStore) add(pluginName string, p *plugin) { s.Lock() defer s.Unlock() if s.store == nil { - s.store = make(map[string]*Plugin) + s.store = make(map[string]*plugin) } - s.store[pluginName] = plugin + _, exists := s.store[pluginName] + if exists { + klog.V(1).InfoS(log("plugin: %s already registered, previous plugin will be overridden", pluginName)) + } + s.store[pluginName] = p } // Delete lets you delete a DRA Plugin by name. // This method is protected by a mutex. -func (s *PluginsStore) Delete(pluginName string) { +func (s *pluginsStore) delete(pluginName string) { s.Lock() defer s.Unlock() delete(s.store, pluginName) } - -// Clear deletes all entries in the store. -// This methiod is protected by a mutex. -func (s *PluginsStore) Clear() { - s.Lock() - defer s.Unlock() - - s.store = make(map[string]*Plugin) -}