diff --git a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go index da434256def..89c6a70790b 100644 --- a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go +++ b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/draplugin.go @@ -135,10 +135,11 @@ func KubeletPluginSocketPath(path string) Option { } } -// GRPCInterceptor is called for each incoming gRPC method call. +// GRPCInterceptor is called for each incoming gRPC method call. This option +// may be used more than once and each interceptor will get called. func GRPCInterceptor(interceptor grpc.UnaryServerInterceptor) Option { return func(o *options) error { - o.interceptor = interceptor + o.interceptors = append(o.interceptors, interceptor) return nil } } @@ -150,7 +151,7 @@ type options struct { draEndpoint endpoint draAddress string pluginRegistrationEndpoint endpoint - interceptor grpc.UnaryServerInterceptor + interceptors []grpc.UnaryServerInterceptor } // draPlugin combines the kubelet registration service and the DRA node plugin @@ -190,7 +191,7 @@ func Start(nodeServer drapbv1.NodeServer, opts ...Option) (result DRAPlugin, fin } // Run the node plugin gRPC server first to ensure that it is ready. - plugin, err := startGRPCServer(klog.LoggerWithName(o.logger, "dra"), o.grpcVerbosity, o.interceptor, o.draEndpoint, func(grpcServer *grpc.Server) { + plugin, err := startGRPCServer(klog.LoggerWithName(o.logger, "dra"), o.grpcVerbosity, o.interceptors, o.draEndpoint, func(grpcServer *grpc.Server) { drapbv1.RegisterNodeServer(grpcServer, nodeServer) }) if err != nil { @@ -209,7 +210,7 @@ func Start(nodeServer drapbv1.NodeServer, opts ...Option) (result DRAPlugin, fin }() // Now make it available to kubelet. - registrar, err := startRegistrar(klog.LoggerWithName(o.logger, "registrar"), o.grpcVerbosity, o.interceptor, o.driverName, o.draAddress, o.pluginRegistrationEndpoint) + registrar, err := startRegistrar(klog.LoggerWithName(o.logger, "registrar"), o.grpcVerbosity, o.interceptors, o.driverName, o.draAddress, o.pluginRegistrationEndpoint) if err != nil { return nil, fmt.Errorf("start registrar: %v", err) } diff --git a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/noderegistrar.go b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/noderegistrar.go index 1f37db22fe9..f5148e4c9c1 100644 --- a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/noderegistrar.go +++ b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/noderegistrar.go @@ -31,7 +31,7 @@ type nodeRegistrar struct { } // startRegistrar returns a running instance. -func startRegistrar(logger klog.Logger, grpcVerbosity int, interceptor grpc.UnaryServerInterceptor, driverName string, endpoint string, pluginRegistrationEndpoint endpoint) (*nodeRegistrar, error) { +func startRegistrar(logger klog.Logger, grpcVerbosity int, interceptors []grpc.UnaryServerInterceptor, driverName string, endpoint string, pluginRegistrationEndpoint endpoint) (*nodeRegistrar, error) { n := &nodeRegistrar{ logger: logger, registrationServer: registrationServer{ @@ -40,7 +40,7 @@ func startRegistrar(logger klog.Logger, grpcVerbosity int, interceptor grpc.Unar supportedVersions: []string{"1.0.0"}, // TODO: is this correct? }, } - s, err := startGRPCServer(logger, grpcVerbosity, interceptor, pluginRegistrationEndpoint, func(grpcServer *grpc.Server) { + s, err := startGRPCServer(logger, grpcVerbosity, interceptors, pluginRegistrationEndpoint, func(grpcServer *grpc.Server) { registerapi.RegisterRegistrationServer(grpcServer, n) }) if err != nil { diff --git a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/nonblockinggrpcserver.go b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/nonblockinggrpcserver.go index 16df39710de..e6a835d9695 100644 --- a/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/nonblockinggrpcserver.go +++ b/staging/src/k8s.io/dynamic-resource-allocation/kubeletplugin/nonblockinggrpcserver.go @@ -54,7 +54,7 @@ type endpoint struct { // startGRPCServer sets up the GRPC server on a Unix domain socket and spawns a goroutine // which handles requests for arbitrary services. -func startGRPCServer(logger klog.Logger, grpcVerbosity int, interceptor grpc.UnaryServerInterceptor, endpoint endpoint, services ...registerService) (*grpcServer, error) { +func startGRPCServer(logger klog.Logger, grpcVerbosity int, interceptors []grpc.UnaryServerInterceptor, endpoint endpoint, services ...registerService) (*grpcServer, error) { s := &grpcServer{ logger: logger, endpoint: endpoint, @@ -79,15 +79,13 @@ func startGRPCServer(logger klog.Logger, grpcVerbosity int, interceptor grpc.Una // Run a gRPC server. It will close the listening socket when // shutting down, so we don't need to do that. var opts []grpc.ServerOption - var interceptors []grpc.UnaryServerInterceptor + var finalInterceptors []grpc.UnaryServerInterceptor if grpcVerbosity >= 0 { - interceptors = append(interceptors, s.interceptor) + finalInterceptors = append(finalInterceptors, s.interceptor) } - if interceptor != nil { - interceptors = append(interceptors, interceptor) - } - if len(interceptors) >= 0 { - opts = append(opts, grpc.ChainUnaryInterceptor(interceptors...)) + finalInterceptors = append(finalInterceptors, interceptors...) + if len(finalInterceptors) >= 0 { + opts = append(opts, grpc.ChainUnaryInterceptor(finalInterceptors...)) } s.server = grpc.NewServer(opts...) for _, service := range services { diff --git a/test/e2e/dra/deploy.go b/test/e2e/dra/deploy.go index e115d1ffbd2..a7583462d67 100644 --- a/test/e2e/dra/deploy.go +++ b/test/e2e/dra/deploy.go @@ -240,14 +240,14 @@ func (d *Driver) SetUp(nodes *Nodes, resources app.Resources) { // Wait for registration. ginkgo.By("wait for plugin registration") - gomega.Eventually(func() []string { - var notRegistered []string + gomega.Eventually(func() map[string][]app.GRPCCall { + notRegistered := make(map[string][]app.GRPCCall) for nodename, plugin := range d.Nodes { - if !plugin.IsRegistered() { - notRegistered = append(notRegistered, nodename) + calls := plugin.GetGRPCCalls() + if contains, err := app.BeRegistered.Match(calls); err != nil || !contains { + notRegistered[nodename] = calls } } - sort.Strings(notRegistered) return notRegistered }).WithTimeout(time.Minute).Should(gomega.BeEmpty(), "hosts where the plugin has not been registered yet") } diff --git a/test/e2e/dra/test-driver/app/gomega.go b/test/e2e/dra/test-driver/app/gomega.go new file mode 100644 index 00000000000..4caa7805041 --- /dev/null +++ b/test/e2e/dra/test-driver/app/gomega.go @@ -0,0 +1,32 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "github.com/onsi/gomega/gcustom" +) + +// BeRegistered checks that plugin registration has completed. +var BeRegistered = gcustom.MakeMatcher(func(actualCalls []GRPCCall) (bool, error) { + for _, call := range actualCalls { + if call.FullMethod == "/pluginregistration.Registration/NotifyRegistrationStatus" && + call.Err == nil { + return true, nil + } + } + return false, nil +}).WithMessage("contain successful NotifyRegistrationStatus call") diff --git a/test/e2e/dra/test-driver/app/kubeletplugin.go b/test/e2e/dra/test-driver/app/kubeletplugin.go index 88065ae60c1..f899e8da404 100644 --- a/test/e2e/dra/test-driver/app/kubeletplugin.go +++ b/test/e2e/dra/test-driver/app/kubeletplugin.go @@ -24,6 +24,8 @@ import ( "path/filepath" "sync" + "google.golang.org/grpc" + "k8s.io/dynamic-resource-allocation/kubeletplugin" "k8s.io/klog/v2" drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1alpha2" @@ -38,8 +40,23 @@ type ExamplePlugin struct { driverName string nodeName string - mutex sync.Mutex - prepared map[ClaimID]bool + mutex sync.Mutex + prepared map[ClaimID]bool + gRPCCalls []GRPCCall +} + +type GRPCCall struct { + // FullMethod is the fully qualified, e.g. /package.service/method. + FullMethod string + + // Request contains the parameters of the call. + Request interface{} + + // Response contains the reply of the plugin. It is nil for calls that are in progress. + Response interface{} + + // Err contains the error return value of the plugin. It is nil for calls that are in progress or succeeded. + Err error } // ClaimID contains both claim name and UID to simplify debugging. The @@ -94,6 +111,7 @@ func StartPlugin(logger klog.Logger, cdiDir, driverName string, nodeName string, opts = append(opts, kubeletplugin.Logger(logger), kubeletplugin.DriverName(driverName), + kubeletplugin.GRPCInterceptor(ex.recordGRPCCall), ) d, err := kubeletplugin.Start(ex, opts...) if err != nil { @@ -206,3 +224,35 @@ func (ex *ExamplePlugin) GetPreparedResources() []ClaimID { } return prepared } + +func (ex *ExamplePlugin) recordGRPCCall(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + call := GRPCCall{ + FullMethod: info.FullMethod, + Request: req, + } + ex.mutex.Lock() + ex.gRPCCalls = append(ex.gRPCCalls, call) + index := len(ex.gRPCCalls) - 1 + ex.mutex.Unlock() + + // We don't hold the mutex here to allow concurrent calls. + call.Response, call.Err = handler(ctx, req) + + ex.mutex.Lock() + ex.gRPCCalls[index] = call + ex.mutex.Unlock() + + return call.Response, call.Err +} + +func (ex *ExamplePlugin) GetGRPCCalls() []GRPCCall { + ex.mutex.Lock() + defer ex.mutex.Unlock() + + // We must return a new slice, otherwise adding new calls would become + // visible to the caller. We also need to copy the entries because + // they get mutated by recordGRPCCall. + calls := make([]GRPCCall, 0, len(ex.gRPCCalls)) + calls = append(calls, ex.gRPCCalls...) + return calls +}